summaryrefslogtreecommitdiff
path: root/compiler/luci
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci')
-rw-r--r--compiler/luci/CMakeLists.txt3
-rw-r--r--compiler/luci/env/include/luci/UserSettings.h1
-rw-r--r--compiler/luci/env/src/UserSettings.cpp6
-rw-r--r--compiler/luci/env/src/UserSettings.test.cpp12
-rw-r--r--compiler/luci/export/CMakeLists.txt1
-rw-r--r--compiler/luci/export/include/luci/CircleFileExpContract.h2
-rw-r--r--compiler/luci/export/src/CircleExportMetadata.cpp121
-rw-r--r--compiler/luci/export/src/CircleExportMetadata.h36
-rw-r--r--compiler/luci/export/src/CircleExporterImpl.cpp61
-rw-r--r--compiler/luci/export/src/CircleExporterImpl.h2
-rw-r--r--compiler/luci/export/src/CircleExporterUtils.cpp8
-rw-r--r--compiler/luci/export/src/CircleOperationExporter.cpp199
-rw-r--r--compiler/luci/export/src/CircleTensorExporter.cpp175
-rw-r--r--compiler/luci/export/src/Optimize.cpp10
-rw-r--r--compiler/luci/export/src/ProgressReporter.h2
-rw-r--r--compiler/luci/export/src/SerializedData.h32
-rw-r--r--compiler/luci/export/src/TypeBridge.cpp105
-rw-r--r--compiler/luci/import/CMakeLists.txt1
-rw-r--r--compiler/luci/import/include/luci/Import/CircleReader.h4
-rw-r--r--compiler/luci/import/include/luci/Import/GraphBuilder.h8
-rw-r--r--compiler/luci/import/include/luci/Import/GraphBuilderBase.h4
-rw-r--r--compiler/luci/import/include/luci/Import/GraphBuilderContext.h2
-rw-r--r--compiler/luci/import/include/luci/Import/GraphBuilderMultiOutput.h67
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes.h2
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleBidirectionalSequenceLSTM.h37
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleCustom.h8
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleFakeQuant.h37
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleIf.h8
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV4.h8
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV5.h8
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleSplit.h8
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleSplitV.h8
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleTopKV2.h8
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleUnique.h8
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleUnpack.h8
-rw-r--r--compiler/luci/import/include/luci/Import/Nodes/CircleWhile.h2
-rw-r--r--compiler/luci/import/src/CircleImportMetadata.cpp185
-rw-r--r--compiler/luci/import/src/CircleImportMetadata.h56
-rw-r--r--compiler/luci/import/src/CircleReader.cpp16
-rw-r--r--compiler/luci/import/src/GraphBuilder.cpp10
-rw-r--r--compiler/luci/import/src/GraphBuilderMultiOutput.cpp97
-rw-r--r--compiler/luci/import/src/GraphBuilderRegistry.cpp4
-rw-r--r--compiler/luci/import/src/Importer.cpp35
-rw-r--r--compiler/luci/import/src/Nodes/CircleAbs.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleAdd.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleArgMax.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleArgMin.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp14
-rw-r--r--compiler/luci/import/src/Nodes/CircleBCQGather.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleBidirectionalSequenceLSTM.cpp112
-rw-r--r--compiler/luci/import/src/Nodes/CircleCast.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleCeil.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleConv2D.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleCos.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleCustom.cpp65
-rw-r--r--compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp26
-rw-r--r--compiler/luci/import/src/Nodes/CircleDequantize.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleDiv.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleElu.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleEqual.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleExp.cpp4
-rw-r--r--compiler/luci/import/src/Nodes/CircleExpandDims.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleFakeQuant.cpp49
-rw-r--r--compiler/luci/import/src/Nodes/CircleFill.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleFloor.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleFloorDiv.cpp14
-rw-r--r--compiler/luci/import/src/Nodes/CircleFloorMod.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleFullyConnected.cpp14
-rw-r--r--compiler/luci/import/src/Nodes/CircleGather.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleGatherNd.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleGreater.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp14
-rw-r--r--compiler/luci/import/src/Nodes/CircleIf.cpp65
-rw-r--r--compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleL2Normalize.cpp15
-rw-r--r--compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleLess.cpp14
-rw-r--r--compiler/luci/import/src/Nodes/CircleLessEqual.cpp14
-rw-r--r--compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleLog.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogicalNot.cpp2
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogicalOr.cpp2
-rw-r--r--compiler/luci/import/src/Nodes/CircleLogistic.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleMean.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleMirrorPad.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleMul.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleNeg.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp76
-rw-r--r--compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp78
-rw-r--r--compiler/luci/import/src/Nodes/CircleNotEqual.cpp14
-rw-r--r--compiler/luci/import/src/Nodes/CircleOneHot.cpp11
-rw-r--r--compiler/luci/import/src/Nodes/CirclePRelu.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CirclePad.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CirclePadV2.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CirclePow.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleRange.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleRank.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleReduceAny.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleReduceProd.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleRelu.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleRelu6.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleReluN1To1.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleReshape.cpp15
-rw-r--r--compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp11
-rw-r--r--compiler/luci/import/src/Nodes/CircleReverseSequence.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleReverseV2.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleRound.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleRsqrt.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleScatterNd.cpp4
-rw-r--r--compiler/luci/import/src/Nodes/CircleSegmentSum.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleSelect.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleSelectV2.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleShape.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleSin.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleSlice.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleSoftmax.cpp6
-rw-r--r--compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp7
-rw-r--r--compiler/luci/import/src/Nodes/CircleSparseToDense.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleSplit.cpp63
-rw-r--r--compiler/luci/import/src/Nodes/CircleSplitV.cpp76
-rw-r--r--compiler/luci/import/src/Nodes/CircleSqrt.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleSquare.cpp4
-rw-r--r--compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleSqueeze.cpp9
-rw-r--r--compiler/luci/import/src/Nodes/CircleStridedSlice.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleSub.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleSum.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleTanh.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleTile.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleTopKV2.cpp69
-rw-r--r--compiler/luci/import/src/Nodes/CircleTranspose.cpp8
-rw-r--r--compiler/luci/import/src/Nodes/CircleTransposeConv.cpp13
-rw-r--r--compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp17
-rw-r--r--compiler/luci/import/src/Nodes/CircleUnique.cpp55
-rw-r--r--compiler/luci/import/src/Nodes/CircleUnpack.cpp61
-rw-r--r--compiler/luci/import/src/Nodes/CircleWhere.cpp10
-rw-r--r--compiler/luci/import/src/Nodes/CircleWhile.cpp5
-rw-r--r--compiler/luci/import/src/Nodes/CircleZerosLike.cpp8
-rw-r--r--compiler/luci/import/src/PostImport.cpp47
-rw-r--r--compiler/luci/lang/CMakeLists.txt1
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleNodeDecl.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleNodeImpl.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleNodeMixins.h107
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleNodeVisitor.h8
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleNodes.h12
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleNodes.lst39
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleOpcode.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleShapeSignature.h53
-rw-r--r--compiler/luci/lang/include/luci/IR/DeadNodeQueryService.h (renamed from compiler/luci/lang/src/DeadNodeQueryService.h)0
-rw-r--r--compiler/luci/lang/include/luci/IR/LuciNodeMixins.h82
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleAbs.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleAdd.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleArgMax.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleArgMin.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleAveragePool2D.h12
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleBCQFullyConnected.h10
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleBCQGather.h6
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleBatchMatMul.h10
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleBatchToSpaceND.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleBidirectionalSequenceLSTM.h172
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleBidirectionalSequenceLSTMOut.h48
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleCast.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleCeil.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleConcatenation.h8
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleConst.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleConv2D.h8
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleCos.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleCustom.h7
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleCustomOut.h7
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleDepthToSpace.h10
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleDepthwiseConv2D.h12
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleDequantize.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleDiv.h7
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleElu.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleEqual.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleExp.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleExpandDims.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleFakeQuant.h60
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleFill.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleFloor.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleFloorDiv.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleFloorMod.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h8
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleGather.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleGatherNd.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleGreater.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleGreaterEqual.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleIf.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleIfOut.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleInput.h7
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleInstanceNorm.h9
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleL2Normalize.h6
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleL2Pool2D.h10
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleLeakyRelu.h8
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleLess.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleLessEqual.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleLocalResponseNormalization.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleLog.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleLogSoftmax.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalAnd.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalNot.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalOr.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleLogistic.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixDiag.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixSetDiag.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleMaxPool2D.h10
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleMaximum.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleMean.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleMinimum.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleMirrorPad.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleMul.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleNeg.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4Out.h7
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5Out.h7
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleNotEqual.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleOneHot.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleOutput.h10
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CirclePRelu.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CirclePad.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CirclePadV2.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CirclePow.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleRange.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleRank.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleReduceAny.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMax.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMin.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleReduceProd.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleRelu.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleRelu6.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleReluN1To1.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleReshape.h7
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleResizeBilinear.h12
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleResizeNearestNeighbor.h10
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleReverseSequence.h19
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleReverseV2.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleRound.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleRsqrt.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleScatterNd.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSegmentSum.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSelect.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSelectV2.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleShape.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSin.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSlice.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSoftmax.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToBatchND.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToDepth.h10
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSparseToDense.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSplit.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSplitOut.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSplitV.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSplitVOut.h7
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSqrt.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSquare.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSquaredDifference.h7
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSqueeze.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleStridedSlice.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSub.h7
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleSum.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleTanh.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleTile.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2Out.h7
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleTranspose.h8
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleTransposeConv.h6
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleUnidirectionalSequenceLSTM.h14
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleUnique.h4
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleUniqueOut.h7
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleUnpack.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleUnpackOut.h7
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleWhere.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleWhile.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleWhileOut.h5
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleZerosLike.h8
-rw-r--r--compiler/luci/lang/include/luci/IR/SparsityParam.h22
-rw-r--r--compiler/luci/lang/src/CircleDialect.cpp3
-rw-r--r--compiler/luci/lang/src/CircleNodeMixins.cpp (renamed from compiler/luci/lang/src/LuciNodeMixins.cpp)6
-rw-r--r--compiler/luci/lang/src/CircleNodes.cpp25
-rw-r--r--compiler/luci/lang/src/DeadNodeQueryService.cpp3
-rw-r--r--compiler/luci/lang/src/Nodes/CircleBatchMatMul.test.cpp2
-rw-r--r--compiler/luci/lang/src/Nodes/CircleBidrectionalSequenceLSTM.test.cpp130
-rw-r--r--compiler/luci/lang/src/Nodes/CircleConst.test.cpp53
-rw-r--r--compiler/luci/lang/src/Nodes/CircleCustom.test.cpp7
-rw-r--r--compiler/luci/lang/src/Nodes/CircleFakeQuant.test.cpp36
-rw-r--r--compiler/luci/logex/src/FormattedGraph.cpp132
-rw-r--r--compiler/luci/partition/CMakeLists.txt29
-rw-r--r--compiler/luci/partition/README.md4
-rw-r--r--compiler/luci/partition/include/luci/Partition.h71
-rw-r--r--compiler/luci/partition/src/CircleOpCode.cpp79
-rw-r--r--compiler/luci/partition/src/CircleOpCode.h (renamed from compiler/luci/lang/src/CircleShapeSignature.cpp)23
-rw-r--r--compiler/luci/partition/src/CircleOpCode.test.cpp31
-rw-r--r--compiler/luci/partition/src/ConnectNode.cpp38
-rw-r--r--compiler/luci/partition/src/ConnectNode.h209
-rw-r--r--compiler/luci/partition/src/ConnectNode.test.cpp19
-rw-r--r--compiler/luci/partition/src/ConnectNode.test.h146
-rw-r--r--compiler/luci/partition/src/Nodes/CircleAdd.cpp40
-rw-r--r--compiler/luci/partition/src/Nodes/CircleAdd.test.cpp100
-rw-r--r--compiler/luci/partition/src/Nodes/CircleConst.cpp (renamed from compiler/luci/service/src/Nodes/CircleInput.cpp)8
-rw-r--r--compiler/luci/partition/src/Nodes/CircleDiv.cpp40
-rw-r--r--compiler/luci/partition/src/Nodes/CircleDiv.test.cpp100
-rw-r--r--compiler/luci/partition/src/Nodes/CircleMean.cpp41
-rw-r--r--compiler/luci/partition/src/Nodes/CircleMul.cpp40
-rw-r--r--compiler/luci/partition/src/Nodes/CircleMul.test.cpp100
-rw-r--r--compiler/luci/partition/src/Nodes/CirclePow.cpp40
-rw-r--r--compiler/luci/partition/src/Nodes/CircleRsqrt.cpp38
-rw-r--r--compiler/luci/partition/src/Nodes/CircleSqrt.cpp38
-rw-r--r--compiler/luci/partition/src/Nodes/CircleSquaredDifference.cpp40
-rw-r--r--compiler/luci/partition/src/Nodes/CircleSub.cpp40
-rw-r--r--compiler/luci/partition/src/Nodes/CircleSub.test.cpp100
-rw-r--r--compiler/luci/partition/src/Partition.cpp61
-rw-r--r--compiler/luci/partition/src/Partition.test.cpp83
-rw-r--r--compiler/luci/partition/src/PartitionCleanup.cpp139
-rw-r--r--compiler/luci/partition/src/PartitionCleanup.h34
-rw-r--r--compiler/luci/partition/src/PartitionIR.cpp101
-rw-r--r--compiler/luci/partition/src/PartitionIR.h91
-rw-r--r--compiler/luci/partition/src/PartitionIR.test.cpp75
-rw-r--r--compiler/luci/partition/src/PartitionIRDump.cpp70
-rw-r--r--compiler/luci/partition/src/PartitionIRDump.h35
-rw-r--r--compiler/luci/partition/src/PartitionMerge.cpp207
-rw-r--r--compiler/luci/partition/src/PartitionMerge.h31
-rw-r--r--compiler/luci/partition/src/PartitionPGroups.cpp139
-rw-r--r--compiler/luci/partition/src/PartitionPGroups.h39
-rw-r--r--compiler/luci/partition/src/PartitionPGroups.test.cpp80
-rw-r--r--compiler/luci/partition/src/PartitionPModules.cpp203
-rw-r--r--compiler/luci/partition/src/PartitionPModules.h31
-rw-r--r--compiler/luci/partition/src/PartitionPModules.test.cpp82
-rw-r--r--compiler/luci/partition/src/PartitionPModulesDump.cpp47
-rw-r--r--compiler/luci/partition/src/PartitionPModulesDump.h34
-rw-r--r--compiler/luci/pass/CMakeLists.txt2
-rw-r--r--compiler/luci/pass/include/luci/CircleOptimizer.h19
-rw-r--r--compiler/luci/pass/include/luci/Pass/CircleShapeInferencePass.h (renamed from compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h)12
-rw-r--r--compiler/luci/pass/include/luci/Pass/ConvertNCHWToNHWCPass.h60
-rw-r--r--compiler/luci/pass/include/luci/Pass/FoldAddV2Pass.h38
-rw-r--r--compiler/luci/pass/include/luci/Pass/FoldCastPass.h38
-rw-r--r--compiler/luci/pass/include/luci/Pass/FoldSparseToDensePass.h38
-rw-r--r--compiler/luci/pass/include/luci/Pass/ForwardReshapeToUnaryOpPass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/FuseBatchNormWithConvPass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/FuseBatchNormWithDwConvPass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/FuseBatchNormWithTConvPass.h (renamed from compiler/luci/pass/include/luci/Pass/FuseBatchNormWithTConv.h)0
-rw-r--r--compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h44
-rw-r--r--compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h2
-rw-r--r--compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h2
-rw-r--r--compiler/luci/pass/include/luci/Pass/RemoveRedundantReshapePass.h39
-rw-r--r--compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryReshapePass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/RemoveUnnecessarySlicePass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/RemoveUnnecessarySplitPass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryStridedSlicePass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/RequantizePass.h2
-rw-r--r--compiler/luci/pass/include/luci/Pass/SparsifyTensorPass.h4
-rw-r--r--compiler/luci/pass/include/luci/Pass/SubstituteSqueezeToReshapePass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/SubstituteTransposeToReshapePass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/TransformMinMaxToRelu6Pass.h37
-rw-r--r--compiler/luci/pass/src/BatchNormPatternFinder.cpp106
-rw-r--r--compiler/luci/pass/src/BatchNormPatternFinder.h43
-rw-r--r--compiler/luci/pass/src/BatchNormPatternFinder.test.cpp217
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.cpp148
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.test.cpp238
-rw-r--r--compiler/luci/pass/src/CircleOptimizerUtils.cpp72
-rw-r--r--compiler/luci/pass/src/CircleOptimizerUtils.h15
-rw-r--r--compiler/luci/pass/src/CircleShapeInferencePass.cpp91
-rw-r--r--compiler/luci/pass/src/CircleShapeInferencePass.test.cpp364
-rw-r--r--compiler/luci/pass/src/CircleTypeInferencePass.cpp4
-rw-r--r--compiler/luci/pass/src/CircleTypeInferencePass.test.cpp26
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp698
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp636
-rw-r--r--compiler/luci/pass/src/FoldAddV2Pass.cpp122
-rw-r--r--compiler/luci/pass/src/FoldAddV2Pass.test.cpp137
-rw-r--r--compiler/luci/pass/src/FoldCastPass.cpp107
-rw-r--r--compiler/luci/pass/src/FoldCastPass.test.cpp112
-rw-r--r--compiler/luci/pass/src/FoldDequantizePass.cpp18
-rw-r--r--compiler/luci/pass/src/FoldDequantizePass.test.cpp (renamed from compiler/luci/service/src/Nodes/CircleOutput.cpp)15
-rw-r--r--compiler/luci/pass/src/FoldSparseToDensePass.cpp140
-rw-r--r--compiler/luci/pass/src/FoldSparseToDensePass.test.cpp133
-rw-r--r--compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp154
-rw-r--r--compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp125
-rw-r--r--compiler/luci/pass/src/FuseActivationFunctionPass.cpp10
-rw-r--r--compiler/luci/pass/src/FuseActivationFunctionPass.test.cpp150
-rw-r--r--compiler/luci/pass/src/FuseAddWithTConvPass.cpp27
-rw-r--r--compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp26
-rw-r--r--compiler/luci/pass/src/FuseBCQPass.cpp54
-rw-r--r--compiler/luci/pass/src/FuseBCQPass.test.cpp26
-rw-r--r--compiler/luci/pass/src/FuseBatchNormWithConvPass.cpp232
-rw-r--r--compiler/luci/pass/src/FuseBatchNormWithConvPass.test.cpp26
-rw-r--r--compiler/luci/pass/src/FuseBatchNormWithDwConvPass.cpp237
-rw-r--r--compiler/luci/pass/src/FuseBatchNormWithDwConvPass.test.cpp26
-rw-r--r--compiler/luci/pass/src/FuseBatchNormWithTConv.cpp159
-rw-r--r--compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp208
-rw-r--r--compiler/luci/pass/src/FuseBatchNormWithTConvPass.test.cpp26
-rw-r--r--compiler/luci/pass/src/FuseInstanceNormPass.cpp229
-rw-r--r--compiler/luci/pass/src/FuseInstanceNormPass.test.cpp9
-rw-r--r--compiler/luci/pass/src/FusePreActivationBatchNormPass.cpp111
-rw-r--r--compiler/luci/pass/src/FusePreActivationBatchNormPass.test.cpp25
-rw-r--r--compiler/luci/pass/src/MakeBatchNormGammaPositivePass.cpp89
-rw-r--r--compiler/luci/pass/src/MakeBatchNormGammaPositivePass.test.cpp26
-rw-r--r--compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp112
-rw-r--r--compiler/luci/pass/src/ModulePhase.test.cpp57
-rw-r--r--compiler/luci/pass/src/PassTestGraphs.h142
-rw-r--r--compiler/luci/pass/src/ProgressReporter.h4
-rw-r--r--compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp153
-rw-r--r--compiler/luci/pass/src/PropagateQuantParamPass.cpp5
-rw-r--r--compiler/luci/pass/src/PropagateQuantParamPass.test.cpp7
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.cpp15
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.h3
-rw-r--r--compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp224
-rw-r--r--compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp27
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp677
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp27
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.cpp71
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.h50
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.test.cpp1668
-rw-r--r--compiler/luci/pass/src/RemoveRedundantReshape.cpp72
-rw-r--r--compiler/luci/pass/src/RemoveRedundantReshape.test.cpp110
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp156
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTransposePass.cpp (renamed from compiler/luci/pass/src/RemoveRedundantTranspose.cpp)73
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp321
-rw-r--r--compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp75
-rw-r--r--compiler/luci/pass/src/RemoveUnnecessaryReshapePass.test.cpp141
-rw-r--r--compiler/luci/pass/src/RemoveUnnecessarySlicePass.cpp111
-rw-r--r--compiler/luci/pass/src/RemoveUnnecessarySlicePass.test.cpp134
-rw-r--r--compiler/luci/pass/src/RemoveUnnecessarySplitPass.cpp (renamed from compiler/luci/pass/src/ShapeSignatureInferencePass.cpp)55
-rw-r--r--compiler/luci/pass/src/RemoveUnnecessarySplitPass.test.cpp149
-rw-r--r--compiler/luci/pass/src/RemoveUnnecessaryStridedSlicePass.cpp124
-rw-r--r--compiler/luci/pass/src/RemoveUnnecessaryStridedSlicePass.test.cpp142
-rw-r--r--compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp99
-rw-r--r--compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp14
-rw-r--r--compiler/luci/pass/src/RequantizePass.cpp4
-rw-r--r--compiler/luci/pass/src/RequantizePass.test.cpp26
-rw-r--r--compiler/luci/pass/src/ResolveCustomOpAddPass.cpp37
-rw-r--r--compiler/luci/pass/src/ResolveCustomOpAddPass.test.cpp26
-rw-r--r--compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.cpp36
-rw-r--r--compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.test.cpp169
-rw-r--r--compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp50
-rw-r--r--compiler/luci/pass/src/ResolveCustomOpMatMulPass.test.cpp26
-rw-r--r--compiler/luci/pass/src/ShapeInferencePass.cpp57
-rw-r--r--compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp8
-rw-r--r--compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp143
-rw-r--r--compiler/luci/pass/src/Sparsifier.cpp4
-rw-r--r--compiler/luci/pass/src/Sparsifier.test.cpp4
-rw-r--r--compiler/luci/pass/src/SparsifyTensorPass.cpp10
-rw-r--r--compiler/luci/pass/src/SparsifyTensorPass.test.cpp30
-rw-r--r--compiler/luci/pass/src/SubstitutePackToReshapePass.cpp57
-rw-r--r--compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp30
-rw-r--r--compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp183
-rw-r--r--compiler/luci/pass/src/SubstituteSqueezeToReshapePass.test.cpp208
-rw-r--r--compiler/luci/pass/src/SubstituteTransposeToReshapePass.cpp137
-rw-r--r--compiler/luci/pass/src/SubstituteTransposeToReshapePass.test.cpp120
-rw-r--r--compiler/luci/pass/src/TransformMinMaxToRelu6Pass.cpp134
-rw-r--r--compiler/luci/pass/src/TransformMinMaxToRelu6Pass.test.cpp151
-rw-r--r--compiler/luci/pass/src/TypeInferencePass.cpp55
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h401
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h388
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h375
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h375
-rw-r--r--compiler/luci/pass/src/helpers/InferenceCandidates.cpp45
-rw-r--r--compiler/luci/pass/src/helpers/InferenceCandidates.h34
-rw-r--r--compiler/luci/pass/src/helpers/InferenceCandidates.test.cpp122
-rw-r--r--compiler/luci/pass/src/helpers/NodeFiller.cpp20
-rw-r--r--compiler/luci/pass/src/helpers/NodeFiller.h104
-rw-r--r--compiler/luci/pass/src/helpers/NodeFiller.test.cpp59
-rw-r--r--compiler/luci/pass/src/helpers/Strings.cpp91
-rw-r--r--compiler/luci/pass/src/helpers/Strings.h57
-rw-r--r--compiler/luci/pass/src/helpers/Strings.test.cpp58
-rw-r--r--compiler/luci/pass/src/helpers/TypeMapper.cpp20
-rw-r--r--compiler/luci/pass/src/helpers/TypeMapper.h77
-rw-r--r--compiler/luci/pass/src/helpers/TypeMapper.test.cpp93
-rw-r--r--compiler/luci/pass/src/test/TestFirstNode.h43
-rw-r--r--compiler/luci/pass/src/test/TestFirstNode.test.cpp19
-rw-r--r--compiler/luci/pass/src/test/TestIOGraph.h161
-rw-r--r--compiler/luci/pass/src/test/TestIOGraph.test.cpp19
-rw-r--r--compiler/luci/pass/src/test/TestShape.h (renamed from compiler/luci/export/src/TypeBridge.h)30
-rw-r--r--compiler/luci/pass/src/test/TestShape.test.cpp57
-rw-r--r--compiler/luci/profile/CMakeLists.txt22
-rw-r--r--compiler/luci/profile/README.md119
-rw-r--r--compiler/luci/profile/include/luci/Profile/CircleNodeID.h (renamed from compiler/luci/pass/src/FuseActivationFunctionPassInternal.h)20
-rw-r--r--compiler/luci/profile/include/luci/Profile/CircleNodeOrigin.h72
-rw-r--r--compiler/luci/profile/src/CircleNodeID.cpp73
-rw-r--r--compiler/luci/profile/src/CircleNodeID.test.cpp44
-rw-r--r--compiler/luci/profile/src/CircleNodeOrigin.cpp168
-rw-r--r--compiler/luci/profile/src/CircleNodeOrigin.test.cpp108
-rw-r--r--compiler/luci/service/CMakeLists.txt1
-rw-r--r--compiler/luci/service/include/luci/Service/CircleNodeClone.h40
-rw-r--r--compiler/luci/service/include/luci/Service/CircleShapeInference.h36
-rw-r--r--compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h179
-rw-r--r--compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h45
-rw-r--r--compiler/luci/service/include/luci/Service/CircleTypeInference.h25
-rw-r--r--compiler/luci/service/include/luci/Service/Nodes/CircleConst.h32
-rw-r--r--compiler/luci/service/include/luci/Service/ShapeDescription.h4
-rw-r--r--compiler/luci/service/include/luci/Service/Validate.h13
-rw-r--r--compiler/luci/service/src/CircleCloneNode.h174
-rw-r--r--compiler/luci/service/src/CircleNodeClone.cpp92
-rw-r--r--compiler/luci/service/src/CircleNodeClone.test.cpp109
-rw-r--r--compiler/luci/service/src/CircleShapeInference.cpp23
-rw-r--r--compiler/luci/service/src/CircleShapeInferenceHelper.cpp21
-rw-r--r--compiler/luci/service/src/CircleShapeInferenceHelper.h (renamed from compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h)16
-rw-r--r--compiler/luci/service/src/CircleShapeInferenceRule.cpp304
-rw-r--r--compiler/luci/service/src/CircleShapeInferenceRule.test.cpp626
-rw-r--r--compiler/luci/service/src/CircleShapeSignatureInference.cpp64
-rw-r--r--compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp160
-rw-r--r--compiler/luci/service/src/CircleTypeInference.cpp55
-rw-r--r--compiler/luci/service/src/CircleTypeInferenceHelper.cpp18
-rw-r--r--compiler/luci/service/src/CircleTypeInferenceHelper.h (renamed from compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h)14
-rw-r--r--compiler/luci/service/src/CircleTypeInferenceRule.cpp268
-rw-r--r--compiler/luci/service/src/CircleTypeInferenceRule.test.cpp63
-rw-r--r--compiler/luci/service/src/Nodes/CircleAbs.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleAbs.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleAdd.cpp (renamed from compiler/luci/pass/include/luci/Pass/TypeInferencePass.h)30
-rw-r--r--compiler/luci/service/src/Nodes/CircleAdd.test.cpp84
-rw-r--r--compiler/luci/service/src/Nodes/CircleAddN.cpp28
-rw-r--r--compiler/luci/service/src/Nodes/CircleAddN.test.cpp34
-rw-r--r--compiler/luci/service/src/Nodes/CircleArgMax.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleArgMax.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleArgMin.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleArgMin.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleAveragePool2D.cpp42
-rw-r--r--compiler/luci/service/src/Nodes/CircleAveragePool2D.test.cpp128
-rw-r--r--compiler/luci/service/src/Nodes/CircleBCQFullyConnected.cpp (renamed from compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h)34
-rw-r--r--compiler/luci/service/src/Nodes/CircleBCQFullyConnected.test.cpp48
-rw-r--r--compiler/luci/service/src/Nodes/CircleBCQGather.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleBCQGather.test.cpp37
-rw-r--r--compiler/luci/service/src/Nodes/CircleBatchMatMul.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleBatchMatMul.test.cpp37
-rw-r--r--compiler/luci/service/src/Nodes/CircleBatchToSpaceND.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleBatchToSpaceND.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleCast.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleCast.test.cpp37
-rw-r--r--compiler/luci/service/src/Nodes/CircleCeil.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleCeil.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleConcatenation.cpp36
-rw-r--r--compiler/luci/service/src/Nodes/CircleConcatenation.test.cpp49
-rw-r--r--compiler/luci/service/src/Nodes/CircleConst.cpp118
-rw-r--r--compiler/luci/service/src/Nodes/CircleConst.test.cpp177
-rw-r--r--compiler/luci/service/src/Nodes/CircleConv2D.cpp42
-rw-r--r--compiler/luci/service/src/Nodes/CircleConv2D.test.cpp61
-rw-r--r--compiler/luci/service/src/Nodes/CircleCos.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleCos.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleCustom.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleCustom.test.cpp46
-rw-r--r--compiler/luci/service/src/Nodes/CircleCustomOut.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleCustomOut.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleDepthToSpace.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleDepthToSpace.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleDepthwiseConv2D.cpp43
-rw-r--r--compiler/luci/service/src/Nodes/CircleDepthwiseConv2D.test.cpp61
-rw-r--r--compiler/luci/service/src/Nodes/CircleDequantize.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleDequantize.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleDiv.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleDiv.test.cpp46
-rw-r--r--compiler/luci/service/src/Nodes/CircleElu.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleElu.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleEqual.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleEqual.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleExp.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleExp.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleExpandDims.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleExpandDims.test.cpp66
-rw-r--r--compiler/luci/service/src/Nodes/CircleFakeQuant.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleFakeQuant.test.cpp41
-rw-r--r--compiler/luci/service/src/Nodes/CircleFill.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleFill.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleFloor.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleFloor.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleFloorDiv.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleFloorDiv.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleFloorMod.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleFloorMod.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleFullyConnected.cpp38
-rw-r--r--compiler/luci/service/src/Nodes/CircleFullyConnected.test.cpp61
-rw-r--r--compiler/luci/service/src/Nodes/CircleGather.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleGather.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleGatherNd.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleGatherNd.test.cpp113
-rw-r--r--compiler/luci/service/src/Nodes/CircleGreater.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleGreater.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleGreaterEqual.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleGreaterEqual.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleIfOut.cpp89
-rw-r--r--compiler/luci/service/src/Nodes/CircleInstanceNorm.cpp36
-rw-r--r--compiler/luci/service/src/Nodes/CircleInstanceNorm.test.cpp48
-rw-r--r--compiler/luci/service/src/Nodes/CircleL2Normalize.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleL2Normalize.test.cpp46
-rw-r--r--compiler/luci/service/src/Nodes/CircleL2Pool2D.cpp42
-rw-r--r--compiler/luci/service/src/Nodes/CircleL2Pool2D.test.cpp61
-rw-r--r--compiler/luci/service/src/Nodes/CircleLeakyRelu.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleLeakyRelu.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleLess.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleLess.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleLessEqual.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleLessEqual.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleLocalResponseNormalization.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleLocalResponseNormalization.test.cpp41
-rw-r--r--compiler/luci/service/src/Nodes/CircleLog.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleLog.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleLogSoftmax.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleLogSoftmax.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleLogicalAnd.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleLogicalAnd.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleLogicalNot.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleLogicalNot.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleLogicalOr.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleLogicalOr.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleLogistic.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleLogistic.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleMatrixDiag.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleMatrixDiag.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleMatrixSetDiag.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleMatrixSetDiag.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleMaxPool2D.cpp42
-rw-r--r--compiler/luci/service/src/Nodes/CircleMaxPool2D.test.cpp69
-rw-r--r--compiler/luci/service/src/Nodes/CircleMaximum.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleMaximum.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleMean.cpp14
-rw-r--r--compiler/luci/service/src/Nodes/CircleMean.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleMinimum.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleMinimum.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleMirrorPad.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleMirrorPad.test.cpp46
-rw-r--r--compiler/luci/service/src/Nodes/CircleMul.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleMul.test.cpp46
-rw-r--r--compiler/luci/service/src/Nodes/CircleNeg.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleNeg.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4Out.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4Out.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5Out.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5Out.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleNotEqual.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleNotEqual.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleOneHot.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleOneHot.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleOutputDummy.cpp11
-rw-r--r--compiler/luci/service/src/Nodes/CircleOutputDummy.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleOutputExclude.cpp10
-rw-r--r--compiler/luci/service/src/Nodes/CircleOutputExclude.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CirclePRelu.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CirclePRelu.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CirclePack.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CirclePack.test.cpp36
-rw-r--r--compiler/luci/service/src/Nodes/CirclePad.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CirclePad.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CirclePadV2.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CirclePadV2.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CirclePow.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CirclePow.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleRange.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleRange.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleRank.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleRank.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceAny.cpp14
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceAny.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceMax.cpp14
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceMax.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceMin.cpp14
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceMin.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceProd.cpp14
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceProd.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleRelu.cpp10
-rw-r--r--compiler/luci/service/src/Nodes/CircleRelu.test.cpp74
-rw-r--r--compiler/luci/service/src/Nodes/CircleRelu6.cpp10
-rw-r--r--compiler/luci/service/src/Nodes/CircleRelu6.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleReluN1To1.cpp10
-rw-r--r--compiler/luci/service/src/Nodes/CircleReluN1To1.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleReshape.cpp37
-rw-r--r--compiler/luci/service/src/Nodes/CircleReshape.test.cpp39
-rw-r--r--compiler/luci/service/src/Nodes/CircleResizeBilinear.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleResizeBilinear.test.cpp73
-rw-r--r--compiler/luci/service/src/Nodes/CircleResizeNearestNeighbor.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleResizeNearestNeighbor.test.cpp71
-rw-r--r--compiler/luci/service/src/Nodes/CircleReverseSequence.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleReverseSequence.test.cpp37
-rw-r--r--compiler/luci/service/src/Nodes/CircleReverseV2.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleReverseV2.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleRound.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleRound.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleRsqrt.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleRsqrt.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleScatterNd.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleScatterNd.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleSegmentSum.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleSegmentSum.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleSelect.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleSelect.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleSelectV2.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleSelectV2.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleShape.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleShape.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleSin.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleSin.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleSlice.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleSlice.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleSoftmax.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleSoftmax.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleSpaceToBatchND.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleSpaceToBatchND.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleSpaceToDepth.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleSpaceToDepth.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleSparseToDense.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleSparseToDense.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleSplit.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleSplit.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleSplitOut.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleSplitOut.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleSplitV.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleSplitV.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleSplitVOut.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleSplitVOut.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleSqrt.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleSqrt.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleSquare.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleSquare.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleSquaredDifference.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleSquaredDifference.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleSqueeze.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleSqueeze.test.cpp83
-rw-r--r--compiler/luci/service/src/Nodes/CircleStridedSlice.cpp36
-rw-r--r--compiler/luci/service/src/Nodes/CircleStridedSlice.test.cpp43
-rw-r--r--compiler/luci/service/src/Nodes/CircleSub.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleSub.test.cpp46
-rw-r--r--compiler/luci/service/src/Nodes/CircleSum.cpp14
-rw-r--r--compiler/luci/service/src/Nodes/CircleSum.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleTanh.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleTanh.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleTile.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleTile.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleTopKV2.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleTopKV2.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleTopKV2Out.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleTopKV2Out.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleTranspose.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleTranspose.test.cpp69
-rw-r--r--compiler/luci/service/src/Nodes/CircleTransposeConv.cpp37
-rw-r--r--compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp46
-rw-r--r--compiler/luci/service/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp39
-rw-r--r--compiler/luci/service/src/Nodes/CircleUnidirectionalSequenceLSTM.test.cpp54
-rw-r--r--compiler/luci/service/src/Nodes/CircleUnique.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleUnique.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleUniqueOut.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleUniqueOut.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleUnpack.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleUnpack.test.cpp37
-rw-r--r--compiler/luci/service/src/Nodes/CircleUnpackOut.cpp30
-rw-r--r--compiler/luci/service/src/Nodes/CircleUnpackOut.test.cpp35
-rw-r--r--compiler/luci/service/src/Nodes/CircleWhere.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleWhere.test.cpp33
-rw-r--r--compiler/luci/service/src/Nodes/CircleZerosLike.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleZerosLike.test.cpp33
-rw-r--r--compiler/luci/service/src/ShapeDescription.cpp85
-rw-r--r--compiler/luci/service/src/ShapeDescription.test.cpp56
-rw-r--r--compiler/luci/service/src/ShapeInfer_StridedSlice.cpp4
-rw-r--r--compiler/luci/service/src/Validate.cpp123
-rw-r--r--compiler/luci/service/src/Validate.test.cpp139
-rw-r--r--compiler/luci/tester/CMakeLists.txt21
-rw-r--r--compiler/luci/tester/src/ReadModule.cpp65
-rw-r--r--compiler/luci/tester/src/ReadModule.h28
-rw-r--r--compiler/luci/tester/src/ReadTester.cpp51
-rw-r--r--compiler/luci/tester/src/ReadTester.test.cpp43
-rw-r--r--compiler/luci/tester/src/WriteTester.cpp56
-rw-r--r--compiler/luci/tester/src/WriteTester.test.cpp44
-rw-r--r--compiler/luci/testhelper/CMakeLists.txt25
-rw-r--r--compiler/luci/testhelper/README.md3
-rw-r--r--compiler/luci/testhelper/include/luci/test/TestIOGraph.h198
-rw-r--r--compiler/luci/testhelper/include/luci/test/TestShape.h40
-rw-r--r--compiler/luci/testhelper/src/TestIOGraph.test.cpp182
-rw-r--r--compiler/luci/testhelper/src/TestShape.test.cpp57
-rw-r--r--compiler/luci/tests/test.lst8
779 files changed, 30293 insertions, 5708 deletions
diff --git a/compiler/luci/CMakeLists.txt b/compiler/luci/CMakeLists.txt
index 214a1bbf2..3771176f0 100644
--- a/compiler/luci/CMakeLists.txt
+++ b/compiler/luci/CMakeLists.txt
@@ -1,8 +1,11 @@
add_subdirectory(env)
add_subdirectory(log)
add_subdirectory(lang)
+add_subdirectory(testhelper)
add_subdirectory(service)
add_subdirectory(pass)
+add_subdirectory(profile)
+add_subdirectory(partition)
add_subdirectory(logex)
add_subdirectory(import)
add_subdirectory(export)
diff --git a/compiler/luci/env/include/luci/UserSettings.h b/compiler/luci/env/include/luci/UserSettings.h
index bcfd16071..b56bd65e2 100644
--- a/compiler/luci/env/include/luci/UserSettings.h
+++ b/compiler/luci/env/include/luci/UserSettings.h
@@ -32,6 +32,7 @@ struct UserSettings
Undefined,
MuteWarnings,
DisableValidation,
+ ProfilingDataGen,
};
static UserSettings *settings();
diff --git a/compiler/luci/env/src/UserSettings.cpp b/compiler/luci/env/src/UserSettings.cpp
index 27dec762d..b4c661190 100644
--- a/compiler/luci/env/src/UserSettings.cpp
+++ b/compiler/luci/env/src/UserSettings.cpp
@@ -30,6 +30,7 @@ public:
private:
bool _MuteWarnings{false};
bool _DisableValidation{false};
+ bool _ProfilingDataGen{false};
};
void UserSettingsImpl::set(const Key key, bool value)
@@ -42,6 +43,9 @@ void UserSettingsImpl::set(const Key key, bool value)
case Key::DisableValidation:
_DisableValidation = value;
break;
+ case Key::ProfilingDataGen:
+ _ProfilingDataGen = value;
+ break;
default:
throw std::runtime_error("Invalid key in boolean set");
break;
@@ -56,6 +60,8 @@ bool UserSettingsImpl::get(const Key key) const
return _MuteWarnings;
case Key::DisableValidation:
return _DisableValidation;
+ case Key::ProfilingDataGen:
+ return _ProfilingDataGen;
default:
throw std::runtime_error("Invalid key in boolean get");
break;
diff --git a/compiler/luci/env/src/UserSettings.test.cpp b/compiler/luci/env/src/UserSettings.test.cpp
index 8d9d1875b..899c0c2a1 100644
--- a/compiler/luci/env/src/UserSettings.test.cpp
+++ b/compiler/luci/env/src/UserSettings.test.cpp
@@ -51,6 +51,18 @@ TEST(UserSettings, DisableValidation)
ASSERT_TRUE(settings->get(luci::UserSettings::Key::DisableValidation));
}
+TEST(UserSettings, ProfilingDataGen)
+{
+ auto settings = luci::UserSettings::settings();
+ ASSERT_NE(nullptr, settings);
+
+ settings->set(luci::UserSettings::Key::ProfilingDataGen, false);
+ ASSERT_FALSE(settings->get(luci::UserSettings::Key::ProfilingDataGen));
+
+ settings->set(luci::UserSettings::Key::ProfilingDataGen, true);
+ ASSERT_TRUE(settings->get(luci::UserSettings::Key::ProfilingDataGen));
+}
+
TEST(UserSettings, undefined_set_NEG)
{
auto settings = luci::UserSettings::settings();
diff --git a/compiler/luci/export/CMakeLists.txt b/compiler/luci/export/CMakeLists.txt
index fe4382ecd..01f737110 100644
--- a/compiler/luci/export/CMakeLists.txt
+++ b/compiler/luci/export/CMakeLists.txt
@@ -13,6 +13,7 @@ target_link_libraries(luci_export PRIVATE mio_circle)
target_link_libraries(luci_export PRIVATE luci_env)
target_link_libraries(luci_export PRIVATE luci_log)
target_link_libraries(luci_export PRIVATE luci_logex)
+target_link_libraries(luci_export PRIVATE luci_profile)
target_link_libraries(luci_export PRIVATE nncc_common)
target_link_libraries(luci_export PRIVATE locop)
target_link_libraries(luci_export PRIVATE oops)
diff --git a/compiler/luci/export/include/luci/CircleFileExpContract.h b/compiler/luci/export/include/luci/CircleFileExpContract.h
index eeaf2d9bb..8ef1b5e0c 100644
--- a/compiler/luci/export/include/luci/CircleFileExpContract.h
+++ b/compiler/luci/export/include/luci/CircleFileExpContract.h
@@ -33,7 +33,7 @@ struct CircleFileExpContract : public luci::CircleExporter::Contract
{
public:
CircleFileExpContract(luci::Module *module, const std::string &filename)
- : _module(module), _filepath(filename)
+ : _module(module), _filepath(filename)
{
// NOTHING TO DO
}
diff --git a/compiler/luci/export/src/CircleExportMetadata.cpp b/compiler/luci/export/src/CircleExportMetadata.cpp
new file mode 100644
index 000000000..ef905a882
--- /dev/null
+++ b/compiler/luci/export/src/CircleExportMetadata.cpp
@@ -0,0 +1,121 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleExportMetadata.h"
+
+#include <luci/UserSettings.h>
+
+namespace
+{
+
+void write_u32(std::vector<uint8_t> &to, uint32_t value)
+{
+ to.emplace_back(0xFF & (value >> 0 * 8));
+ to.emplace_back(0xFF & (value >> 1 * 8));
+ to.emplace_back(0xFF & (value >> 2 * 8));
+ to.emplace_back(0xFF & (value >> 3 * 8));
+}
+
+flatbuffers::Offset<circle::Metadata> metadata_offset(flatbuffers::FlatBufferBuilder &builder,
+ luci::SerializedModelData &md,
+ const std::vector<uint8_t> &data,
+ const std::string &metadata_name)
+{
+ auto buffer_id = static_cast<uint32_t>(md._buffers.size());
+ md._buffers.push_back(circle::CreateBufferDirect(builder, &data));
+ return circle::CreateMetadataDirect(builder, metadata_name.c_str(), buffer_id);
+}
+
+} // namespace
+
+namespace luci
+{
+
+// 'source_table' is encoded to binary format.
+const std::vector<uint8_t> CircleExportMetadata::encoded_source_table(void)
+{
+ std::vector<uint8_t> data;
+
+ write_u32(data, _source_table.size());
+
+ for (auto &kv : _source_table)
+ {
+ const auto id = kv.first;
+ write_u32(data, id);
+
+ const auto origin_name = kv.second;
+ const auto length = origin_name.length();
+ write_u32(data, length + 1); // name + '\0
+
+ for (uint32_t i = 0; i < length; ++i)
+ {
+ data.emplace_back(origin_name.at(i));
+ }
+ data.emplace_back('\0');
+ }
+
+ return data;
+}
+
+// 'op_table' is encoded to binary format.
+const std::vector<uint8_t> CircleExportMetadata::encoded_op_table(void)
+{
+ std::vector<uint8_t> data;
+
+ write_u32(data, _op_table.size());
+
+ for (auto &kv : _op_table)
+ {
+ const auto id = kv.first;
+ write_u32(data, id);
+
+ const auto origins = kv.second;
+ const auto node_num = origins.size();
+ write_u32(data, node_num);
+
+ for (auto origin : origins)
+ {
+ write_u32(data, origin);
+ }
+ }
+
+ return data;
+}
+
+} // namespace luci
+
+namespace luci
+{
+
+std::vector<flatbuffers::Offset<circle::Metadata>>
+createCircleMetadataVector(flatbuffers::FlatBufferBuilder &builder, luci::SerializedModelData &md)
+{
+ std::vector<flatbuffers::Offset<circle::Metadata>> metadata_vec;
+
+ auto settings = luci::UserSettings::settings();
+ if (settings->get(luci::UserSettings::Key::ProfilingDataGen))
+ {
+ metadata_vec.emplace_back(
+ metadata_offset(builder, md, md._metadata.encoded_source_table(), "ONE_source_table"));
+
+ metadata_vec.emplace_back(
+ metadata_offset(builder, md, md._metadata.encoded_op_table(), "ONE_op_table"));
+ }
+
+ return metadata_vec;
+}
+
+} // namespace luci
diff --git a/compiler/luci/export/src/CircleExportMetadata.h b/compiler/luci/export/src/CircleExportMetadata.h
new file mode 100644
index 000000000..10cda421e
--- /dev/null
+++ b/compiler/luci/export/src/CircleExportMetadata.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CIRCLE_EXPORT_METADATA_H__
+#define __LUCI_CIRCLE_EXPORT_METADATA_H__
+
+#include "SerializedData.h"
+
+#include <flatbuffers/flatbuffers.h>
+#include <mio/circle/schema_generated.h>
+
+namespace luci
+{
+
+/**
+ * @brief Create Metadata corresponding to model metadata
+ */
+std::vector<flatbuffers::Offset<circle::Metadata>>
+createCircleMetadataVector(flatbuffers::FlatBufferBuilder &builder, SerializedModelData &md);
+
+} // namespace luci
+
+#endif // __LUCI_CIRCLE_EXPORT_METADATA_H__
diff --git a/compiler/luci/export/src/CircleExporterImpl.cpp b/compiler/luci/export/src/CircleExporterImpl.cpp
index df7542797..7e218191c 100644
--- a/compiler/luci/export/src/CircleExporterImpl.cpp
+++ b/compiler/luci/export/src/CircleExporterImpl.cpp
@@ -16,10 +16,13 @@
#include "CircleExporterImpl.h"
#include "Optimize.h"
+#include "CircleExportMetadata.h"
#include "CircleTensorExporter.h"
#include "CircleOperationExporter.h"
#include "CircleExporterUtils.h"
+#include <luci/IR/CircleNodes.h>
+
#include <oops/InternalExn.h>
#include <mio/circle/schema_generated.h>
#include <flatbuffers/flatbuffers.h>
@@ -27,46 +30,16 @@
#include <cassert>
#include <unordered_map>
#include <string>
-#include <stdexcept>
+#include <vector>
namespace
{
-luci::CircleInput *input_node(loco::Graph *g, const loco::GraphInputIndex &index)
-{
- for (uint32_t n = 0; n < g->nodes()->size(); ++n)
- {
- if (auto input = dynamic_cast<luci::CircleInput *>(g->nodes()->at(n)))
- {
- if (input->indexed() && input->index() == index)
- {
- return input;
- }
- }
- }
- return nullptr;
-}
-
-luci::CircleOutput *output_node(loco::Graph *g, const loco::GraphOutputIndex &index)
-{
- for (uint32_t n = 0; n < g->nodes()->size(); ++n)
- {
- if (auto output = dynamic_cast<luci::CircleOutput *>(g->nodes()->at(n)))
- {
- if (output->indexed() && output->index() == index)
- {
- return output;
- }
- }
- }
- return nullptr;
-}
-
void registerGraphInputTensors(loco::Graph *graph, luci::SubGraphContext &ctx)
{
for (uint32_t n = 0; n < graph->inputs()->size(); ++n)
{
- auto node = input_node(graph, n);
+ auto node = luci::input_node(graph, n);
assert(node != nullptr);
ctx._inputs.push_back(luci::get_tensor_index(node));
}
@@ -76,7 +49,7 @@ void registerGraphOutputTensors(loco::Graph *graph, luci::SubGraphContext &ctx)
{
for (uint32_t n = 0; n < graph->outputs()->size(); ++n)
{
- auto push = output_node(graph, n);
+ auto push = luci::output_node(graph, n);
assert(push != nullptr);
auto node = push->from();
assert(node != nullptr);
@@ -113,7 +86,7 @@ encodeOperatorCodes(FlatBufferBuilder &builder, std::unordered_map<luci::OpCode,
else
{
operator_codes_vec[idx] =
- CreateOperatorCode(builder, it.first.opcode, builder.CreateString(it.first.custom_code));
+ CreateOperatorCode(builder, it.first.opcode, builder.CreateString(it.first.custom_code));
}
}
@@ -186,16 +159,16 @@ void CircleExporterImpl::exportGraph(loco::Graph *graph)
std::string description_str = "nnpackage";
auto description = _builder.CreateString(description_str);
+ // Metadata
+ auto metadata_vec = createCircleMetadataVector(_builder, md);
+ auto metadata = _builder.CreateVector(std::vector<Offset<Metadata>>(metadata_vec));
+
// create array of buffers
auto buffers = _builder.CreateVector(md._buffers);
- // empty metadata
- std::vector<int> metadata_buffer_vec;
- auto metadata_buffer = _builder.CreateVector(metadata_buffer_vec);
-
// Model
auto model_offset = CreateModel(_builder, version, operator_codes, subgraphs, description,
- buffers, metadata_buffer);
+ buffers, 0 /* metadata_buffer */, metadata);
FinishModelBuffer(_builder, model_offset);
}
@@ -250,19 +223,19 @@ void CircleExporterImpl::exportModule(Module *module)
std::string description_str = "nnpackage";
auto description = _builder.CreateString(description_str);
+ // Metadata
+ auto metadata_vec = createCircleMetadataVector(_builder, md);
+ auto metadata = _builder.CreateVector(std::vector<Offset<Metadata>>(metadata_vec));
+
// create array of buffers
auto buffers = _builder.CreateVector(md._buffers);
- // empty metadata
- std::vector<int> metadata_buffer_vec;
- auto metadata_buffer = _builder.CreateVector(metadata_buffer_vec);
-
// This version is taken from comment in fbs
constexpr uint32_t version = 0;
// Model
auto model_offset = CreateModel(_builder, version, operator_codes, subgraphs, description,
- buffers, metadata_buffer);
+ buffers, 0 /* metadata_buffer */, metadata);
FinishModelBuffer(_builder, model_offset);
}
diff --git a/compiler/luci/export/src/CircleExporterImpl.h b/compiler/luci/export/src/CircleExporterImpl.h
index e5d5b5a00..069f62afd 100644
--- a/compiler/luci/export/src/CircleExporterImpl.h
+++ b/compiler/luci/export/src/CircleExporterImpl.h
@@ -22,8 +22,6 @@
#include "SerializedData.h"
-#include "SerializedData.h"
-
#include <mio/circle/schema_generated.h>
#include <loco.h>
diff --git a/compiler/luci/export/src/CircleExporterUtils.cpp b/compiler/luci/export/src/CircleExporterUtils.cpp
index 3715513e0..1b21fdd86 100644
--- a/compiler/luci/export/src/CircleExporterUtils.cpp
+++ b/compiler/luci/export/src/CircleExporterUtils.cpp
@@ -208,13 +208,13 @@ circle::Padding getOpPadding(const loco::Padding2D *pad, const loco::Stride<2> *
//
// NOTE input and output 'feature' map are shape of NHWC
bool same_padding_criterion_1 =
- (static_cast<uint32_t>(ofm._dims[1]) == (ifm._dims[1] - 1) / stride->vertical() + 1) &&
- (static_cast<uint32_t>(ofm._dims[2]) == (ifm._dims[2] - 1) / stride->horizontal() + 1);
+ (static_cast<uint32_t>(ofm._dims[1]) == (ifm._dims[1] - 1) / stride->vertical() + 1) &&
+ (static_cast<uint32_t>(ofm._dims[2]) == (ifm._dims[2] - 1) / stride->horizontal() + 1);
// For same padding, rear padding is same or bigger than front padding by at most 1
bool same_padding_criterion_2 =
- (pad->top() <= pad->bottom()) && (pad->bottom() <= pad->top() + 1) &&
- (pad->left() <= pad->right()) && (pad->right() <= pad->left() + 1);
+ (pad->top() <= pad->bottom()) && (pad->bottom() <= pad->top() + 1) &&
+ (pad->left() <= pad->right()) && (pad->right() <= pad->left() + 1);
if (same_padding_criterion_1 && same_padding_criterion_2)
return circle::Padding_SAME;
diff --git a/compiler/luci/export/src/CircleOperationExporter.cpp b/compiler/luci/export/src/CircleOperationExporter.cpp
index 4343cf3c9..4bf674b9b 100644
--- a/compiler/luci/export/src/CircleOperationExporter.cpp
+++ b/compiler/luci/export/src/CircleOperationExporter.cpp
@@ -21,6 +21,7 @@
#include <luci/IR/CircleNode.h>
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/UserSettings.h>
#include <luci/Log.h>
@@ -53,8 +54,8 @@ template <class CirclePool2D>
void export_pool_2d(ExportContext &ctx, CirclePool2D *node, circle::BuiltinOperator builtin_op)
{
LUCI_ASSERT(builtin_op == circle::BuiltinOperator_MAX_POOL_2D ||
- builtin_op == circle::BuiltinOperator_L2_POOL_2D ||
- builtin_op == circle::BuiltinOperator_AVERAGE_POOL_2D,
+ builtin_op == circle::BuiltinOperator_L2_POOL_2D ||
+ builtin_op == circle::BuiltinOperator_AVERAGE_POOL_2D,
"Should be L2Pool, MaxPool or AvgPool");
LUCI_ASSERT(node->padding() != luci::Padding::UNDEFINED, "Padding is not set");
@@ -81,7 +82,7 @@ void export_node(ExportContext &ctx, loco::Node *node, circle::BuiltinOperator b
circle::BuiltinOptions bot, flatbuffers::Offset<void> options_offset)
{
uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(bop, loco::must_cast<luci::CircleNode *>(node)->op_version());
+ ctx.md.registerBuiltinOpcode(bop, loco::must_cast<luci::CircleNode *>(node)->op_version());
std::vector<int32_t> inputs_vec;
std::vector<int32_t> outputs_vec{get_tensor_index(node)};
for (uint32_t i = 0; i < node->arity(); ++i)
@@ -98,7 +99,7 @@ void export_node(ExportContext &ctx, loco::Node *node, circle::BuiltinOperator b
void export_node(ExportContext &ctx, loco::Node *node, circle::BuiltinOperator bop)
{
uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(bop, loco::must_cast<luci::CircleNode *>(node)->op_version());
+ ctx.md.registerBuiltinOpcode(bop, loco::must_cast<luci::CircleNode *>(node)->op_version());
std::vector<int32_t> inputs_vec;
std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
for (uint32_t i = 0; i < node->arity(); ++i)
@@ -152,7 +153,7 @@ void export_node(ExportContext &ctx, luci::CircleCast *node)
void export_node(ExportContext &ctx, luci::CircleConcatenation *node)
{
uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_CONCATENATION, node->op_version());
+ ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_CONCATENATION, node->op_version());
std::vector<int32_t> inputs_vec;
std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
@@ -171,6 +172,7 @@ void export_node(ExportContext &ctx, luci::CircleConcatenation *node)
void export_node(ExportContext &ctx, luci::CircleCustom *node)
{
auto custom_outputs = loco::succs(node);
+ assert(custom_outputs.size() == node->numOutputs());
uint32_t op_idx = ctx.md.registerCustomOpcode(node->custom_code());
std::vector<int32_t> inputs_vec;
@@ -260,9 +262,9 @@ void export_node(ExportContext &ctx, luci::CircleNonMaxSuppressionV4 *node)
uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_NON_MAX_SUPPRESSION_V4,
node->op_version());
std::vector<int32_t> inputs_vec{
- get_tensor_index(node->boxes()), get_tensor_index(node->scores()),
- get_tensor_index(node->max_output_size()), get_tensor_index(node->iou_threshold()),
- get_tensor_index(node->score_threshold()),
+ get_tensor_index(node->boxes()), get_tensor_index(node->scores()),
+ get_tensor_index(node->max_output_size()), get_tensor_index(node->iou_threshold()),
+ get_tensor_index(node->score_threshold()),
};
std::vector<int32_t> outputs_vec;
@@ -290,8 +292,8 @@ void export_node(ExportContext &ctx, luci::CircleNonMaxSuppressionV4 *node)
auto outputs = ctx.builder.CreateVector(outputs_vec);
auto options = CreateNonMaxSuppressionV4Options(ctx.builder);
auto op_offset =
- CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_NonMaxSuppressionV4Options, options.Union());
+ CreateOperator(ctx.builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_NonMaxSuppressionV4Options, options.Union());
ctx.gd._operators.push_back(op_offset);
}
@@ -303,9 +305,9 @@ void export_node(ExportContext &ctx, luci::CircleNonMaxSuppressionV5 *node)
uint32_t op_idx = ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_NON_MAX_SUPPRESSION_V5,
node->op_version());
std::vector<int32_t> inputs_vec{
- get_tensor_index(node->boxes()), get_tensor_index(node->scores()),
- get_tensor_index(node->max_output_size()), get_tensor_index(node->iou_threshold()),
- get_tensor_index(node->score_threshold()), get_tensor_index(node->soft_nms_sigma()),
+ get_tensor_index(node->boxes()), get_tensor_index(node->scores()),
+ get_tensor_index(node->max_output_size()), get_tensor_index(node->iou_threshold()),
+ get_tensor_index(node->score_threshold()), get_tensor_index(node->soft_nms_sigma()),
};
std::vector<int32_t> outputs_vec;
@@ -333,15 +335,15 @@ void export_node(ExportContext &ctx, luci::CircleNonMaxSuppressionV5 *node)
auto outputs = ctx.builder.CreateVector(outputs_vec);
auto options = CreateNonMaxSuppressionV5Options(ctx.builder);
auto op_offset =
- CreateOperator(ctx.builder, op_idx, inputs, outputs,
- circle::BuiltinOptions_NonMaxSuppressionV5Options, options.Union());
+ CreateOperator(ctx.builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_NonMaxSuppressionV5Options, options.Union());
ctx.gd._operators.push_back(op_offset);
}
void export_node(ExportContext &ctx, luci::CircleReverseV2 *node)
{
uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_REVERSE_V2, node->op_version());
+ ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_REVERSE_V2, node->op_version());
std::vector<int32_t> inputs_vec{get_tensor_index(node->tensor()), get_tensor_index(node->axis())};
std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
auto inputs = ctx.builder.CreateVector(inputs_vec);
@@ -397,7 +399,7 @@ void export_node(ExportContext &ctx, luci::CircleSplitV *node)
assert(int32_t(split_outs.size()) == node->num_split());
uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_SPLIT_V, node->op_version());
+ ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_SPLIT_V, node->op_version());
std::vector<int32_t> inputs_vec{get_tensor_index(node->input()),
get_tensor_index(node->size_splits()),
get_tensor_index(node->split_dim())};
@@ -438,7 +440,7 @@ void export_node(ExportContext &ctx, luci::CircleTopKV2 *node)
assert(outs_count == 2);
uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_TOPK_V2, node->op_version());
+ ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_TOPK_V2, node->op_version());
std::vector<int32_t> inputs_vec{get_tensor_index(node->input()), get_tensor_index(node->k())};
std::vector<int32_t> outputs_vec;
@@ -475,7 +477,7 @@ void export_node(ExportContext &ctx, luci::CircleUnique *node)
auto unique_outs = loco::succs(node);
assert(int32_t(unique_outs.size()) == 2);
uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNIQUE, node->op_version());
+ ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNIQUE, node->op_version());
std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
std::vector<int32_t> outputs_vec;
@@ -526,7 +528,7 @@ void export_node(ExportContext &ctx, luci::CircleUnpack *node)
}
uint32_t op_idx =
- ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNPACK, node->op_version());
+ ctx.md.registerBuiltinOpcode(circle::BuiltinOperator_UNPACK, node->op_version());
std::vector<int32_t> inputs_vec{get_tensor_index(node->value())};
std::vector<int32_t> outputs_vec;
@@ -622,6 +624,7 @@ public:
void visit(luci::CircleAveragePool2D *) final;
void visit(luci::CircleBatchMatMul *) final;
void visit(luci::CircleBatchToSpaceND *) final;
+ void visit(luci::CircleBidirectionalSequenceLSTM *) final;
void visit(luci::CircleCast *) final;
void visit(luci::CircleCeil *) final;
void visit(luci::CircleConcatenation *) final;
@@ -637,6 +640,7 @@ public:
void visit(luci::CircleEqual *) final;
void visit(luci::CircleExp *) final;
void visit(luci::CircleExpandDims *) final;
+ void visit(luci::CircleFakeQuant *) final;
void visit(luci::CircleFill *) final;
void visit(luci::CircleFloor *) final;
void visit(luci::CircleFloorDiv *) final;
@@ -734,6 +738,7 @@ public:
void visit(luci::CircleOutputDummy *) final {}
void visit(luci::CircleOutputExclude *) final {}
// Virtual for multiple-outputs
+ void visit(luci::CircleBidirectionalSequenceLSTMOut *) final {}
void visit(luci::CircleCustomOut *) final {}
void visit(luci::CircleIfOut *) final {}
void visit(luci::CircleNonMaxSuppressionV4Out *) final {}
@@ -782,8 +787,8 @@ void OperationExporter::visit(luci::CircleAbs *node)
void OperationExporter::visit(luci::CircleAdd *node)
{
export_simple(
- node, circle::BuiltinOperator_ADD, circle::BuiltinOptions_AddOptions,
- CreateAddOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
+ node, circle::BuiltinOperator_ADD, circle::BuiltinOptions_AddOptions,
+ CreateAddOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
}
void OperationExporter::visit(luci::CircleAddN *node) { export_node(_ctx, node); }
@@ -791,15 +796,15 @@ void OperationExporter::visit(luci::CircleAddN *node) { export_node(_ctx, node);
void OperationExporter::visit(luci::CircleArgMax *node)
{
export_simple(
- node, circle::BuiltinOperator_ARG_MAX, circle::BuiltinOptions_ArgMaxOptions,
- CreateArgMaxOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union());
+ node, circle::BuiltinOperator_ARG_MAX, circle::BuiltinOptions_ArgMaxOptions,
+ CreateArgMaxOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union());
}
void OperationExporter::visit(luci::CircleArgMin *node)
{
export_simple(
- node, circle::BuiltinOperator_ARG_MIN, circle::BuiltinOptions_ArgMinOptions,
- CreateArgMinOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union());
+ node, circle::BuiltinOperator_ARG_MIN, circle::BuiltinOptions_ArgMinOptions,
+ CreateArgMinOptions(_ctx.builder, to_circle_tensortype(node->output_type())).Union());
}
void OperationExporter::visit(luci::CircleAveragePool2D *node)
@@ -814,6 +819,48 @@ void OperationExporter::visit(luci::CircleBatchMatMul *node)
CreateBatchMatMulOptions(_ctx.builder, node->adj_x(), node->adj_y()).Union());
}
+void OperationExporter::visit(luci::CircleBidirectionalSequenceLSTM *node)
+{
+ auto bidi_lstm_outs = loco::succs(node);
+ assert((bidi_lstm_outs.size() == 1) || (bidi_lstm_outs.size() == 2));
+ uint32_t op_idx = _ctx.md.registerBuiltinOpcode(
+ circle::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, node->op_version());
+
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+ std::vector<int32_t> outputs_vec;
+
+ for (int32_t index = 0; index < 2; index++)
+ {
+ // store in order of index
+ bool found = false;
+ for (auto out : bidi_lstm_outs)
+ {
+ auto bidi_lstm_out = loco::must_cast<luci::CircleBidirectionalSequenceLSTMOut *>(out);
+ if (bidi_lstm_out->index() == index)
+ {
+ outputs_vec.push_back(get_tensor_index(bidi_lstm_out));
+ found = true;
+ break;
+ }
+ }
+ if (!found)
+ {
+ INTERNAL_EXN("Invalid BidirectionalSequenceLSTM output");
+ }
+ }
+
+ auto inputs = _ctx.builder.CreateVector(inputs_vec);
+ auto outputs = _ctx.builder.CreateVector(outputs_vec);
+ auto options = CreateBidirectionalSequenceLSTMOptions(
+ _ctx.builder, to_circle_actfunc(node->fusedActivationFunction()), node->cell_clip(),
+ node->proj_clip(), node->merge_outputs(), node->time_major(),
+ node->asymmetric_quantize_inputs());
+ auto op_offset =
+ CreateOperator(_ctx.builder, op_idx, inputs, outputs,
+ circle::BuiltinOptions_BidirectionalSequenceLSTMOptions, options.Union());
+ _ctx.gd._operators.push_back(op_offset);
+}
+
void OperationExporter::visit(luci::CircleCast *node) { export_node(_ctx, node); }
void OperationExporter::visit(luci::CircleCeil *node)
@@ -837,7 +884,7 @@ void OperationExporter::visit(luci::CircleConv2D *node)
node->stride()->w(), node->stride()->h(),
to_circle_actfunc(node->fusedActivationFunction()),
node->dilation()->w(), node->dilation()->h())
- .Union());
+ .Union());
}
void OperationExporter::visit(luci::CircleCos *node)
@@ -857,14 +904,13 @@ void OperationExporter::visit(luci::CircleDepthToSpace *node)
void OperationExporter::visit(luci::CircleDepthwiseConv2D *node)
{
- export_simple(node, circle::BuiltinOperator_DEPTHWISE_CONV_2D,
- circle::BuiltinOptions_DepthwiseConv2DOptions,
- CreateDepthwiseConv2DOptions(_ctx.builder, getOpPadding(node->padding()),
- node->stride()->w(), node->stride()->h(),
- node->depthMultiplier(),
- to_circle_actfunc(node->fusedActivationFunction()),
- node->dilation()->w(), node->dilation()->h())
- .Union());
+ export_simple(
+ node, circle::BuiltinOperator_DEPTHWISE_CONV_2D, circle::BuiltinOptions_DepthwiseConv2DOptions,
+ CreateDepthwiseConv2DOptions(_ctx.builder, getOpPadding(node->padding()), node->stride()->w(),
+ node->stride()->h(), node->depthMultiplier(),
+ to_circle_actfunc(node->fusedActivationFunction()),
+ node->dilation()->w(), node->dilation()->h())
+ .Union());
}
void OperationExporter::visit(luci::CircleDequantize *node)
@@ -875,8 +921,8 @@ void OperationExporter::visit(luci::CircleDequantize *node)
void OperationExporter::visit(luci::CircleDiv *node)
{
export_simple(
- node, circle::BuiltinOperator_DIV, circle::BuiltinOptions_DivOptions,
- CreateDivOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
+ node, circle::BuiltinOperator_DIV, circle::BuiltinOptions_DivOptions,
+ CreateDivOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
}
void OperationExporter::visit(luci::CircleElu *node)
@@ -902,6 +948,14 @@ void OperationExporter::visit(luci::CircleExpandDims *node)
CreateExpandDimsOptions(_ctx.builder).Union());
}
+void OperationExporter::visit(luci::CircleFakeQuant *node)
+{
+ export_simple(node, circle::BuiltinOperator_FAKE_QUANT, circle::BuiltinOptions_FakeQuantOptions,
+ CreateFakeQuantOptions(_ctx.builder, node->min(), node->max(), node->num_bits(),
+ node->narrow_range())
+ .Union());
+}
+
void OperationExporter::visit(luci::CircleFill *node)
{
export_simple(node, circle::BuiltinOperator_FILL, circle::BuiltinOptions_FillOptions,
@@ -928,10 +982,10 @@ void OperationExporter::visit(luci::CircleFloorMod *node)
void OperationExporter::visit(luci::CircleFullyConnected *node)
{
export_simple(
- node, circle::BuiltinOperator_FULLY_CONNECTED, circle::BuiltinOptions_FullyConnectedOptions,
- CreateFullyConnectedOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction()),
- to_circle_weightsformat(node->weights_format()))
- .Union());
+ node, circle::BuiltinOperator_FULLY_CONNECTED, circle::BuiltinOptions_FullyConnectedOptions,
+ CreateFullyConnectedOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction()),
+ to_circle_weightsformat(node->weights_format()))
+ .Union());
}
void OperationExporter::visit(luci::CircleGather *node)
@@ -964,9 +1018,8 @@ void OperationExporter::visit(luci::CircleIf *node) { export_node(_ctx, node); }
void OperationExporter::visit(luci::CircleL2Normalize *node)
{
export_simple(
- node, circle::BuiltinOperator_L2_NORMALIZATION, circle::BuiltinOptions_L2NormOptions,
- CreateL2NormOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction()))
- .Union());
+ node, circle::BuiltinOperator_L2_NORMALIZATION, circle::BuiltinOptions_L2NormOptions,
+ CreateL2NormOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
}
void OperationExporter::visit(luci::CircleL2Pool2D *node)
@@ -998,7 +1051,7 @@ void OperationExporter::visit(luci::CircleLocalResponseNormalization *node)
circle::BuiltinOptions_LocalResponseNormalizationOptions,
CreateLocalResponseNormalizationOptions(_ctx.builder, node->radius(), node->bias(),
node->alpha(), node->beta())
- .Union());
+ .Union());
}
void OperationExporter::visit(luci::CircleLog *node)
@@ -1074,15 +1127,15 @@ void OperationExporter::visit(luci::CircleMinimum *node)
void OperationExporter::visit(luci::CircleMirrorPad *node)
{
export_simple(
- node, circle::BuiltinOperator_MIRROR_PAD, circle::BuiltinOptions_MirrorPadOptions,
- CreateMirrorPadOptions(_ctx.builder, to_circle_mirrorpadmode(node->mode())).Union());
+ node, circle::BuiltinOperator_MIRROR_PAD, circle::BuiltinOptions_MirrorPadOptions,
+ CreateMirrorPadOptions(_ctx.builder, to_circle_mirrorpadmode(node->mode())).Union());
}
void OperationExporter::visit(luci::CircleMul *node)
{
export_simple(
- node, circle::BuiltinOperator_MUL, circle::BuiltinOptions_MulOptions,
- CreateMulOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
+ node, circle::BuiltinOperator_MUL, circle::BuiltinOptions_MulOptions,
+ CreateMulOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
}
void OperationExporter::visit(luci::CircleNeg *node)
@@ -1190,7 +1243,7 @@ void OperationExporter::visit(luci::CircleReluN1To1 *node)
void OperationExporter::visit(luci::CircleReshape *node)
{
auto new_shape = _ctx.builder.CreateVector<int32_t>(
- node->newShape()->rank(), [node](size_t i) { return node->newShape()->dim(i); });
+ node->newShape()->rank(), [node](size_t i) { return node->newShape()->dim(i); });
export_simple(node, circle::BuiltinOperator_RESHAPE, circle::BuiltinOptions_ReshapeOptions,
CreateReshapeOptions(_ctx.builder, new_shape).Union());
@@ -1199,9 +1252,9 @@ void OperationExporter::visit(luci::CircleReshape *node)
void OperationExporter::visit(luci::CircleResizeBilinear *node)
{
export_simple(
- node, circle::BuiltinOperator_RESIZE_BILINEAR, circle::BuiltinOptions_ResizeBilinearOptions,
- CreateResizeBilinearOptions(_ctx.builder, node->align_corners(), node->half_pixel_centers())
- .Union());
+ node, circle::BuiltinOperator_RESIZE_BILINEAR, circle::BuiltinOptions_ResizeBilinearOptions,
+ CreateResizeBilinearOptions(_ctx.builder, node->align_corners(), node->half_pixel_centers())
+ .Union());
}
void OperationExporter::visit(luci::CircleResizeNearestNeighbor *node)
@@ -1214,8 +1267,8 @@ void OperationExporter::visit(luci::CircleResizeNearestNeighbor *node)
void OperationExporter::visit(luci::CircleReverseSequence *node)
{
export_simple(
- node, circle::BuiltinOperator_REVERSE_SEQUENCE, circle::BuiltinOptions_ReverseSequenceOptions,
- CreateReverseSequenceOptions(_ctx.builder, node->seq_axis(), node->batch_axis()).Union());
+ node, circle::BuiltinOperator_REVERSE_SEQUENCE, circle::BuiltinOptions_ReverseSequenceOptions,
+ CreateReverseSequenceOptions(_ctx.builder, node->seq_axis(), node->batch_axis()).Union());
}
void OperationExporter::visit(luci::CircleReverseV2 *node) { export_node(_ctx, node); }
@@ -1334,14 +1387,14 @@ void OperationExporter::visit(luci::CircleStridedSlice *node)
CreateStridedSliceOptions(_ctx.builder, node->begin_mask(), node->end_mask(),
node->ellipsis_mask(), node->new_axis_mask(),
node->shrink_axis_mask())
- .Union());
+ .Union());
}
void OperationExporter::visit(luci::CircleSub *node)
{
export_simple(
- node, circle::BuiltinOperator_SUB, circle::BuiltinOptions_SubOptions,
- CreateSubOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
+ node, circle::BuiltinOperator_SUB, circle::BuiltinOptions_SubOptions,
+ CreateSubOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())).Union());
}
void OperationExporter::visit(luci::CircleSum *node)
@@ -1375,7 +1428,7 @@ void OperationExporter::visit(luci::CircleTransposeConv *node)
circle::BuiltinOptions_TransposeConvOptions,
CreateTransposeConvOptions(_ctx.builder, getOpPadding(node->padding()),
node->stride()->w(), node->stride()->h())
- .Union());
+ .Union());
}
void OperationExporter::visit(luci::CircleUnidirectionalSequenceLSTM *node)
@@ -1383,10 +1436,10 @@ void OperationExporter::visit(luci::CircleUnidirectionalSequenceLSTM *node)
export_simple(node, circle::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
circle::BuiltinOptions_UnidirectionalSequenceLSTMOptions,
CreateUnidirectionalSequenceLSTMOptions(
- _ctx.builder, to_circle_actfunc(node->fusedActivationFunction()),
- node->cell_clip(), node->proj_clip(), node->time_major(),
- node->asymmetric_quantize_inputs())
- .Union());
+ _ctx.builder, to_circle_actfunc(node->fusedActivationFunction()),
+ node->cell_clip(), node->proj_clip(), node->time_major(),
+ node->asymmetric_quantize_inputs())
+ .Union());
}
void OperationExporter::visit(luci::CircleUnique *node) { export_node(_ctx, node); }
@@ -1413,14 +1466,14 @@ void OperationExporter::visit(luci::CircleBCQFullyConnected *node)
circle::BuiltinOptions_BCQFullyConnectedOptions,
CreateBCQFullyConnectedOptions(_ctx.builder, node->weights_hidden_size(),
to_circle_actfunc(node->fusedActivationFunction()))
- .Union());
+ .Union());
}
void OperationExporter::visit(luci::CircleBCQGather *node)
{
export_simple(
- node, circle::BuiltinOperator_BCQ_GATHER, circle::BuiltinOptions_BCQGatherOptions,
- CreateBCQGatherOptions(_ctx.builder, node->input_hidden_size(), node->axis()).Union());
+ node, circle::BuiltinOperator_BCQ_GATHER, circle::BuiltinOptions_BCQGatherOptions,
+ CreateBCQGatherOptions(_ctx.builder, node->input_hidden_size(), node->axis()).Union());
}
void OperationExporter::visit(luci::CircleInstanceNorm *node)
@@ -1429,7 +1482,7 @@ void OperationExporter::visit(luci::CircleInstanceNorm *node)
circle::BuiltinOptions_InstanceNormOptions,
CreateInstanceNormOptions(_ctx.builder, node->epsilon(),
to_circle_actfunc(node->fusedActivationFunction()))
- .Union());
+ .Union());
}
void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, SerializedModelData &md,
@@ -1439,7 +1492,19 @@ void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder, Seria
{
ExportContext ctx{builder, md, gd};
OperationExporter exporter{ctx};
+
+ const auto ops_size = gd._operators.size();
+
circle_node->accept(&exporter);
+ if (has_origin(circle_node) && ops_size != gd._operators.size())
+ {
+ const auto node_id = gd._operators.size() - 1;
+ for (auto source : get_origin(circle_node)->sources())
+ {
+ md._metadata.add_source_table(source->id(), source->name());
+ md._metadata.add_op_table(node_id, source->id());
+ }
+ }
}
else
{
diff --git a/compiler/luci/export/src/CircleTensorExporter.cpp b/compiler/luci/export/src/CircleTensorExporter.cpp
index 9bdfa0079..fefdf4e73 100644
--- a/compiler/luci/export/src/CircleTensorExporter.cpp
+++ b/compiler/luci/export/src/CircleTensorExporter.cpp
@@ -15,11 +15,9 @@
*/
#include "CircleTensorExporter.h"
-#include "TypeBridge.h"
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/IR/CircleShapeSignature.h>
#include <luci/Service/CircleTypeInference.h>
#include <luci/Service/CircleShapeInference.h>
#include <luci/Log.h>
@@ -38,10 +36,10 @@ namespace
using namespace luci;
-class CircleTensoInfo
+class CircleTensorInfo
{
public:
- CircleTensoInfo() = default;
+ CircleTensorInfo() = default;
public:
void name(const std::string &name) { _name = name; }
@@ -54,9 +52,6 @@ public:
const ShapeDescription &shape(void) const { return _shape; }
void shape(const ShapeDescription &shape) { _shape = shape; }
- const ShapeSignature &shape_signature(void) const { return _shape_signature; }
- void shape_signature(const ShapeSignature &ss) { _shape_signature = ss; }
-
luci::ShapeStatus shape_status(void) const { return _shape_status; }
void shape_status(luci::ShapeStatus ss) { _shape_status = ss; }
@@ -75,7 +70,6 @@ private:
circle::TensorType _dtype{circle::TensorType_FLOAT32};
ShapeDescription _shape{};
- ShapeSignature _shape_signature;
luci::ShapeStatus _shape_status{luci::ShapeStatus::UNDEFINED};
luci::CircleConst *_content = nullptr;
@@ -83,7 +77,29 @@ private:
luci::SparsityParam *_sparsityparam = nullptr;
};
-using CircleTensorContext = std::vector<CircleTensoInfo>;
+class CircleTensorContext
+{
+public:
+ CircleTensorContext() = default;
+
+public:
+ void emplace_back(CircleTensorInfo &ti)
+ {
+ assert(_names.find(ti.name()) == _names.end());
+ _tis.emplace_back(ti);
+ _names.insert(ti.name());
+ }
+ size_t size(void) const { return _tis.size(); }
+ std::vector<CircleTensorInfo>::iterator begin(void) { return _tis.begin(); }
+ std::vector<CircleTensorInfo>::iterator end(void) { return _tis.end(); }
+
+public:
+ bool exist(const std::string &name) const { return _names.find(name) != _names.end(); }
+
+private:
+ std::vector<CircleTensorInfo> _tis;
+ std::set<std::string> _names;
+};
struct NoOpDetector final : public luci::CircleNodeMutableVisitor<bool>
{
@@ -102,17 +118,23 @@ void allocateCircleTensorInfo(CircleNode *node, CircleTensorContext &ctx)
auto tensor_index = static_cast<CircleTensorIndex>(ctx.size());
// TODO Use Graph-level metadata for Input & Output
- // auto tensor_name = "t_" + std::to_string(tensor_index);
std::string tensor_name = node->name();
- if (tensor_name.empty())
- tensor_name = "t_" + std::to_string(tensor_index);
+ // NOTE tensor_name maybe empty. this assertion will alert when this happens.
+ // currently we require tensor should have a name.
+ // TODO if this breaks, fix the cause or permit empty tensor_name.
+ assert(!tensor_name.empty());
+ if (ctx.exist(tensor_name))
+ {
+ // NOTE this should assign unique name for a Tensor.
+ tensor_name = tensor_name + "_" + std::to_string(tensor_index);
+ assert(!ctx.exist(tensor_name));
+ }
INFO(l) << "[luci] Tensor for " << tensor_name << ": " << tensor_index << std::endl;
- CircleTensoInfo tensor_info;
+ CircleTensorInfo tensor_info;
tensor_info.name(tensor_name);
tensor_info.dtype(to_circle_tensortype(node->dtype()));
- tensor_info.shape_signature(node->shape_signature());
if (node->shape_status() == ShapeStatus::VALID)
tensor_info.shape(to_shape_description(node));
tensor_info.shape_status(node->shape_status());
@@ -146,19 +168,55 @@ private:
}
public:
+ bool visit(luci::CircleBidirectionalSequenceLSTMOut *) final { return true; }
+ bool visit(luci::CircleCustomOut *) final { return true; }
bool visit(luci::CircleIfOut *) final { return true; }
+ bool visit(luci::CircleNonMaxSuppressionV4Out *) final { return true; }
+ bool visit(luci::CircleNonMaxSuppressionV5Out *) final { return true; }
bool visit(luci::CircleSplitOut *) final { return true; }
bool visit(luci::CircleSplitVOut *) final { return true; }
bool visit(luci::CircleTopKV2Out *) final { return true; }
bool visit(luci::CircleUnpackOut *) final { return true; }
+ bool visit(luci::CircleUniqueOut *) final { return true; }
bool visit(luci::CircleWhileOut *) final { return true; }
+ bool visit(luci::CircleBidirectionalSequenceLSTM *node) final
+ {
+ if (node->merge_outputs())
+ {
+ store_outputs(node, 1);
+ }
+ else
+ {
+ store_outputs(node, 2);
+ }
+ return true;
+ }
+
+ bool visit(luci::CircleCustom *node) final
+ {
+ store_outputs(node, node->numOutputs());
+ return true;
+ }
+
bool visit(luci::CircleIf *node) final
{
store_outputs(node, node->output_count());
return true;
}
+ bool visit(luci::CircleNonMaxSuppressionV4 *node) final
+ {
+ store_outputs(node, 2);
+ return true;
+ }
+
+ bool visit(luci::CircleNonMaxSuppressionV5 *node) final
+ {
+ store_outputs(node, 3);
+ return true;
+ }
+
bool visit(luci::CircleSplit *node) final
{
store_outputs(node, uint32_t(node->num_split()));
@@ -183,6 +241,12 @@ public:
return true;
}
+ bool visit(luci::CircleUnique *node) final
+ {
+ store_outputs(node, 2);
+ return true;
+ }
+
bool visit(luci::CircleWhile *node) final
{
store_outputs(node, node->output_count());
@@ -237,16 +301,26 @@ flatbuffers::Offset<Vector<int32_t>> encodeShape(FlatBufferBuilder &builder,
const ShapeDescription &shape)
{
assert(shape._rank_known && "unknown number of dimensions is not supported");
- return builder.CreateVector(shape._dims);
+
+ std::vector<int32_t> encoded_shape;
+ encoded_shape.resize(shape._dims.size());
+ for (uint32_t i = 0; i < shape._dims.size(); ++i)
+ encoded_shape.at(i) = shape._dims.at(i) == -1 ? 1 : shape._dims.at(i);
+
+ return builder.CreateVector(encoded_shape);
}
flatbuffers::Offset<Vector<int32_t>> encodeShapeSignature(FlatBufferBuilder &builder,
- const ShapeSignature &shape_signature)
+ const ShapeDescription &shape)
{
- if (shape_signature.rank() == 0)
- return 0;
+ assert(shape._rank_known && "unknown number of dimensions is not supported");
+
+ // shape_signature is set if and only if at least one of dimensions are unknown.
+ for (uint32_t i = 0; i < shape._dims.size(); ++i)
+ if (shape._dims.at(i) == -1)
+ return builder.CreateVector(shape._dims);
- return builder.CreateVector(shape_signature.as_vector());
+ return flatbuffers::Offset<Vector<int32_t>>();
}
flatbuffers::Offset<circle::Buffer> encodeOpBuffer(FlatBufferBuilder &builder)
@@ -343,14 +417,14 @@ encodeSparsityParameters(FlatBufferBuilder &builder, luci::SparsityParam *sparsi
// array_segments
auto circle_array_segments = to_circle_sparse_index_vector(builder, it.array_segments());
auto circle_array_segments_type =
- to_circle_sparse_index_vector_type(it.array_segments().type());
+ to_circle_sparse_index_vector_type(it.array_segments().type());
// array_indices
auto circle_array_indices = to_circle_sparse_index_vector(builder, it.array_indices());
auto circle_array_indices_type = to_circle_sparse_index_vector_type(it.array_indices().type());
auto dim_metadata = circle::CreateDimensionMetadata(
- builder, to_circle_dimensiontype(it.format()), it.dense_size(), circle_array_segments_type,
- circle_array_segments, circle_array_indices_type, circle_array_indices);
+ builder, to_circle_dimensiontype(it.format()), it.dense_size(), circle_array_segments_type,
+ circle_array_segments, circle_array_indices_type, circle_array_indices);
dim_metadata_vec.emplace_back(dim_metadata);
}
@@ -358,6 +432,18 @@ encodeSparsityParameters(FlatBufferBuilder &builder, luci::SparsityParam *sparsi
&sparsityparam->block_map, &dim_metadata_vec);
}
+template <loco::DataType DT> bool has_same_elements(luci::CircleConst *lhs, luci::CircleConst *rhs)
+{
+ assert(lhs->dtype() == DT);
+ assert(rhs->dtype() == DT);
+ assert(lhs->size<DT>() == rhs->size<DT>());
+
+ for (uint32_t i = 0; i < lhs->size<DT>(); ++i)
+ if (lhs->at<DT>(i) != rhs->at<DT>(i))
+ return false;
+ return true;
+}
+
bool has_same_values(luci::CircleConst *lhs, luci::CircleConst *rhs)
{
if (lhs->dtype() != rhs->dtype())
@@ -373,34 +459,31 @@ bool has_same_values(luci::CircleConst *lhs, luci::CircleConst *rhs)
switch (lhs->dtype())
{
case loco::DataType::FLOAT32:
- for (uint32_t i = 0; i < lhs->size<loco::DataType::FLOAT32>(); ++i)
- if (lhs->at<loco::DataType::FLOAT32>(i) != rhs->at<loco::DataType::FLOAT32>(i))
- return false;
- break;
+ return has_same_elements<loco::DataType::FLOAT32>(lhs, rhs);
+
+ case loco::DataType::S8:
+ return has_same_elements<loco::DataType::S8>(lhs, rhs);
+
+ case loco::DataType::S16:
+ return has_same_elements<loco::DataType::S16>(lhs, rhs);
case loco::DataType::S32:
- for (uint32_t i = 0; i < lhs->size<loco::DataType::S32>(); ++i)
- if (lhs->at<loco::DataType::S32>(i) != rhs->at<loco::DataType::S32>(i))
- return false;
- break;
+ return has_same_elements<loco::DataType::S32>(lhs, rhs);
case loco::DataType::S64:
- for (uint32_t i = 0; i < lhs->size<loco::DataType::S64>(); ++i)
- if (lhs->at<loco::DataType::S64>(i) != rhs->at<loco::DataType::S64>(i))
- return false;
- break;
+ return has_same_elements<loco::DataType::S64>(lhs, rhs);
+
+ case loco::DataType::U8:
+ return has_same_elements<loco::DataType::U8>(lhs, rhs);
case loco::DataType::BOOL:
- for (uint32_t i = 0; i < lhs->size<loco::DataType::BOOL>(); ++i)
- if (lhs->at<loco::DataType::BOOL>(i) != rhs->at<loco::DataType::BOOL>(i))
- return false;
- break;
+ return has_same_elements<loco::DataType::BOOL>(lhs, rhs);
default:
- return false;
+ break;
}
- return true;
+ return false;
}
uint32_t get_buffer_id(FlatBufferBuilder &builder, SerializedModelData &md, luci::CircleConst *node)
@@ -433,26 +516,28 @@ uint32_t get_buffer_id(FlatBufferBuilder &builder, SerializedModelData &md, luci
}
}
-void exportOpDefinedTensor(const CircleTensoInfo &info, FlatBufferBuilder &builder,
+void exportOpDefinedTensor(const CircleTensorInfo &info, FlatBufferBuilder &builder,
SerializedModelData &md, SerializedGraphData &gd)
{
// Create and register output tensor shape
flatbuffers::Offset<Vector<int32_t>> shape_offset;
+ flatbuffers::Offset<Vector<int32_t>> shape_signature_offset;
if (info.shape_status() == ShapeStatus::VALID)
+ {
shape_offset = encodeShape(builder, info.shape());
+ shape_signature_offset = encodeShapeSignature(builder, info.shape());
+ }
auto quantparam = encodeQuantizationParameters(builder, info.quantparam());
auto sparsityparam = encodeSparsityParameters(builder, info.sparsityparam());
- auto shape_signature_offset = encodeShapeSignature(builder, info.shape_signature());
-
auto buffer_id = get_buffer_id(builder, md, info.content());
auto name_offset = builder.CreateString(info.name());
auto tensor_offset =
- CreateTensor(builder, shape_offset, info.dtype(), buffer_id, name_offset, quantparam,
- /*is_variable*/ false, sparsityparam, shape_signature_offset);
+ CreateTensor(builder, shape_offset, info.dtype(), buffer_id, name_offset, quantparam,
+ /*is_variable*/ false, sparsityparam, shape_signature_offset);
gd._tensors.push_back(tensor_offset);
}
diff --git a/compiler/luci/export/src/Optimize.cpp b/compiler/luci/export/src/Optimize.cpp
index 036a4a2f9..e59f15204 100644
--- a/compiler/luci/export/src/Optimize.cpp
+++ b/compiler/luci/export/src/Optimize.cpp
@@ -17,9 +17,8 @@
#include "Optimize.h"
#include "ProgressReporter.h"
-#include <luci/Pass/ShapeInferencePass.h>
-#include <luci/Pass/ShapeSignatureInferencePass.h>
-#include <luci/Pass/TypeInferencePass.h>
+#include <luci/Pass/CircleShapeInferencePass.h>
+#include <luci/Pass/CircleTypeInferencePass.h>
#include <logo/Phase.h>
@@ -33,9 +32,8 @@ void optimize(loco::Graph *g)
logo::Phase phase;
{
// prepare type and shape before optimization
- phase.emplace_back(std::make_unique<TypeInferencePass>());
- phase.emplace_back(std::make_unique<ShapeInferencePass>());
- phase.emplace_back(std::make_unique<ShapeSignatureInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
// TODO add more optimization passes (with a knob)
}
diff --git a/compiler/luci/export/src/ProgressReporter.h b/compiler/luci/export/src/ProgressReporter.h
index e91f42592..5d55bcd07 100644
--- a/compiler/luci/export/src/ProgressReporter.h
+++ b/compiler/luci/export/src/ProgressReporter.h
@@ -28,7 +28,7 @@ class ProgressReporter : public logo::PhaseEventListener
{
public:
ProgressReporter(loco::Graph *graph, logo::PhaseStrategy strategy)
- : _graph{graph}, _strategy{strategy}
+ : _graph{graph}, _strategy{strategy}
{
// DO NOTHING
}
diff --git a/compiler/luci/export/src/SerializedData.h b/compiler/luci/export/src/SerializedData.h
index c41f50edd..df71e5c21 100644
--- a/compiler/luci/export/src/SerializedData.h
+++ b/compiler/luci/export/src/SerializedData.h
@@ -48,6 +48,37 @@ struct OpCode
}
};
+class CircleExportMetadata
+{
+public:
+ void add_source_table(uint32_t source_id, std::string origin_name)
+ {
+ // Model with multiple subgraph may have different origin_name
+ // even if source_id is same. However, as we do not consider about
+ // multiple subgraph in profiling for now, just do not care those cases
+ // and support them correctly in the future.
+ _source_table.emplace(source_id, origin_name);
+ }
+
+ void add_op_table(uint32_t node_id, uint32_t source_id)
+ {
+ // Model with multiple subgraph may have duplicated node id.
+ // For now, as we do not consider about multiple subgraph in profiling,
+ // just ignore those cases and support them in the future.
+ if (_op_table.find(node_id) == _op_table.end())
+ _op_table.emplace(node_id, std::set<uint32_t>());
+ _op_table.at(node_id).emplace(source_id);
+ }
+
+public:
+ const std::vector<uint8_t> encoded_source_table(void);
+ const std::vector<uint8_t> encoded_op_table(void);
+
+private:
+ std::map<uint32_t, std::string> _source_table;
+ std::map<uint32_t, std::set<uint32_t>> _op_table;
+};
+
} // namespace luci
namespace std
@@ -86,6 +117,7 @@ struct SerializedModelData final
std::unordered_map<OpCode, uint32_t> _operator_codes;
std::vector<flatbuffers::Offset<circle::Buffer>> _buffers;
+ CircleExportMetadata _metadata;
// This is used for removing buffers with same values
std::map<luci::CircleConst *, uint32_t> _cached_buffer_id;
diff --git a/compiler/luci/export/src/TypeBridge.cpp b/compiler/luci/export/src/TypeBridge.cpp
deleted file mode 100644
index 9ccd52376..000000000
--- a/compiler/luci/export/src/TypeBridge.cpp
+++ /dev/null
@@ -1,105 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "TypeBridge.h"
-
-#include "CircleExporterUtils.h"
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Service/CircleTypeInference.h>
-#include <luci/Service/CircleShapeInference.h>
-
-#include <loco/Service/TypeInference.h>
-#include <loco/Service/ShapeInference.h>
-
-namespace
-{
-
-/**
- * @brief CopySelector will return condition of copy shape/type inference to node
- */
-struct CopySelector final : public luci::CircleNodeVisitor<bool>
-{
- // return false(don't copy) for nodes that provides shape/type from nature
- bool visit(const luci::CircleInput *) final { return false; }
- bool visit(const luci::CircleConst *) final { return false; }
-
- // default is copy attributes
- bool visit(const luci::CircleNode *) { return true; }
-};
-
-} // namespace
-
-namespace luci
-{
-
-loco::TensorShape node_shape(CircleNode *node)
-{
- loco::TensorShape shape;
-
- shape.rank(node->rank());
- for (uint32_t r = 0; r < node->rank(); ++r)
- {
- shape.dim(r) = loco::Dimension(node->dim(r).value());
- }
- return shape;
-}
-
-loco::DataType node_dtype(CircleNode *node) { return node->dtype(); }
-
-void copy_shape_dtype(loco::Graph *graph)
-{
- /**
- * @note We will iterate all the nodes in the graph to include dangle nodes
- */
- auto nodes = graph->nodes();
- for (uint32_t n = 0; n < nodes->size(); ++n)
- {
- auto node = loco::must_cast<luci::CircleNode *>(nodes->at(n));
-
- CopySelector cs;
- if (node->accept(&cs))
- {
- // NOTE not all nodes have infered shape/dtype: multiple outs may not be
- // visited when outputs are not used
- // TODO fix shape inference traversal
- // NOTE when loco supports multiple outputs in nature this issue should be
- // resolved also
-
- if (loco::dtype_known(node))
- {
- node->dtype(loco::dtype_get(node));
- }
-
- if (loco::shape_known(node))
- {
- auto shape = loco::shape_get(node).as<loco::TensorShape>();
- node->rank(shape.rank());
- for (uint32_t r = 0; r < shape.rank(); ++r)
- {
- node->dim(r) = loco::Dimension(shape.dim(r).value());
- }
-
- // ShapeStatus should be update only when the status was UNDEFINED
- if (node->shape_status() == ShapeStatus::UNDEFINED)
- node->shape_status(ShapeStatus::VALID);
- }
- }
- }
-}
-
-} // namespace luci
diff --git a/compiler/luci/import/CMakeLists.txt b/compiler/luci/import/CMakeLists.txt
index 2ae00b837..642751ca6 100644
--- a/compiler/luci/import/CMakeLists.txt
+++ b/compiler/luci/import/CMakeLists.txt
@@ -6,6 +6,7 @@ add_library(luci_import SHARED ${SOURCES})
target_include_directories(luci_import PRIVATE src)
target_include_directories(luci_import PUBLIC include)
target_link_libraries(luci_import PUBLIC luci_lang)
+target_link_libraries(luci_import PUBLIC luci_profile)
target_link_libraries(luci_import PUBLIC mio_circle)
target_link_libraries(luci_import PRIVATE luci_env)
target_link_libraries(luci_import PRIVATE luci_log)
diff --git a/compiler/luci/import/include/luci/Import/CircleReader.h b/compiler/luci/import/include/luci/Import/CircleReader.h
index 8e210dd77..b9697fb86 100644
--- a/compiler/luci/import/include/luci/Import/CircleReader.h
+++ b/compiler/luci/import/include/luci/Import/CircleReader.h
@@ -23,7 +23,6 @@
#include <luci/IR/AttrPadding.h>
#include <luci/IR/CircleNode.h>
#include <luci/IR/CircleQuantParam.h>
-#include <luci/IR/CircleShapeSignature.h>
#include <luci/IR/SparsityParam.h>
#include <loco.h>
@@ -64,6 +63,7 @@ private:
using CircleTensors_t = std::vector<std::unique_ptr<circle::TensorT>>;
using CircleOperators_t = std::vector<std::unique_ptr<circle::OperatorT>>;
using CircleOperatorCodes_t = std::vector<std::unique_ptr<circle::OperatorCodeT>>;
+ using CircleMetadata_t = std::vector<std::unique_ptr<circle::MetadataT>>;
using CircleSubGraphsPtr_t = flatbuffers::Vector<flatbuffers::Offset<circle::SubGraph>>;
using CircleTensorsPtr_t = flatbuffers::Vector<flatbuffers::Offset<circle::Tensor>>;
@@ -79,6 +79,8 @@ public:
const std::vector<int32_t> &inputs() const { return _current_subgraph->inputs; }
const std::vector<int32_t> &outputs() const { return _current_subgraph->outputs; }
const std::string &name() const { return _current_subgraph->name; }
+ const circle::DataFormat &data_format() const { return _current_subgraph->data_format; }
+ const CircleMetadata_t &metadata() const { return _model->metadata; }
const CircleTensorsPtr_t *tensors_ptr() const { return _tensors_ptr; }
diff --git a/compiler/luci/import/include/luci/Import/GraphBuilder.h b/compiler/luci/import/include/luci/Import/GraphBuilder.h
index 548264dac..0db612652 100644
--- a/compiler/luci/import/include/luci/Import/GraphBuilder.h
+++ b/compiler/luci/import/include/luci/Import/GraphBuilder.h
@@ -33,7 +33,13 @@ class GraphBuilder : public GraphBuilderBase
public:
virtual ~GraphBuilder() = default;
- void build(const circle::OperatorT &op, GraphBuilderContext *context) const final;
+ // common validate method to check number of inputs and single output
+ bool validate(const ValidateArgs &args, size_t input_cnt) const
+ {
+ return (args.op.inputs.size() == input_cnt && args.op.outputs.size() == 1);
+ }
+
+ CircleNode *build(const circle::OperatorT &op, GraphBuilderContext *context) const final;
private:
virtual CircleNode *build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/include/luci/Import/GraphBuilderBase.h b/compiler/luci/import/include/luci/Import/GraphBuilderBase.h
index a0cd008e0..ddd4445cd 100644
--- a/compiler/luci/import/include/luci/Import/GraphBuilderBase.h
+++ b/compiler/luci/import/include/luci/Import/GraphBuilderBase.h
@@ -19,6 +19,8 @@
#include "GraphBuilderContext.h"
+#include <luci/IR/CircleNode.h>
+
#include <mio/circle/schema_generated.h>
namespace luci
@@ -38,7 +40,7 @@ struct GraphBuilderBase
};
virtual bool validate(const ValidateArgs &) const = 0;
- virtual void build(const circle::OperatorT &op, GraphBuilderContext *context) const = 0;
+ virtual CircleNode *build(const circle::OperatorT &op, GraphBuilderContext *context) const = 0;
virtual ~GraphBuilderBase() = default;
};
diff --git a/compiler/luci/import/include/luci/Import/GraphBuilderContext.h b/compiler/luci/import/include/luci/Import/GraphBuilderContext.h
index 72e237abc..1673df43d 100644
--- a/compiler/luci/import/include/luci/Import/GraphBuilderContext.h
+++ b/compiler/luci/import/include/luci/Import/GraphBuilderContext.h
@@ -71,7 +71,7 @@ class GraphBuilderContext
public:
GraphBuilderContext(loco::Graph *g, CircleReader *reader, IndexNodeFinder *nodefinder,
IndexTensorOutputs *tensoroutputs)
- : _g(g), _reader(reader), _indexnodefinder(nodefinder), _indextensoroutputs(tensoroutputs)
+ : _g(g), _reader(reader), _indexnodefinder(nodefinder), _indextensoroutputs(tensoroutputs)
{
// DO NOTHING
}
diff --git a/compiler/luci/import/include/luci/Import/GraphBuilderMultiOutput.h b/compiler/luci/import/include/luci/Import/GraphBuilderMultiOutput.h
new file mode 100644
index 000000000..6e8791b62
--- /dev/null
+++ b/compiler/luci/import/include/luci/Import/GraphBuilderMultiOutput.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IMPORT_GRAPH_BUILDER_MULTI_OUTPUT_H__
+#define __LUCI_IMPORT_GRAPH_BUILDER_MULTI_OUTPUT_H__
+
+#include "GraphBuilderContext.h"
+#include "GraphBuilderBase.h"
+
+#include <mio/circle/schema_generated.h>
+
+namespace luci
+{
+
+/**
+ * @brief Base of general multiple outputs graph builder(e.g., CircleIfGraphBuilder)
+ */
+class GraphBuilderMultiOutput : public GraphBuilderBase
+{
+public:
+ virtual ~GraphBuilderMultiOutput() = default;
+
+ CircleNode *build(const circle::OperatorT &op, GraphBuilderContext *context) const final;
+
+protected:
+ struct BuildNodeArgs
+ {
+ BuildNodeArgs(const circle::OperatorT &o, GraphBuilderContext *c,
+ const std::vector<CircleNode *> &i)
+ : op(o), context(c), input_nodes(i)
+ {
+ }
+
+ const circle::OperatorT &op;
+ GraphBuilderContext *context;
+ const std::vector<CircleNode *> &input_nodes;
+ };
+
+ struct BuildOutArgs
+ {
+ BuildOutArgs(CircleNode *nd, uint32_t n) : node(nd), index(n) {}
+
+ CircleNode *node;
+ uint32_t index;
+ };
+
+private:
+ virtual CircleNode *build_node(const BuildNodeArgs &) const = 0;
+ virtual CircleNode *build_out(const BuildOutArgs &) const = 0;
+};
+
+} // namespace luci
+
+#endif // __LUCI_IMPORT_GRAPH_BUILDER_MULTI_OUTPUT_H__
diff --git a/compiler/luci/import/include/luci/Import/Nodes.h b/compiler/luci/import/include/luci/Import/Nodes.h
index 28741064e..b084c7dbc 100644
--- a/compiler/luci/import/include/luci/Import/Nodes.h
+++ b/compiler/luci/import/include/luci/Import/Nodes.h
@@ -27,6 +27,7 @@
#include "Nodes/CircleBatchToSpaceND.h"
#include "Nodes/CircleBCQFullyConnected.h"
#include "Nodes/CircleBCQGather.h"
+#include "Nodes/CircleBidirectionalSequenceLSTM.h"
#include "Nodes/CircleCast.h"
#include "Nodes/CircleCeil.h"
#include "Nodes/CircleConcatenation.h"
@@ -42,6 +43,7 @@
#include "Nodes/CircleEqual.h"
#include "Nodes/CircleExp.h"
#include "Nodes/CircleExpandDims.h"
+#include "Nodes/CircleFakeQuant.h"
#include "Nodes/CircleFill.h"
#include "Nodes/CircleFloor.h"
#include "Nodes/CircleFloorDiv.h"
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleBidirectionalSequenceLSTM.h b/compiler/luci/import/include/luci/Import/Nodes/CircleBidirectionalSequenceLSTM.h
new file mode 100644
index 000000000..491517268
--- /dev/null
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleBidirectionalSequenceLSTM.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IMPORT_OP_CIRCLE_BIDIRECTIONALSEQUENCE_LSTM_H__
+#define __LUCI_IMPORT_OP_CIRCLE_BIDIRECTIONALSEQUENCE_LSTM_H__
+
+#include "luci/Import/GraphBuilderMultiOutput.h"
+
+namespace luci
+{
+
+class CircleBidirectionalSequenceLSTMGraphBuilder : public GraphBuilderMultiOutput
+{
+public:
+ bool validate(const ValidateArgs &args) const final;
+
+private:
+ CircleNode *build_node(const BuildNodeArgs &) const final;
+ CircleNode *build_out(const BuildOutArgs &) const final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_IMPORT_OP_CIRCLE_BIDIRECTIONALSEQUENCE_LSTM_H__
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleCustom.h b/compiler/luci/import/include/luci/Import/Nodes/CircleCustom.h
index 65745be4b..f0d7e303d 100644
--- a/compiler/luci/import/include/luci/Import/Nodes/CircleCustom.h
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleCustom.h
@@ -17,17 +17,19 @@
#ifndef __LUCI_IMPORT_OP_CIRCLE_CUSTOM_H__
#define __LUCI_IMPORT_OP_CIRCLE_CUSTOM_H__
-#include "luci/Import/GraphBuilder.h"
+#include "luci/Import/GraphBuilderMultiOutput.h"
namespace luci
{
-class CircleCustomGraphBuilder : public GraphBuilderBase
+class CircleCustomGraphBuilder : public GraphBuilderMultiOutput
{
public:
bool validate(const ValidateArgs &args) const final;
- void build(const circle::OperatorT &op, GraphBuilderContext *context) const final;
+private:
+ CircleNode *build_node(const BuildNodeArgs &) const final;
+ CircleNode *build_out(const BuildOutArgs &) const final;
};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleFakeQuant.h b/compiler/luci/import/include/luci/Import/Nodes/CircleFakeQuant.h
new file mode 100644
index 000000000..9d9f7b07b
--- /dev/null
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleFakeQuant.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IMPORT_OP_CIRCLE_FAKE_QUANT_H__
+#define __LUCI_IMPORT_OP_CIRCLE_FAKE_QUANT_H__
+
+#include "luci/Import/GraphBuilder.h"
+
+namespace luci
+{
+
+class CircleFakeQuantGraphBuilder : public GraphBuilder
+{
+public:
+ bool validate(const ValidateArgs &args) const final;
+
+private:
+ CircleNode *build_node(const circle::OperatorT &op, const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_IMPORT_OP_CIRCLE_FAKE_QUANT_H__
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleIf.h b/compiler/luci/import/include/luci/Import/Nodes/CircleIf.h
index 8faf09cae..94052f5be 100644
--- a/compiler/luci/import/include/luci/Import/Nodes/CircleIf.h
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleIf.h
@@ -17,17 +17,19 @@
#ifndef __LUCI_IMPORT_OP_CIRCLE_IF_H__
#define __LUCI_IMPORT_OP_CIRCLE_IF_H__
-#include "luci/Import/GraphBuilderBase.h"
+#include "luci/Import/GraphBuilderMultiOutput.h"
namespace luci
{
-class CircleIfGraphBuilder : public GraphBuilderBase
+class CircleIfGraphBuilder : public GraphBuilderMultiOutput
{
public:
bool validate(const ValidateArgs &args) const final;
- void build(const circle::OperatorT &op, GraphBuilderContext *context) const final;
+private:
+ CircleNode *build_node(const BuildNodeArgs &) const final;
+ CircleNode *build_out(const BuildOutArgs &) const final;
};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV4.h b/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV4.h
index f193aae35..4e8388b3e 100644
--- a/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV4.h
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV4.h
@@ -17,17 +17,19 @@
#ifndef __LUCI_IMPORT_OP_CIRCLE_NON_MAX_SUPPRESSION_V4_H__
#define __LUCI_IMPORT_OP_CIRCLE_NON_MAX_SUPPRESSION_V4_H__
-#include "luci/Import/GraphBuilderBase.h"
+#include "luci/Import/GraphBuilderMultiOutput.h"
namespace luci
{
-class CircleNonMaxSuppressionV4GraphBuilder : public GraphBuilderBase
+class CircleNonMaxSuppressionV4GraphBuilder : public GraphBuilderMultiOutput
{
public:
bool validate(const ValidateArgs &args) const final;
- void build(const circle::OperatorT &op, GraphBuilderContext *context) const final;
+private:
+ CircleNode *build_node(const BuildNodeArgs &) const final;
+ CircleNode *build_out(const BuildOutArgs &) const final;
};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV5.h b/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV5.h
index 62be0758e..4120a30eb 100644
--- a/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV5.h
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleNonMaxSuppressionV5.h
@@ -17,17 +17,19 @@
#ifndef __LUCI_IMPORT_OP_CIRCLE_NON_MAX_SUPPRESSION_V5_H__
#define __LUCI_IMPORT_OP_CIRCLE_NON_MAX_SUPPRESSION_V5_H__
-#include "luci/Import/GraphBuilderBase.h"
+#include "luci/Import/GraphBuilderMultiOutput.h"
namespace luci
{
-class CircleNonMaxSuppressionV5GraphBuilder : public GraphBuilderBase
+class CircleNonMaxSuppressionV5GraphBuilder : public GraphBuilderMultiOutput
{
public:
bool validate(const ValidateArgs &args) const final;
- void build(const circle::OperatorT &op, GraphBuilderContext *context) const final;
+private:
+ CircleNode *build_node(const BuildNodeArgs &) const final;
+ CircleNode *build_out(const BuildOutArgs &) const final;
};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleSplit.h b/compiler/luci/import/include/luci/Import/Nodes/CircleSplit.h
index 3395e40fd..5b45c9a9e 100644
--- a/compiler/luci/import/include/luci/Import/Nodes/CircleSplit.h
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleSplit.h
@@ -17,17 +17,19 @@
#ifndef __LUCI_IMPORT_OP_CIRCLE_SPLIT_H__
#define __LUCI_IMPORT_OP_CIRCLE_SPLIT_H__
-#include "luci/Import/GraphBuilderBase.h"
+#include "luci/Import/GraphBuilderMultiOutput.h"
namespace luci
{
-class CircleSplitGraphBuilder : public GraphBuilderBase
+class CircleSplitGraphBuilder : public GraphBuilderMultiOutput
{
public:
bool validate(const ValidateArgs &args) const final;
- void build(const circle::OperatorT &op, GraphBuilderContext *context) const final;
+private:
+ CircleNode *build_node(const BuildNodeArgs &) const final;
+ CircleNode *build_out(const BuildOutArgs &) const final;
};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleSplitV.h b/compiler/luci/import/include/luci/Import/Nodes/CircleSplitV.h
index 3e53df362..de712f90c 100644
--- a/compiler/luci/import/include/luci/Import/Nodes/CircleSplitV.h
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleSplitV.h
@@ -17,17 +17,19 @@
#ifndef __LUCI_IMPORT_OP_CIRCLE_SPLIT_V_H__
#define __LUCI_IMPORT_OP_CIRCLE_SPLIT_V_H__
-#include "luci/Import/GraphBuilderBase.h"
+#include "luci/Import/GraphBuilderMultiOutput.h"
namespace luci
{
-class CircleSplitVGraphBuilder : public GraphBuilderBase
+class CircleSplitVGraphBuilder : public GraphBuilderMultiOutput
{
public:
bool validate(const ValidateArgs &args) const final;
- void build(const circle::OperatorT &op, GraphBuilderContext *context) const final;
+private:
+ CircleNode *build_node(const BuildNodeArgs &) const final;
+ CircleNode *build_out(const BuildOutArgs &) const final;
};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleTopKV2.h b/compiler/luci/import/include/luci/Import/Nodes/CircleTopKV2.h
index 8ec3f3311..b4ad97130 100644
--- a/compiler/luci/import/include/luci/Import/Nodes/CircleTopKV2.h
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleTopKV2.h
@@ -17,17 +17,19 @@
#ifndef __LUCI_IMPORT_OP_CIRCLE_TOPK_V2_H__
#define __LUCI_IMPORT_OP_CIRCLE_TOPK_V2_H__
-#include "luci/Import/GraphBuilderBase.h"
+#include "luci/Import/GraphBuilderMultiOutput.h"
namespace luci
{
-class CircleTopKV2GraphBuilder : public GraphBuilderBase
+class CircleTopKV2GraphBuilder : public GraphBuilderMultiOutput
{
public:
bool validate(const ValidateArgs &args) const final;
- void build(const circle::OperatorT &op, GraphBuilderContext *context) const final;
+private:
+ CircleNode *build_node(const BuildNodeArgs &) const final;
+ CircleNode *build_out(const BuildOutArgs &) const final;
};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleUnique.h b/compiler/luci/import/include/luci/Import/Nodes/CircleUnique.h
index ed5b5035d..40e75ec73 100644
--- a/compiler/luci/import/include/luci/Import/Nodes/CircleUnique.h
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleUnique.h
@@ -17,17 +17,19 @@
#ifndef __LUCI_IMPORT_OP_CIRCLE_UNIQUE_H__
#define __LUCI_IMPORT_OP_CIRCLE_UNIQUE_H__
-#include "luci/Import/GraphBuilderBase.h"
+#include "luci/Import/GraphBuilderMultiOutput.h"
namespace luci
{
-class CircleUniqueGraphBuilder : public GraphBuilderBase
+class CircleUniqueGraphBuilder : public GraphBuilderMultiOutput
{
public:
bool validate(const ValidateArgs &args) const final;
- void build(const circle::OperatorT &op, GraphBuilderContext *context) const final;
+private:
+ CircleNode *build_node(const BuildNodeArgs &) const final;
+ CircleNode *build_out(const BuildOutArgs &) const final;
};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleUnpack.h b/compiler/luci/import/include/luci/Import/Nodes/CircleUnpack.h
index f1a21de22..0b623655f 100644
--- a/compiler/luci/import/include/luci/Import/Nodes/CircleUnpack.h
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleUnpack.h
@@ -17,17 +17,19 @@
#ifndef __LUCI_IMPORT_OP_CIRCLE_UNPACK_H__
#define __LUCI_IMPORT_OP_CIRCLE_UNPACK_H__
-#include "luci/Import/GraphBuilderBase.h"
+#include "luci/Import/GraphBuilderMultiOutput.h"
namespace luci
{
-class CircleUnpackGraphBuilder : public GraphBuilderBase
+class CircleUnpackGraphBuilder : public GraphBuilderMultiOutput
{
public:
bool validate(const ValidateArgs &args) const final;
- void build(const circle::OperatorT &op, GraphBuilderContext *context) const final;
+private:
+ CircleNode *build_node(const BuildNodeArgs &) const final;
+ CircleNode *build_out(const BuildOutArgs &) const final;
};
} // namespace luci
diff --git a/compiler/luci/import/include/luci/Import/Nodes/CircleWhile.h b/compiler/luci/import/include/luci/Import/Nodes/CircleWhile.h
index 68c56b3c6..69d23f823 100644
--- a/compiler/luci/import/include/luci/Import/Nodes/CircleWhile.h
+++ b/compiler/luci/import/include/luci/Import/Nodes/CircleWhile.h
@@ -27,7 +27,7 @@ class CircleWhileGraphBuilder : public GraphBuilderBase
public:
bool validate(const ValidateArgs &args) const final;
- void build(const circle::OperatorT &op, GraphBuilderContext *context) const final;
+ CircleNode *build(const circle::OperatorT &op, GraphBuilderContext *context) const final;
};
} // namespace luci
diff --git a/compiler/luci/import/src/CircleImportMetadata.cpp b/compiler/luci/import/src/CircleImportMetadata.cpp
new file mode 100644
index 000000000..f68f3301a
--- /dev/null
+++ b/compiler/luci/import/src/CircleImportMetadata.cpp
@@ -0,0 +1,185 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleImportMetadata.h"
+
+#include <vector>
+
+namespace
+{
+
+uint32_t read_u32(const std::vector<uint8_t> &buffer, uint32_t idx)
+{
+ uint32_t val = 0;
+ val += (buffer.at(idx + 0) << 0 * 8);
+ val += (buffer.at(idx + 1) << 1 * 8);
+ val += (buffer.at(idx + 2) << 2 * 8);
+ val += (buffer.at(idx + 3) << 3 * 8);
+ return val;
+}
+
+} // namespace
+
+namespace
+{
+
+// 'source_table' is decoded to std::map<uint32_t, std::string> format.
+const std::map<uint32_t, std::string>
+decoded_source_table(const std::vector<uint8_t> &source_table_data)
+{
+ std::map<uint32_t, std::string> source_id_name_map;
+ uint32_t idx = 0;
+
+ if (source_table_data.size() < 4)
+ throw std::runtime_error("Source table decode error : invalid entry number");
+
+ uint32_t entry_number = read_u32(source_table_data, idx);
+ idx += sizeof(uint32_t);
+
+ while (idx < source_table_data.size())
+ {
+ if (idx + 2 * sizeof(uint32_t) > source_table_data.size())
+ throw std::runtime_error("Source table decode error : invalid entry item");
+
+ uint32_t id = read_u32(source_table_data, idx);
+ idx += sizeof(uint32_t);
+
+ uint32_t length = read_u32(source_table_data, idx);
+ idx += sizeof(uint32_t);
+
+ if (idx + sizeof(char) * length > source_table_data.size())
+ throw std::runtime_error("Source table decode error : invalid entry data");
+
+ // The last character of name is '\0'.
+ // However, as std::string do not use '\0' for finding the end of string,
+ // we ignore the character and do not include it in the string.
+ std::string origin_name;
+ for (uint32_t j = 0; j < length - 1; ++j)
+ origin_name += source_table_data.at(idx + j);
+ assert(source_table_data.at(idx + length - 1) == '\0');
+ idx += sizeof(char) * length;
+
+ if (source_id_name_map.insert({id, origin_name}).second == false)
+ throw std::runtime_error("Source table decode error : duplicated origin ID");
+ }
+
+ if (idx != source_table_data.size())
+ throw std::runtime_error("Source table decode error : data size invalid");
+
+ if (source_id_name_map.size() != entry_number)
+ throw std::runtime_error("Source table decode error : result size mismatch");
+
+ return source_id_name_map;
+}
+
+// 'op_table' is decoded to std::map<uint32_t, std::set<uint32_t>> format.
+const std::map<uint32_t, std::set<uint32_t>>
+decoded_op_table(const std::vector<uint8_t> &op_table_data)
+{
+ std::map<uint32_t, std::set<uint32_t>> node_source_ids_map;
+ uint32_t idx = 0;
+
+ if (op_table_data.size() < 4)
+ throw std::runtime_error("Op table decode error : invalid entry number");
+
+ uint32_t entry_number = read_u32(op_table_data, idx);
+ idx += sizeof(uint32_t);
+
+ while (idx < op_table_data.size())
+ {
+ if (idx + 2 * sizeof(uint32_t) > op_table_data.size())
+ throw std::runtime_error("Op table decode error : invalid entry item");
+
+ uint32_t id = read_u32(op_table_data, idx);
+ idx += sizeof(uint32_t);
+
+ uint32_t node_num = read_u32(op_table_data, idx);
+ idx += sizeof(uint32_t);
+
+ if (idx + sizeof(uint32_t) * node_num > op_table_data.size())
+ throw std::runtime_error("Source table decode error : invalid entry data");
+
+ std::set<uint32_t> source_ids;
+ for (uint32_t j = 0; j < node_num; ++j)
+ {
+ uint32_t origin = read_u32(op_table_data, idx);
+ idx += sizeof(uint32_t);
+
+ source_ids.insert(origin);
+ }
+
+ if (node_source_ids_map.insert({id, source_ids}).second == false)
+ throw std::runtime_error("Op table decode error : duplicated origin ID");
+ }
+
+ if (idx != op_table_data.size())
+ throw std::runtime_error("Op table decode error : data size invalid");
+
+ if (node_source_ids_map.size() != entry_number)
+ throw std::runtime_error("Op table decode error : entry number invalid");
+
+ return node_source_ids_map;
+}
+
+} // namespace
+
+namespace luci
+{
+
+CircleImportMetadata::CircleImportMetadata(const luci::CircleReader &reader)
+{
+ const auto &metadata = reader.metadata();
+ for (uint32_t i = 0; i < metadata.size(); ++i)
+ {
+ const circle::MetadataT &meta = *metadata[i];
+
+ assert(meta.buffer < reader.buffers().size());
+ const std::vector<uint8_t> &buffer = reader.buffers()[meta.buffer]->data;
+
+ if (meta.name.compare("ONE_op_table") == 0)
+ _op_table = decoded_op_table(buffer);
+ else if (meta.name.compare("ONE_source_table") == 0)
+ _source_table = decoded_source_table(buffer);
+ }
+}
+
+const OriginTable CircleImportMetadata::origin_table(void)
+{
+ OriginTable origin_table;
+
+ if (_op_table.size() > 0 && _source_table.size() > 0)
+ {
+ for (auto &kv : _op_table)
+ {
+ const auto node_id = kv.first;
+ const auto &source_ids = kv.second;
+
+ std::vector<std::shared_ptr<CircleNodeOrigin>> origins;
+ for (auto source_id : source_ids)
+ {
+ const auto source_name = _source_table.at(source_id);
+ origins.push_back(single_origin(source_id, source_name));
+ }
+
+ auto origin = composite_origin(origins);
+ origin_table.emplace(node_id, origin);
+ }
+ }
+
+ return origin_table;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/CircleImportMetadata.h b/compiler/luci/import/src/CircleImportMetadata.h
new file mode 100644
index 000000000..80176db94
--- /dev/null
+++ b/compiler/luci/import/src/CircleImportMetadata.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CIRCLE_IMPORT_METADATA_H__
+#define __LUCI_CIRCLE_IMPORT_METADATA_H__
+
+#include "luci/Import/CircleReader.h"
+
+#include <luci/Profile/CircleNodeOrigin.h>
+
+#include <map>
+#include <set>
+#include <string>
+
+namespace luci
+{
+
+using OriginTable = std::map<uint32_t, std::shared_ptr<CircleNodeOrigin>>;
+
+class CircleImportMetadata
+{
+public:
+ CircleImportMetadata() = delete;
+
+ CircleImportMetadata(const luci::CircleReader &reader);
+
+public:
+ /**
+ * @brief Create origin table using _source_table and _op_table in CircleImportMetadata
+ * @note For creating origin table, both _op_table and _source_table should exist.
+ * If one of them does not exist, empty table is returned.
+ */
+ const OriginTable origin_table(void);
+
+private:
+ // Decoded metadata is stored
+ std::map<uint32_t, std::string> _source_table;
+ std::map<uint32_t, std::set<uint32_t>> _op_table;
+};
+
+} // namespace luci
+
+#endif // __LUCI_CIRCLE_IMPORT_METADATA_H__
diff --git a/compiler/luci/import/src/CircleReader.cpp b/compiler/luci/import/src/CircleReader.cpp
index b33c920b1..861c1bbe3 100644
--- a/compiler/luci/import/src/CircleReader.cpp
+++ b/compiler/luci/import/src/CircleReader.cpp
@@ -190,19 +190,19 @@ luci_sparse_index_vector(const circle::SparseIndexVectorUnion &sparse_index_vect
case circle::SparseIndexVector_Int32Vector:
{
const auto const_vec_ptr =
- static_cast<const void *>(&(sparse_index_vector.AsInt32Vector()->values));
+ static_cast<const void *>(&(sparse_index_vector.AsInt32Vector()->values));
return SparseIndexVector{SparseIndexVectorType::I32, const_vec_ptr};
}
case circle::SparseIndexVector_Uint16Vector:
{
const auto const_vec_ptr =
- static_cast<const void *>(&(sparse_index_vector.AsUint16Vector()->values));
+ static_cast<const void *>(&(sparse_index_vector.AsUint16Vector()->values));
return SparseIndexVector{SparseIndexVectorType::U16, const_vec_ptr};
}
case circle::SparseIndexVector_Uint8Vector:
{
const auto const_vec_ptr =
- static_cast<const void *>(&(sparse_index_vector.AsUint8Vector()->values));
+ static_cast<const void *>(&(sparse_index_vector.AsUint8Vector()->values));
return SparseIndexVector{SparseIndexVectorType::U8, const_vec_ptr};
}
default:
@@ -262,15 +262,19 @@ void copy_tensor_attributes(const circle::TensorT &tensor, CircleNode *node)
node->name(tensor_name(tensor));
node->dtype(luci_datatype(tensor.type));
+ assert(tensor.shape_signature.size() == 0 ||
+ tensor.shape_signature.size() == tensor.shape.size());
+
std::vector<int32_t> dims = tensor.shape; // in NHWC
node->rank(dims.size());
for (uint32_t r = 0; r < dims.size(); ++r)
{
- node->dim(r) = loco::Dimension(dims[r]);
+ if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1)
+ node->dim(r).unset();
+ else
+ node->dim(r).set(dims[r]);
}
- node->shape_signature(tensor.shape_signature);
-
const auto *quantization = tensor.quantization.get();
if (quantization != nullptr)
{
diff --git a/compiler/luci/import/src/GraphBuilder.cpp b/compiler/luci/import/src/GraphBuilder.cpp
index 80a9f986a..356501c2f 100644
--- a/compiler/luci/import/src/GraphBuilder.cpp
+++ b/compiler/luci/import/src/GraphBuilder.cpp
@@ -21,7 +21,7 @@
namespace luci
{
-void GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const
+CircleNode *GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const
{
LOGGER(l);
@@ -47,7 +47,11 @@ void GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *conte
else
{
// If there is no tensor, insert CircleOutputExclude.
- input_nodes.push_back(context->graph()->nodes()->create<luci::CircleOutputExclude>());
+ auto *node = context->graph()->nodes()->create<luci::CircleOutputExclude>();
+ // CircleOutputExclude doesn't need a type, but since all nodes must have a type,
+ // a dummy type is inserted.
+ node->dtype(loco::DataType::FLOAT32);
+ input_nodes.push_back(node);
}
}
@@ -73,6 +77,8 @@ void GraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *conte
{
context->nodefinder()->enroll(outputs[0], node);
}
+
+ return node;
}
} // namespace luci
diff --git a/compiler/luci/import/src/GraphBuilderMultiOutput.cpp b/compiler/luci/import/src/GraphBuilderMultiOutput.cpp
new file mode 100644
index 000000000..9b42e997e
--- /dev/null
+++ b/compiler/luci/import/src/GraphBuilderMultiOutput.cpp
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/GraphBuilderMultiOutput.h"
+
+#include <luci/Log.h>
+
+namespace luci
+{
+
+CircleNode *GraphBuilderMultiOutput::build(const circle::OperatorT &op,
+ GraphBuilderContext *context) const
+{
+ LOGGER(l);
+
+ assert(context != nullptr);
+
+ const std::vector<int32_t> &inputs = op.inputs;
+ const std::vector<int32_t> &outputs = op.outputs;
+ const auto &tensors = context->reader()->tensors();
+ const auto &opcodes = context->reader()->opcodes();
+ auto tensors_ptr = context->reader()->tensors_ptr();
+ assert(tensors_ptr != nullptr);
+
+ std::vector<CircleNode *> input_nodes;
+ for (const int32_t input_tensor_index : inputs)
+ {
+ if (input_tensor_index >= 0)
+ {
+ auto input = context->nodefinder()->node(input_tensor_index);
+ if (input == nullptr)
+ INFO(l) << "[luci] Warning: input node is null " << input_tensor_index << std::endl;
+ input_nodes.push_back(input);
+ }
+ else
+ {
+ // If there is no tensor, insert CircleOutputExclude.
+ auto *node = context->graph()->nodes()->create<luci::CircleOutputExclude>();
+ // CircleOutputExclude doesn't need a type, but since all nodes must have a type,
+ // a dummy type is inserted.
+ node->dtype(loco::DataType::FLOAT32);
+ input_nodes.push_back(node);
+ }
+ }
+
+ BuildNodeArgs bna(op, context, input_nodes);
+ auto *node = build_node(bna);
+
+ uint32_t output_count = outputs.size();
+ assert(output_count > 0);
+ {
+ // Let's use attributes from output 0 for this node
+ const circle::TensorT &output_tensor = *tensors[outputs[0]];
+ node->name(tensor_name(output_tensor));
+ node->dtype(luci_datatype(output_tensor.type));
+
+ // mark operator version
+ node->op_version(opcodes[op.opcode_index].get()->version);
+
+ // NOTE We don't set quantization for multiple output nodes but to virtual outputs
+ }
+
+ // Create virtual outputs of Virtual Output node(s)
+ for (uint32_t n = 0; n < output_count; ++n)
+ {
+ const circle::TensorT &output_tensor = *tensors[outputs[n]];
+
+ BuildOutArgs boa(node, n);
+ auto *nodeout = build_out(boa);
+
+ copy_tensor_attributes(output_tensor, nodeout);
+ // mark shape_status
+ if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
+ nodeout->shape_status(ShapeStatus::NOSHAPE);
+ else
+ nodeout->shape_status(ShapeStatus::VALID);
+
+ context->nodefinder()->enroll(outputs[n], nodeout);
+ }
+
+ return node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/GraphBuilderRegistry.cpp b/compiler/luci/import/src/GraphBuilderRegistry.cpp
index d598d30f4..7f98aab78 100644
--- a/compiler/luci/import/src/GraphBuilderRegistry.cpp
+++ b/compiler/luci/import/src/GraphBuilderRegistry.cpp
@@ -37,6 +37,7 @@ GraphBuilderRegistry::GraphBuilderRegistry()
CIRCLE_NODE(BATCH_TO_SPACE_ND, CircleBatchToSpaceNDGraphBuilder); // 37
CIRCLE_NODE(BCQ_FULLY_CONNECTED, CircleBCQFullyConnectedGraphBuilder); // 253
CIRCLE_NODE(BCQ_GATHER, CircleBCQGatherGraphBuilder); // 252
+ CIRCLE_NODE(BIDIRECTIONAL_SEQUENCE_LSTM, CircleBidirectionalSequenceLSTMGraphBuilder); // 52
CIRCLE_NODE(CAST, CircleCastGraphBuilder); // 53
CIRCLE_NODE(CEIL, CircleCeilGraphBuilder); // 104
CIRCLE_NODE(CUSTOM, CircleCustomGraphBuilder); // 32
@@ -51,6 +52,7 @@ GraphBuilderRegistry::GraphBuilderRegistry()
CIRCLE_NODE(EQUAL, CircleEqualGraphBuilder); // 71
CIRCLE_NODE(EXP, CircleExpGraphBuilder); // 47
CIRCLE_NODE(EXPAND_DIMS, CircleExpandDimsGraphBuilder); // 70
+ CIRCLE_NODE(FAKE_QUANT, CircleFakeQuantGraphBuilder); // 80
CIRCLE_NODE(FILL, CircleFillGraphBuilder); // 94
CIRCLE_NODE(FLOOR, CircleFloorGraphBuilder); // 8
CIRCLE_NODE(FLOOR_DIV, CircleFloorDivGraphBuilder); // 90
@@ -155,9 +157,7 @@ GraphBuilderRegistry::GraphBuilderRegistry()
// BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN = 35,
// BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN = 46,
// BuiltinOperator_DELEGATE = 51,
- // BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM = 52,
// BuiltinOperator_ARG_MAX = 56,
- // BuiltinOperator_FAKE_QUANT = 80,
// BuiltinOperator_QUANTIZE = 114,
// BuiltinOperator_HARD_SWISH = 117,
// BuiltinOperator_DENSIFY = 124,
diff --git a/compiler/luci/import/src/Importer.cpp b/compiler/luci/import/src/Importer.cpp
index ab89f3587..193afffcb 100644
--- a/compiler/luci/import/src/Importer.cpp
+++ b/compiler/luci/import/src/Importer.cpp
@@ -15,6 +15,7 @@
*/
#include "luci/Importer.h"
+#include "CircleImportMetadata.h"
#include "PostImport.h"
#include "luci/Import/GraphBuilder.h"
@@ -25,6 +26,8 @@
#include <luci/IR/Module.h>
#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeID.h>
+#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Log.h>
#include <luci/LogHelper.h>
@@ -50,6 +53,7 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
const auto &tensors = reader.tensors();
auto tensors_ptr = reader.tensors_ptr();
assert(tensors_ptr != nullptr);
+ auto circle_metadata = std::make_unique<luci::CircleImportMetadata>(reader);
// build a cache to identify if a tensor is output of an operator
// if this is set, we should not create a CircleConst for this tensor
@@ -96,12 +100,20 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
// Data type
graph_input->dtype(input_node->dtype());
+ assert(tensor.shape_signature.size() == 0 ||
+ tensor.shape_signature.size() == tensor.shape.size());
+
// Shape of GraphInput
auto input_shape = std::make_unique<loco::TensorShape>();
const std::vector<int32_t> &input_dims = tensor.shape; // in NHWC
input_shape->rank(input_dims.size());
for (uint32_t r = 0; r < input_dims.size(); ++r)
- input_shape->dim(r) = loco::Dimension(input_dims[r]);
+ {
+ if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1)
+ input_shape->dim(r).unset();
+ else
+ input_shape->dim(r).set(input_dims[r]);
+ }
graph_input->shape(std::move(input_shape));
}
@@ -117,6 +129,7 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
// Note that operators in model are stored in execution order. This means that when importing
// an operator, its input operators have already been imported. We exploit this fact to set up
// node's inputs right after creating the node.
+ auto origin_table = circle_metadata->origin_table();
for (uint32_t i = 0; i < operators.size(); ++i)
{
const circle::OperatorT &op = *operators[i];
@@ -130,7 +143,12 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
throw oops::UserExn("Invalid operator", reader.opcode_name(op));
}
- builder->build(op, &gb_context);
+ auto built_op = builder->build(op, &gb_context);
+ set_node_id(built_op, i);
+ if (origin_table.find(i) != origin_table.end())
+ add_origin(built_op, origin_table.at(i));
+ else
+ add_origin(built_op, luci::single_origin(i, built_op->name()));
}
else
{
@@ -169,19 +187,28 @@ void convert_graph(const luci::GraphBuilderSource &source, luci::CircleReader &r
// set the graph output name and node object
auto graph_output = graph->outputs()->create();
std::string tname = luci::tensor_name(tensor);
- graph_output->name("output_" + tname);
+ assert(tname.length() > 0);
+ graph_output->name(tname);
luci::copy_tensor_attributes(tensor, output_node);
// Set GraphInputOutputIndex for graph
output_node->index(graph_output->index());
+ assert(tensor.shape_signature.size() == 0 ||
+ tensor.shape_signature.size() == tensor.shape.size());
+
// Shape of Output
auto output_shape = std::make_unique<loco::TensorShape>();
const std::vector<int32_t> &output_dims = tensor.shape; // in NHWC
output_shape->rank(output_dims.size());
for (uint32_t r = 0; r < output_dims.size(); ++r)
- output_shape->dim(r) = loco::Dimension(output_dims[r]);
+ {
+ if (tensor.shape_signature.size() > 0 && tensor.shape_signature.at(r) == -1)
+ output_shape->dim(r).unset();
+ else
+ output_shape->dim(r).set(output_dims[r]);
+ }
graph_output->shape(std::move(output_shape));
// Data type
diff --git a/compiler/luci/import/src/Nodes/CircleAbs.cpp b/compiler/luci/import/src/Nodes/CircleAbs.cpp
index 3556dc7fa..2a1601a21 100644
--- a/compiler/luci/import/src/Nodes/CircleAbs.cpp
+++ b/compiler/luci/import/src/Nodes/CircleAbs.cpp
@@ -24,11 +24,8 @@ namespace luci
{
bool CircleAbsGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
// TODO Support type check
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleAbsGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleAdd.cpp b/compiler/luci/import/src/Nodes/CircleAdd.cpp
index b767d4af2..94cbdf081 100644
--- a/compiler/luci/import/src/Nodes/CircleAdd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleAdd.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleAddGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleAddGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleArgMax.cpp b/compiler/luci/import/src/Nodes/CircleArgMax.cpp
index 10e8516f4..fd8a84289 100644
--- a/compiler/luci/import/src/Nodes/CircleArgMax.cpp
+++ b/compiler/luci/import/src/Nodes/CircleArgMax.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleArgMaxGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleArgMaxGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleArgMin.cpp b/compiler/luci/import/src/Nodes/CircleArgMin.cpp
index 5ff534dbb..63ca8db03 100644
--- a/compiler/luci/import/src/Nodes/CircleArgMin.cpp
+++ b/compiler/luci/import/src/Nodes/CircleArgMin.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleArgMinGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleArgMinGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp b/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp
index ad011f71f..a351cf5e7 100644
--- a/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp
+++ b/compiler/luci/import/src/Nodes/CircleAveragePool2D.cpp
@@ -23,10 +23,7 @@ namespace luci
bool CircleAveragePool2DGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleAveragePool2DGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp b/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp
index 16ecebd5c..4c86399ce 100644
--- a/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp
+++ b/compiler/luci/import/src/Nodes/CircleBCQFullyConnected.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleBCQFullyConnectedGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 5)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 5);
}
CircleNode *CircleBCQFullyConnectedGraphBuilder::build_node(const circle::OperatorT &op,
@@ -43,15 +40,6 @@ CircleNode *CircleBCQFullyConnectedGraphBuilder::build_node(const circle::Operat
node->bias(inputs.at(3));
node->weights_clusters(inputs.at(4));
- // TODO Find and move to appropriate place for setting optional input
- if (auto bias = dynamic_cast<luci::CircleOutputExclude *>(node->bias()))
- {
- // bias is not used for type inference, but node itself should have a type
- bias->dtype(loco::DataType::FLOAT32);
-
- // bias is not used for shape inference
- }
-
const auto *options = op.builtin_options.AsBCQFullyConnectedOptions();
node->weights_hidden_size(options->weights_hidden_size);
node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
diff --git a/compiler/luci/import/src/Nodes/CircleBCQGather.cpp b/compiler/luci/import/src/Nodes/CircleBCQGather.cpp
index 464f1ac18..ee1358197 100644
--- a/compiler/luci/import/src/Nodes/CircleBCQGather.cpp
+++ b/compiler/luci/import/src/Nodes/CircleBCQGather.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleBCQGatherGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 4)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 4);
}
CircleNode *CircleBCQGatherGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp b/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp
index 330775691..390719061 100644
--- a/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp
+++ b/compiler/luci/import/src/Nodes/CircleBatchMatMul.cpp
@@ -23,10 +23,7 @@ namespace luci
bool CircleBatchMatMulGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleBatchMatMulGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleBidirectionalSequenceLSTM.cpp b/compiler/luci/import/src/Nodes/CircleBidirectionalSequenceLSTM.cpp
new file mode 100644
index 000000000..f8bdcff72
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleBidirectionalSequenceLSTM.cpp
@@ -0,0 +1,112 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleBidirectionalSequenceLSTM.h"
+
+#include <luci/IR/Nodes/CircleBidirectionalSequenceLSTM.h>
+#include <luci/IR/Nodes/CircleBidirectionalSequenceLSTMOut.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool CircleBidirectionalSequenceLSTMGraphBuilder::validate(const ValidateArgs &args) const
+{
+ if (args.op.inputs.size() != 48)
+ return false;
+ if (args.op.outputs.size() != 2)
+ return false;
+
+ return true;
+}
+
+CircleNode *CircleBidirectionalSequenceLSTMGraphBuilder::build_node(const BuildNodeArgs &bna) const
+{
+ auto *node = bna.context->graph()->nodes()->create<CircleBidirectionalSequenceLSTM>();
+ auto &inputs = bna.input_nodes;
+ node->input(inputs.at(0));
+ node->fw_input_to_input_weights(inputs.at(1)); // Optional
+ node->fw_input_to_cell_weights(inputs.at(2));
+ node->fw_input_to_forget_weights(inputs.at(3));
+ node->fw_input_to_output_weights(inputs.at(4));
+ node->fw_recurrent_to_input_weights(inputs.at(5)); // Optional
+ node->fw_recurrent_to_cell_weights(inputs.at(6));
+ node->fw_recurrent_to_forget_weights(inputs.at(7));
+ node->fw_recurrent_to_output_weights(inputs.at(8));
+ node->fw_cell_to_input_weights(inputs.at(9)); // Optional
+ node->fw_cell_to_forget_weights(inputs.at(10)); // Optional
+ node->fw_cell_to_output_weights(inputs.at(11)); // Optional
+ node->fw_input_gate_bias(inputs.at(12)); // Optional
+ node->fw_forget_gate_bias(inputs.at(13));
+ node->fw_cell_gate_bias(inputs.at(14));
+ node->fw_output_gate_bias(inputs.at(15));
+ node->fw_projection_weights(inputs.at(16)); // Optional
+ node->fw_projection_bias(inputs.at(17)); // Optional
+ node->bw_input_to_input_weights(inputs.at(18)); // Optional
+ node->bw_input_to_cell_weights(inputs.at(19));
+ node->bw_input_to_forget_weights(inputs.at(20));
+ node->bw_input_to_output_weights(inputs.at(21));
+ node->bw_recurrent_to_input_weights(inputs.at(22)); // Optional
+ node->bw_recurrent_to_cell_weights(inputs.at(23));
+ node->bw_recurrent_to_forget_weights(inputs.at(24));
+ node->bw_recurrent_to_output_weights(inputs.at(25));
+ node->bw_cell_to_input_weights(inputs.at(26)); // Optional
+ node->bw_cell_to_forget_weights(inputs.at(27)); // Optional
+ node->bw_cell_to_output_weights(inputs.at(28)); // Optional
+ node->bw_input_gate_bias(inputs.at(29)); // Optional
+ node->bw_forget_gate_bias(inputs.at(30));
+ node->bw_cell_gate_bias(inputs.at(31));
+ node->bw_output_gate_bias(inputs.at(32));
+ node->bw_projection_weights(inputs.at(33)); // Optional
+ node->bw_projection_bias(inputs.at(34)); // Optional
+ node->fw_activation_state(inputs.at(35));
+ node->fw_cell_state(inputs.at(36));
+ node->bw_activation_state(inputs.at(37));
+ node->bw_cell_state(inputs.at(38));
+
+ node->auxillary_input(inputs.at(39)); // Optional
+ node->fw_auxillary_input_to_input_weights(inputs.at(40)); // Optional
+ node->fw_auxillary_input_to_forget_weights(inputs.at(41)); // Optional
+ node->fw_auxillary_input_to_cell_weights(inputs.at(42)); // Optional
+ node->fw_auxillary_input_to_output_weights(inputs.at(43)); // Optional
+ node->bw_auxillary_input_to_input_weights(inputs.at(44)); // Optional
+ node->bw_auxillary_input_to_forget_weights(inputs.at(45)); // Optional
+ node->bw_auxillary_input_to_cell_weights(inputs.at(46)); // Optional
+ node->bw_auxillary_input_to_output_weights(inputs.at(47)); // Optional
+
+ const auto *options = bna.op.builtin_options.AsBidirectionalSequenceLSTMOptions();
+ node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
+ node->cell_clip(options->cell_clip);
+ node->proj_clip(options->proj_clip);
+ node->merge_outputs(options->merge_outputs);
+ node->time_major(options->time_major);
+ node->asymmetric_quantize_inputs(options->asymmetric_quantize_inputs);
+
+ return node;
+}
+
+CircleNode *CircleBidirectionalSequenceLSTMGraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleBidirectionalSequenceLSTMOut>();
+
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
+
+ return nodeout;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleCast.cpp b/compiler/luci/import/src/Nodes/CircleCast.cpp
index 7bdb63044..3e8c08bfa 100644
--- a/compiler/luci/import/src/Nodes/CircleCast.cpp
+++ b/compiler/luci/import/src/Nodes/CircleCast.cpp
@@ -30,14 +30,13 @@ bool CircleCastGraphBuilder::validate(const ValidateArgs &args) const
{
LOGGER(l);
+ if (!GraphBuilder::validate(args, 1))
+ return false;
+
auto settings = luci::UserSettings::settings();
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
- if (inputs.size() != 1)
- return false;
- if (outputs.size() != 1)
- return false;
// NOTE real models do have type mismatch
const auto *options = args.op.builtin_options.AsCastOptions();
diff --git a/compiler/luci/import/src/Nodes/CircleCeil.cpp b/compiler/luci/import/src/Nodes/CircleCeil.cpp
index 2e1aaa295..d439f41cd 100644
--- a/compiler/luci/import/src/Nodes/CircleCeil.cpp
+++ b/compiler/luci/import/src/Nodes/CircleCeil.cpp
@@ -25,16 +25,8 @@ namespace luci
bool CircleCeilGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 1)
- return false;
- if (outputs.size() != 1)
- return false;
-
// TODO dtype check
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleCeilGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleConv2D.cpp b/compiler/luci/import/src/Nodes/CircleConv2D.cpp
index 9516ef16a..8cbecdc00 100644
--- a/compiler/luci/import/src/Nodes/CircleConv2D.cpp
+++ b/compiler/luci/import/src/Nodes/CircleConv2D.cpp
@@ -28,10 +28,7 @@ namespace luci
bool CircleConv2DGraphBuilder::validate(const ValidateArgs &args) const
{
// Circle Conv2D may not have a bias but we won't support this
- if (args.op.inputs.size() != 3)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 3);
}
CircleNode *CircleConv2DGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleCos.cpp b/compiler/luci/import/src/Nodes/CircleCos.cpp
index 27d60c62c..9705202ee 100644
--- a/compiler/luci/import/src/Nodes/CircleCos.cpp
+++ b/compiler/luci/import/src/Nodes/CircleCos.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleCosGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleCosGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleCustom.cpp b/compiler/luci/import/src/Nodes/CircleCustom.cpp
index d541ee87b..01ac3e2a0 100644
--- a/compiler/luci/import/src/Nodes/CircleCustom.cpp
+++ b/compiler/luci/import/src/Nodes/CircleCustom.cpp
@@ -27,62 +27,39 @@ bool CircleCustomGraphBuilder::validate(const ValidateArgs &) const
return true;
}
-void CircleCustomGraphBuilder::build(const circle::OperatorT &op,
- GraphBuilderContext *context) const
+CircleNode *CircleCustomGraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
+ uint32_t input_count = bna.op.inputs.size();
+ uint32_t output_count = bna.op.outputs.size();
- auto graph = context->graph();
+ auto *node = bna.context->graph()->nodes()->create<CircleCustom>(input_count, output_count);
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
+ for (uint32_t idx = 0; idx < input_count; ++idx)
+ {
+ node->inputs(idx, bna.input_nodes[idx]);
+ }
- // Create CircleCustom
- const auto &opcodes = context->reader()->opcodes();
- const uint32_t opcode_index = op.opcode_index;
+ const auto &opcodes = bna.context->reader()->opcodes();
+ const uint32_t opcode_index = bna.op.opcode_index;
const circle::OperatorCodeT &opcode = *opcodes[opcode_index];
- auto *node = graph->nodes()->create<CircleCustom>(inputs.size());
- uint32_t input_idx = 0;
- for (const int32_t input_tensor_index : inputs)
- {
- node->inputs(input_idx++, context->nodefinder()->node(input_tensor_index));
- }
- node->custom_options(std::vector<uint8_t>{op.custom_options.begin(), op.custom_options.end()});
+ node->custom_options(
+ std::vector<uint8_t>{bna.op.custom_options.begin(), bna.op.custom_options.end()});
node->custom_code(opcode.custom_code);
- // Operator version of custom is always 1, so do nothing
- uint32_t output_count = outputs.size();
+ // NOTE Operator version of custom is always 1
- assert(output_count > 0);
- {
- // Let's use attributes from output 0 for this node
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->dtype(luci_datatype(output_tensor.type));
- }
-
- // Create virtual outputs of Custom
- for (uint32_t n = 0; n < output_count; ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
+ return node;
+}
- auto *nodeout = graph->nodes()->create<CircleCustomOut>();
- copy_tensor_attributes(output_tensor, nodeout);
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
+CircleNode *CircleCustomGraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleCustomOut>();
- nodeout->input(node);
- nodeout->index(n);
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp b/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp
index 49d31bb99..49eb30a83 100644
--- a/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp
+++ b/compiler/luci/import/src/Nodes/CircleDepthToSpace.cpp
@@ -27,17 +27,13 @@ namespace luci
bool CircleDepthToSpaceGraphBuilder::validate(const ValidateArgs &args) const
{
+ if (!GraphBuilder::validate(args, 1))
+ return false;
+
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
const auto *options = args.op.builtin_options.AsDepthToSpaceOptions();
-
- if (inputs.size() != 1)
- return false;
-
- if (outputs.size() != 1)
- return false;
-
const auto &tensors = args.reader.tensors();
if (tensors[outputs[0]]->type != tensors[inputs.at(0)]->type)
diff --git a/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp b/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp
index 53f85f2f5..727487c6a 100644
--- a/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp
+++ b/compiler/luci/import/src/Nodes/CircleDepthwiseConv2D.cpp
@@ -32,6 +32,32 @@ bool CircleDepthwiseConv2DGraphBuilder::validate(const ValidateArgs &args) const
if (args.op.outputs.size() != 1)
return false;
+ const auto &tensors = args.reader.tensors();
+
+ // input shape
+ const auto &input = tensors.at(args.op.inputs.at(0));
+ const auto &input_shape = input->shape;
+
+ // input shape must be rank 4
+ if (input_shape.size() != 4)
+ return false;
+
+ // filter shape
+ const auto &filter = tensors.at(args.op.inputs.at(1));
+ const auto &filter_shape = filter->shape;
+
+ // filter shape must be rank 4
+ if (filter_shape.size() != 4)
+ return false;
+
+ // multiplier
+ const auto *options = args.op.builtin_options.AsDepthwiseConv2DOptions();
+ const auto &multiplier = options->depth_multiplier;
+
+ // filter represents as [1, H, W, C*M] where M is multiplier.
+ if (filter_shape.at(3) != input_shape.at(3) * multiplier)
+ return false;
+
return true;
}
diff --git a/compiler/luci/import/src/Nodes/CircleDequantize.cpp b/compiler/luci/import/src/Nodes/CircleDequantize.cpp
index 1936da97c..3db546bd0 100644
--- a/compiler/luci/import/src/Nodes/CircleDequantize.cpp
+++ b/compiler/luci/import/src/Nodes/CircleDequantize.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleDequantizeGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleDequantizeGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleDiv.cpp b/compiler/luci/import/src/Nodes/CircleDiv.cpp
index 615c224d7..7ea1afd95 100644
--- a/compiler/luci/import/src/Nodes/CircleDiv.cpp
+++ b/compiler/luci/import/src/Nodes/CircleDiv.cpp
@@ -23,13 +23,7 @@ namespace luci
bool CircleDivGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleDivGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleElu.cpp b/compiler/luci/import/src/Nodes/CircleElu.cpp
index 919e95ee4..461da9517 100644
--- a/compiler/luci/import/src/Nodes/CircleElu.cpp
+++ b/compiler/luci/import/src/Nodes/CircleElu.cpp
@@ -25,14 +25,11 @@ namespace luci
bool CircleEluGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
- if (outputs.size() != 1)
- return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
diff --git a/compiler/luci/import/src/Nodes/CircleEqual.cpp b/compiler/luci/import/src/Nodes/CircleEqual.cpp
index 1db33b8ac..4909692b4 100644
--- a/compiler/luci/import/src/Nodes/CircleEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleEqual.cpp
@@ -25,13 +25,10 @@ namespace luci
bool CircleEqualGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
-
- if (inputs.size() != 2)
- {
+ if (!GraphBuilder::validate(args, 2))
return false;
- }
+ const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
return tensors[inputs.at(0)]->type == tensors[inputs.at(1)]->type;
diff --git a/compiler/luci/import/src/Nodes/CircleExp.cpp b/compiler/luci/import/src/Nodes/CircleExp.cpp
index 2c031d6b3..64f18fbd4 100644
--- a/compiler/luci/import/src/Nodes/CircleExp.cpp
+++ b/compiler/luci/import/src/Nodes/CircleExp.cpp
@@ -25,10 +25,10 @@ namespace luci
bool CircleExpGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
// input type check
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
diff --git a/compiler/luci/import/src/Nodes/CircleExpandDims.cpp b/compiler/luci/import/src/Nodes/CircleExpandDims.cpp
index ab537c710..ee0fbdc7e 100644
--- a/compiler/luci/import/src/Nodes/CircleExpandDims.cpp
+++ b/compiler/luci/import/src/Nodes/CircleExpandDims.cpp
@@ -25,13 +25,10 @@ namespace luci
bool CircleExpandDimsGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
-
- if (inputs.size() != 2)
- {
+ if (!GraphBuilder::validate(args, 2))
return false;
- }
+ const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
return tensors[inputs.at(1)]->type == circle::TensorType_INT32;
diff --git a/compiler/luci/import/src/Nodes/CircleFakeQuant.cpp b/compiler/luci/import/src/Nodes/CircleFakeQuant.cpp
new file mode 100644
index 000000000..7cf40b225
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleFakeQuant.cpp
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleFakeQuant.h"
+
+#include <luci/IR/Nodes/CircleFullyConnected.h>
+#include <luci/IR/Nodes/CircleOutput.h>
+
+#include <loco.h>
+#include <oops/UserExn.h>
+
+namespace luci
+{
+
+bool CircleFakeQuantGraphBuilder::validate(const ValidateArgs &args) const
+{
+ return GraphBuilder::validate(args, 1);
+}
+
+CircleNode *CircleFakeQuantGraphBuilder::build_node(const circle::OperatorT &op,
+ const std::vector<CircleNode *> &inputs,
+ loco::Graph *graph) const
+{
+ auto *node = graph->nodes()->create<CircleFakeQuant>();
+ node->inputs(inputs.at(0));
+
+ const auto *options = op.builtin_options.AsFakeQuantOptions();
+ node->min(options->min);
+ node->max(options->max);
+ node->num_bits(options->num_bits);
+ node->narrow_range(options->narrow_range);
+
+ return node;
+}
+
+} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleFill.cpp b/compiler/luci/import/src/Nodes/CircleFill.cpp
index 95d5b876b..9aacddcbe 100644
--- a/compiler/luci/import/src/Nodes/CircleFill.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFill.cpp
@@ -23,13 +23,7 @@ namespace luci
bool CircleFillGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleFillGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleFloor.cpp b/compiler/luci/import/src/Nodes/CircleFloor.cpp
index ce756b3b1..9651259c7 100644
--- a/compiler/luci/import/src/Nodes/CircleFloor.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFloor.cpp
@@ -25,16 +25,8 @@ namespace luci
bool CircleFloorGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 1)
- return false;
- if (outputs.size() != 1)
- return false;
-
// TODO dtype check
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleFloorGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp b/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp
index 55f385d60..ce329326a 100644
--- a/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFloorDiv.cpp
@@ -25,19 +25,11 @@ namespace luci
bool CircleFloorDivGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- {
+ if (!GraphBuilder::validate(args, 2))
return false;
- }
-
- if (outputs.size() != 1)
- {
- return false;
- }
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor_in_0 = tensors.at(inputs.at(0));
const auto &tensor_in_1 = tensors.at(inputs.at(1));
diff --git a/compiler/luci/import/src/Nodes/CircleFloorMod.cpp b/compiler/luci/import/src/Nodes/CircleFloorMod.cpp
index 2101e417e..d8420a43c 100644
--- a/compiler/luci/import/src/Nodes/CircleFloorMod.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFloorMod.cpp
@@ -25,13 +25,10 @@ namespace luci
bool CircleFloorModGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 2)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
+ const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
const auto &tensor_in_0 = tensors.at(inputs.at(0));
const auto &tensor_in_1 = tensors.at(inputs.at(1));
diff --git a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
index 17293ad7a..58750d79a 100644
--- a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
@@ -27,10 +27,7 @@ namespace luci
bool CircleFullyConnectedGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 3)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 3);
}
CircleNode *CircleFullyConnectedGraphBuilder::build_node(const circle::OperatorT &op,
@@ -42,15 +39,6 @@ CircleNode *CircleFullyConnectedGraphBuilder::build_node(const circle::OperatorT
node->weights(inputs.at(1));
node->bias(inputs.at(2)); // bias is optional
- // TODO Find and move to appropriate place for setting optional input
- if (auto bias = dynamic_cast<luci::CircleOutputExclude *>(node->bias()))
- {
- // bias is not used for type inference, but node itself should have a type
- bias->dtype(loco::DataType::FLOAT32);
-
- // bias is not used for shape inference
- }
-
const auto *options = op.builtin_options.AsFullyConnectedOptions();
node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
node->weights_format(luci_weights_format(options->weights_format));
diff --git a/compiler/luci/import/src/Nodes/CircleGather.cpp b/compiler/luci/import/src/Nodes/CircleGather.cpp
index 75447a38a..8317a3340 100644
--- a/compiler/luci/import/src/Nodes/CircleGather.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGather.cpp
@@ -26,18 +26,14 @@ namespace luci
bool CircleGatherGraphBuilder::validate(const ValidateArgs &args) const
{
+ if (!GraphBuilder::validate(args, 2))
+ return false;
+
const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
const auto *options = args.op.builtin_options.AsGatherOptions();
int32_t axis = options->axis;
- if (inputs.size() != 2)
- return false;
-
- if (outputs.size() != 1)
- return false;
-
if (axis < 0)
axis += inputs.size();
diff --git a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
index 981adbf63..a4bb26a10 100644
--- a/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGatherNd.cpp
@@ -27,15 +27,10 @@ namespace luci
bool CircleGatherNdGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- return false;
-
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
+ const auto &inputs = args.op.inputs;
auto &indices_tensor = args.reader.tensors()[inputs.at(1)];
if (!(indices_tensor->type == circle::TensorType::TensorType_INT32 ||
diff --git a/compiler/luci/import/src/Nodes/CircleGreater.cpp b/compiler/luci/import/src/Nodes/CircleGreater.cpp
index 1ad0467e4..f9c00346c 100644
--- a/compiler/luci/import/src/Nodes/CircleGreater.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGreater.cpp
@@ -30,17 +30,13 @@ bool CircleGreaterGraphBuilder::validate(const ValidateArgs &args) const
{
LOGGER(l);
+ if (!GraphBuilder::validate(args, 2))
+ return false;
+
auto settings = luci::UserSettings::settings();
const auto &inputs = args.op.inputs;
const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- return false;
-
- if (outputs.size() != 1)
- return false;
-
const auto &tensors = args.reader.tensors();
if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
diff --git a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp
index 0ac63b017..e20038fd9 100644
--- a/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleGreaterEqual.cpp
@@ -25,19 +25,11 @@ namespace luci
bool CircleGreaterEqualGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- {
+ if (!GraphBuilder::validate(args, 2))
return false;
- }
-
- if (outputs.size() != 1)
- {
- return false;
- }
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
diff --git a/compiler/luci/import/src/Nodes/CircleIf.cpp b/compiler/luci/import/src/Nodes/CircleIf.cpp
index db9ffe1cd..ffdbf0b79 100644
--- a/compiler/luci/import/src/Nodes/CircleIf.cpp
+++ b/compiler/luci/import/src/Nodes/CircleIf.cpp
@@ -70,69 +70,34 @@ bool CircleIfGraphBuilder::validate(const ValidateArgs &args) const
* \- CircleIfOut --- Node ---
*/
-void CircleIfGraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const
+CircleNode *CircleIfGraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
+ uint32_t input_count = bna.op.inputs.size() - 1;
+ uint32_t output_count = bna.op.outputs.size();
- auto graph = context->graph();
+ auto *node = bna.context->graph()->nodes()->create<CircleIf>(input_count, output_count);
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
-
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
-
- uint32_t input_count = inputs.size() - 1;
- uint32_t output_count = outputs.size();
-
- // Create CircleIf
- CircleIf *node = graph->nodes()->create<CircleIf>(input_count, output_count);
-
- node->cond(input_nodes[0]);
+ node->cond(bna.input_nodes[0]);
for (uint32_t idx = 0; idx < input_count; ++idx)
{
- node->input(idx, input_nodes[idx + 1]);
+ node->input(idx, bna.input_nodes[idx + 1]);
}
- const auto *options = op.builtin_options.AsIfOptions();
+ const auto *options = bna.op.builtin_options.AsIfOptions();
node->then_branch(options->then_subgraph_index);
node->else_branch(options->else_subgraph_index);
- assert(outputs.size() > 0);
- {
- // Lets use name of output 0 as If name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
-
- // NOTE We don't set quantization for If itself but to virtual outputs
- }
-
- // Create virtual outputs of If
- for (uint32_t n = 0; n < output_count; ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
+ return node;
+}
- auto *nodeout = graph->nodes()->create<CircleIfOut>();
- copy_tensor_attributes(output_tensor, nodeout);
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
+CircleNode *CircleIfGraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleIfOut>();
- nodeout->input(node);
- nodeout->index(n);
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp b/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp
index 6349fd3b7..977b53406 100644
--- a/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp
+++ b/compiler/luci/import/src/Nodes/CircleInstanceNorm.cpp
@@ -25,12 +25,8 @@ namespace luci
bool CircleInstanceNormGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 3)
- return false;
-
// TODO check dtypes
-
- return true;
+ return GraphBuilder::validate(args, 3);
}
CircleNode *CircleInstanceNormGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp b/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp
index e4fdc200c..7e1faedfb 100644
--- a/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp
+++ b/compiler/luci/import/src/Nodes/CircleL2Normalize.cpp
@@ -25,20 +25,7 @@ namespace luci
bool CircleL2NormalizeGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 1)
- {
- return false;
- }
-
- if (outputs.size() != 1)
- {
- return false;
- }
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleL2NormalizeGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp b/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp
index 202d9d6fb..849c7c5ed 100644
--- a/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp
+++ b/compiler/luci/import/src/Nodes/CircleL2Pool2D.cpp
@@ -25,12 +25,8 @@ namespace luci
bool CircleL2Pool2DGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
// TODO check dtypes
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleL2Pool2DGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp b/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp
index ad4979f39..880fa6428 100644
--- a/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLeakyRelu.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CircleLeakyReluGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleLeakyReluGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleLess.cpp b/compiler/luci/import/src/Nodes/CircleLess.cpp
index 506036908..f9b99bebe 100644
--- a/compiler/luci/import/src/Nodes/CircleLess.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLess.cpp
@@ -25,19 +25,11 @@ namespace luci
bool CircleLessGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- {
+ if (!GraphBuilder::validate(args, 2))
return false;
- }
-
- if (outputs.size() != 1)
- {
- return false;
- }
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
diff --git a/compiler/luci/import/src/Nodes/CircleLessEqual.cpp b/compiler/luci/import/src/Nodes/CircleLessEqual.cpp
index 9b4f934a5..bb1712137 100644
--- a/compiler/luci/import/src/Nodes/CircleLessEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLessEqual.cpp
@@ -25,19 +25,11 @@ namespace luci
bool CircleLessEqualGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- {
+ if (!GraphBuilder::validate(args, 2))
return false;
- }
-
- if (outputs.size() != 1)
- {
- return false;
- }
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
diff --git a/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp b/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp
index 0e32f62de..d03c47d12 100644
--- a/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLocalResponseNormalization.cpp
@@ -25,16 +25,12 @@ namespace luci
bool CircleLocalResponseNormalizationGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
// TODO do attribute checks
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleLocalResponseNormalizationGraphBuilder::build_node(
- const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const
+ const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const
{
auto *node = graph->nodes()->create<CircleLocalResponseNormalization>();
node->input(inputs.at(0));
diff --git a/compiler/luci/import/src/Nodes/CircleLog.cpp b/compiler/luci/import/src/Nodes/CircleLog.cpp
index 346fc43bb..26b575070 100644
--- a/compiler/luci/import/src/Nodes/CircleLog.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLog.cpp
@@ -25,12 +25,10 @@ namespace luci
bool CircleLogGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
- return false;
- if (args.op.outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
// input type check
// Must be one of bfloat16, half, float32, float64, complex64, complex128.
// Currently circle supports half(float16), float32, float64, complex64.
diff --git a/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp b/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp
index ef69e868a..4361db691 100644
--- a/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogSoftmax.cpp
@@ -25,12 +25,8 @@ namespace luci
bool CircleLogSoftmaxGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
// TODO do attribute checks
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleLogSoftmaxGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp b/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp
index 7844da0f6..b13fc2735 100644
--- a/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogicalAnd.cpp
@@ -25,11 +25,11 @@ namespace luci
bool CircleLogicalAndGraphBuilder::validate(const ValidateArgs &args) const
{
- // Only BOOL type is allowed for inputs
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 2)
+ if (!GraphBuilder::validate(args, 2))
return false;
+ // Only BOOL type is allowed for inputs
+ const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
for (auto input : inputs)
{
diff --git a/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp b/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp
index 3758642e4..f68218349 100644
--- a/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogicalNot.cpp
@@ -25,7 +25,7 @@ namespace luci
bool CircleLogicalNotGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
// Only BOOL type is allowed for the input
diff --git a/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp b/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp
index 1b87e6f9c..8c9023dd3 100644
--- a/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogicalOr.cpp
@@ -25,7 +25,7 @@ namespace luci
bool CircleLogicalOrGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
+ if (!GraphBuilder::validate(args, 2))
return false;
// Only BOOL type is allowed for inputs
diff --git a/compiler/luci/import/src/Nodes/CircleLogistic.cpp b/compiler/luci/import/src/Nodes/CircleLogistic.cpp
index 9606e19cd..0f92a9bb4 100644
--- a/compiler/luci/import/src/Nodes/CircleLogistic.cpp
+++ b/compiler/luci/import/src/Nodes/CircleLogistic.cpp
@@ -25,13 +25,11 @@ namespace luci
bool CircleLogisticGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
- return false;
- const auto &outputs = args.op.outputs;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
return false;
diff --git a/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp
index a4a21a8b7..590a07f2d 100644
--- a/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMatrixDiag.cpp
@@ -25,15 +25,11 @@ namespace luci
bool CircleMatrixDiagGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 1)
- return false;
-
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
diff --git a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
index cf0313149..edd7d2ae2 100644
--- a/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMatrixSetDiag.cpp
@@ -25,15 +25,11 @@ namespace luci
bool CircleMatrixSetDiagGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- return false;
-
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
diff --git a/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp b/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp
index 4bca0f40b..5c03fff18 100644
--- a/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMaxPool2D.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleMaxPool2DGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleMaxPool2DGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleMean.cpp b/compiler/luci/import/src/Nodes/CircleMean.cpp
index d8fa9a53d..7882f17fc 100644
--- a/compiler/luci/import/src/Nodes/CircleMean.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMean.cpp
@@ -23,10 +23,7 @@ namespace luci
bool CircleMeanGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleMeanGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp b/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp
index e0ddd4c11..e40ce2249 100644
--- a/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMirrorPad.cpp
@@ -25,12 +25,8 @@ namespace luci
bool CircleMirrorPadGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
// TODO check others
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleMirrorPadGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleMul.cpp b/compiler/luci/import/src/Nodes/CircleMul.cpp
index e3c4a7ee5..28421f8c4 100644
--- a/compiler/luci/import/src/Nodes/CircleMul.cpp
+++ b/compiler/luci/import/src/Nodes/CircleMul.cpp
@@ -23,13 +23,7 @@ namespace luci
bool CircleMulGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleMulGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleNeg.cpp b/compiler/luci/import/src/Nodes/CircleNeg.cpp
index a64a69560..9dd1458f4 100644
--- a/compiler/luci/import/src/Nodes/CircleNeg.cpp
+++ b/compiler/luci/import/src/Nodes/CircleNeg.cpp
@@ -24,11 +24,8 @@ namespace luci
{
bool CircleNegGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
// TODO Support type check
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleNegGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp
index a4ad4a53d..d3d69506b 100644
--- a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp
+++ b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV4.cpp
@@ -61,63 +61,27 @@ bool CircleNonMaxSuppressionV4GraphBuilder::validate(const ValidateArgs &args) c
* We will create multiple NonMasSuppressionV4Oout nodes to emulate this
*/
-void CircleNonMaxSuppressionV4GraphBuilder::build(const circle::OperatorT &op,
- GraphBuilderContext *context) const
+CircleNode *CircleNonMaxSuppressionV4GraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
-
- auto graph = context->graph();
-
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
-
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
-
- // Create CircleNonMaxSuppressionV4
- auto node = graph->nodes()->create<CircleNonMaxSuppressionV4>();
- node->boxes(input_nodes[0]);
- node->scores(input_nodes[1]);
- node->max_output_size(input_nodes[2]);
- node->iou_threshold(input_nodes[3]);
- node->score_threshold(input_nodes[4]);
-
- assert(outputs.size() == 2);
- {
- // Let's use name of output 0 as NonMaxSuppressionV4 name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
-
- // NOTE We don't set quantization for NonMaxSuppressionV4 itself but to virtual outputs
- }
-
- // Create virtual outputs of NonMaxSuppressionV4
- for (size_t n = 0; n < outputs.size(); ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
-
- auto *nodeout = graph->nodes()->create<CircleNonMaxSuppressionV4Out>();
- copy_tensor_attributes(output_tensor, nodeout);
-
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
-
- nodeout->input(node);
- nodeout->index(n);
-
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ auto node = bna.context->graph()->nodes()->create<CircleNonMaxSuppressionV4>();
+
+ node->boxes(bna.input_nodes[0]);
+ node->scores(bna.input_nodes[1]);
+ node->max_output_size(bna.input_nodes[2]);
+ node->iou_threshold(bna.input_nodes[3]);
+ node->score_threshold(bna.input_nodes[4]);
+
+ return node;
+}
+
+CircleNode *CircleNonMaxSuppressionV4GraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleNonMaxSuppressionV4Out>();
+
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
+
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp
index 241dbf5ff..d797d4cb7 100644
--- a/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp
+++ b/compiler/luci/import/src/Nodes/CircleNonMaxSuppressionV5.cpp
@@ -63,64 +63,28 @@ bool CircleNonMaxSuppressionV5GraphBuilder::validate(const ValidateArgs &args) c
* We will create multiple NonMasSuppressionV5Oout nodes to emulate this
*/
-void CircleNonMaxSuppressionV5GraphBuilder::build(const circle::OperatorT &op,
- GraphBuilderContext *context) const
+CircleNode *CircleNonMaxSuppressionV5GraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
-
- auto graph = context->graph();
-
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
-
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
-
- // Create CircleNonMaxSuppressionV5
- auto node = graph->nodes()->create<CircleNonMaxSuppressionV5>();
- node->boxes(input_nodes[0]);
- node->scores(input_nodes[1]);
- node->max_output_size(input_nodes[2]);
- node->iou_threshold(input_nodes[3]);
- node->score_threshold(input_nodes[4]);
- node->soft_nms_sigma(input_nodes[5]);
-
- assert(outputs.size() == 3);
- {
- // Let's use name of output 0 as NonMaxSuppressionV5 name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
-
- // NOTE We don't set quantization for NonMaxSuppressionV5 itself but to virtual outputs
- }
-
- // Create virtual outputs of NonMaxSuppressionV5
- for (size_t n = 0; n < outputs.size(); ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
-
- auto *nodeout = graph->nodes()->create<CircleNonMaxSuppressionV5Out>();
- copy_tensor_attributes(output_tensor, nodeout);
-
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
-
- nodeout->input(node);
- nodeout->index(n);
-
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ auto node = bna.context->graph()->nodes()->create<CircleNonMaxSuppressionV5>();
+
+ node->boxes(bna.input_nodes[0]);
+ node->scores(bna.input_nodes[1]);
+ node->max_output_size(bna.input_nodes[2]);
+ node->iou_threshold(bna.input_nodes[3]);
+ node->score_threshold(bna.input_nodes[4]);
+ node->soft_nms_sigma(bna.input_nodes[5]);
+
+ return node;
+}
+
+CircleNode *CircleNonMaxSuppressionV5GraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleNonMaxSuppressionV5Out>();
+
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
+
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleNotEqual.cpp b/compiler/luci/import/src/Nodes/CircleNotEqual.cpp
index 77e986de1..a0b8f9e4f 100644
--- a/compiler/luci/import/src/Nodes/CircleNotEqual.cpp
+++ b/compiler/luci/import/src/Nodes/CircleNotEqual.cpp
@@ -25,19 +25,11 @@ namespace luci
bool CircleNotEqualGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- {
+ if (!GraphBuilder::validate(args, 2))
return false;
- }
-
- if (outputs.size() != 1)
- {
- return false;
- }
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
if (tensors[inputs.at(0)]->type != tensors[inputs.at(1)]->type)
diff --git a/compiler/luci/import/src/Nodes/CircleOneHot.cpp b/compiler/luci/import/src/Nodes/CircleOneHot.cpp
index 69294e1ed..3952cc21a 100644
--- a/compiler/luci/import/src/Nodes/CircleOneHot.cpp
+++ b/compiler/luci/import/src/Nodes/CircleOneHot.cpp
@@ -26,17 +26,12 @@ namespace luci
bool CircleOneHotGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- const auto *options = args.op.builtin_options.AsOneHotOptions();
-
// Only 4 Input come refered from
- if (inputs.size() != 4)
- return false;
-
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 4))
return false;
+ const auto &inputs = args.op.inputs;
+ const auto *options = args.op.builtin_options.AsOneHotOptions();
const auto &tensors = args.reader.tensors();
const auto &indices = tensors.at(inputs.at(0));
const auto &depth = tensors.at(inputs.at(1));
diff --git a/compiler/luci/import/src/Nodes/CirclePRelu.cpp b/compiler/luci/import/src/Nodes/CirclePRelu.cpp
index c07920f7c..7c81f04bb 100644
--- a/compiler/luci/import/src/Nodes/CirclePRelu.cpp
+++ b/compiler/luci/import/src/Nodes/CirclePRelu.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CirclePReluGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CirclePReluGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CirclePad.cpp b/compiler/luci/import/src/Nodes/CirclePad.cpp
index 999173b90..67dce6dee 100644
--- a/compiler/luci/import/src/Nodes/CirclePad.cpp
+++ b/compiler/luci/import/src/Nodes/CirclePad.cpp
@@ -25,12 +25,8 @@ namespace luci
bool CirclePadGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
// TODO do attribute checks
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CirclePadGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CirclePadV2.cpp b/compiler/luci/import/src/Nodes/CirclePadV2.cpp
index 493876e68..84a45722a 100644
--- a/compiler/luci/import/src/Nodes/CirclePadV2.cpp
+++ b/compiler/luci/import/src/Nodes/CirclePadV2.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CirclePadV2GraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 3)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 3);
}
CircleNode *CirclePadV2GraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CirclePow.cpp b/compiler/luci/import/src/Nodes/CirclePow.cpp
index def012614..1d2d41607 100644
--- a/compiler/luci/import/src/Nodes/CirclePow.cpp
+++ b/compiler/luci/import/src/Nodes/CirclePow.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CirclePowGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CirclePowGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleRange.cpp b/compiler/luci/import/src/Nodes/CircleRange.cpp
index 38dc44ed6..d3b5afc95 100644
--- a/compiler/luci/import/src/Nodes/CircleRange.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRange.cpp
@@ -24,11 +24,8 @@ namespace luci
{
bool CircleRangeGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 3)
- return false;
-
// TODO Support type check
- return true;
+ return GraphBuilder::validate(args, 3);
}
CircleNode *CircleRangeGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleRank.cpp b/compiler/luci/import/src/Nodes/CircleRank.cpp
index 12658b192..afebb9509 100644
--- a/compiler/luci/import/src/Nodes/CircleRank.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRank.cpp
@@ -24,13 +24,7 @@ namespace luci
{
bool CircleRankGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleRankGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleReduceAny.cpp b/compiler/luci/import/src/Nodes/CircleReduceAny.cpp
index 21a821951..13205dd7a 100644
--- a/compiler/luci/import/src/Nodes/CircleReduceAny.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReduceAny.cpp
@@ -23,13 +23,11 @@ namespace luci
bool CircleReduceAnyGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 2)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor_0 = tensors.at(inputs.at(0));
const auto &tensor_1 = tensors.at(inputs.at(1));
diff --git a/compiler/luci/import/src/Nodes/CircleReduceProd.cpp b/compiler/luci/import/src/Nodes/CircleReduceProd.cpp
index 5f054586e..3549c1a18 100644
--- a/compiler/luci/import/src/Nodes/CircleReduceProd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReduceProd.cpp
@@ -23,12 +23,10 @@ namespace luci
bool CircleReduceProdGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 2)
- return false;
- if (args.op.outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
+ const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
const auto &tensor_1 = tensors.at(inputs.at(1));
diff --git a/compiler/luci/import/src/Nodes/CircleRelu.cpp b/compiler/luci/import/src/Nodes/CircleRelu.cpp
index 8e1c32a3a..73b8ffee8 100644
--- a/compiler/luci/import/src/Nodes/CircleRelu.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRelu.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CircleReluGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleReluGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleRelu6.cpp b/compiler/luci/import/src/Nodes/CircleRelu6.cpp
index 0283d7350..ab957eda8 100644
--- a/compiler/luci/import/src/Nodes/CircleRelu6.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRelu6.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CircleRelu6GraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleRelu6GraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp b/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp
index 7f517bc0d..4987f3be2 100644
--- a/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReluN1To1.cpp
@@ -25,15 +25,8 @@ namespace luci
bool CircleReluN1To1GraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
// TODO check dtypes
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleReluN1To1GraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleReshape.cpp b/compiler/luci/import/src/Nodes/CircleReshape.cpp
index 996ae9d20..401dff0fc 100644
--- a/compiler/luci/import/src/Nodes/CircleReshape.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReshape.cpp
@@ -30,6 +30,19 @@ bool CircleReshapeGraphBuilder::validate(const ValidateArgs &args) const
if (args.op.outputs.size() != 1)
return false;
+ // for two inputs, check if type is S32
+ if (args.op.inputs.size() == 2)
+ {
+ const auto &inputs = args.op.inputs;
+ const auto &tensors = args.reader.tensors();
+ const auto &tensor_in = tensors.at(inputs.at(1));
+
+ // NOTE fix this if there is any other case
+ // TensorFlow lite and circle only supports S32
+ if (tensor_in->type != circle::TensorType::TensorType_INT32)
+ return false;
+ }
+
return true;
}
@@ -53,6 +66,7 @@ static CircleNode *create_shape_node(const std::vector<int32_t> &shape, loco::Gr
{
shape_node->at<loco::DataType::S32>(i) = shape[i];
}
+ shape_node->name("Reshape/shape");
return shape_node;
}
@@ -73,6 +87,7 @@ CircleNode *CircleReshapeGraphBuilder::build_node(const circle::OperatorT &op,
shape_node = graph->nodes()->create<CircleOutputDummy>();
shape_node->dtype(loco::DataType::S32);
shape_node->rank(0);
+ shape_node->name("Reshape/dummy");
}
}
diff --git a/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp b/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp
index 0fccb7b44..c751b245c 100644
--- a/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp
+++ b/compiler/luci/import/src/Nodes/CircleResizeBilinear.cpp
@@ -16,7 +16,6 @@
#include "luci/Import/Nodes/CircleResizeBilinear.h"
-#include <luci/IR/Nodes/CircleConst.h>
#include <luci/IR/Nodes/CircleResizeBilinear.h>
namespace luci
@@ -24,13 +23,7 @@ namespace luci
bool CircleResizeBilinearGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleResizeBilinearGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp b/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp
index 324323f59..df7517fe9 100644
--- a/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp
+++ b/compiler/luci/import/src/Nodes/CircleResizeNearestNeighbor.cpp
@@ -16,7 +16,6 @@
#include "luci/Import/Nodes/CircleResizeNearestNeighbor.h"
-#include <luci/IR/Nodes/CircleConst.h>
#include <luci/IR/Nodes/CircleResizeNearestNeighbor.h>
namespace luci
@@ -24,17 +23,11 @@ namespace luci
bool CircleResizeNearestNeighborGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleResizeNearestNeighborGraphBuilder::build_node(
- const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const
+ const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const
{
auto *node = graph->nodes()->create<CircleResizeNearestNeighbor>();
node->input(inputs.at(0));
diff --git a/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp b/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp
index ad11d4c63..2fbb7a87c 100644
--- a/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReverseSequence.cpp
@@ -25,14 +25,11 @@ namespace luci
bool CircleReverseSequenceGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor_in = tensors.at(inputs.at(0));
const auto &tensor_lengths = tensors.at(inputs.at(1));
diff --git a/compiler/luci/import/src/Nodes/CircleReverseV2.cpp b/compiler/luci/import/src/Nodes/CircleReverseV2.cpp
index e2e53bb4b..ca7653201 100644
--- a/compiler/luci/import/src/Nodes/CircleReverseV2.cpp
+++ b/compiler/luci/import/src/Nodes/CircleReverseV2.cpp
@@ -25,14 +25,11 @@ namespace luci
bool CircleReverseV2GraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor_in = tensors.at(inputs.at(0));
const auto &tensor_axis = tensors.at(inputs.at(1));
diff --git a/compiler/luci/import/src/Nodes/CircleRound.cpp b/compiler/luci/import/src/Nodes/CircleRound.cpp
index ad77f9f03..d13e0fafe 100644
--- a/compiler/luci/import/src/Nodes/CircleRound.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRound.cpp
@@ -25,14 +25,11 @@ namespace luci
bool CircleRoundGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 1)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
// Must be one of the following types
// bfloat16, half (float16), float32, float64, complex64, complex128
// Currently, circle supports float16, float32, complex64
diff --git a/compiler/luci/import/src/Nodes/CircleRsqrt.cpp b/compiler/luci/import/src/Nodes/CircleRsqrt.cpp
index ae05fbbf9..a9ca90832 100644
--- a/compiler/luci/import/src/Nodes/CircleRsqrt.cpp
+++ b/compiler/luci/import/src/Nodes/CircleRsqrt.cpp
@@ -25,10 +25,10 @@ namespace luci
bool CircleRsqrtGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
// Must be one of the following types
// bfloat16, half (float16), float32, float64, complex64, complex128
// Currently, circle supports float16, float32, complex64
@@ -36,6 +36,8 @@ bool CircleRsqrtGraphBuilder::validate(const ValidateArgs &args) const
const auto &tensor = tensors.at(inputs.at(0));
switch (tensor->type)
{
+ case circle::TensorType_UINT8:
+ case circle::TensorType_INT16:
case circle::TensorType_FLOAT16:
case circle::TensorType_FLOAT32:
case circle::TensorType_COMPLEX64:
diff --git a/compiler/luci/import/src/Nodes/CircleScatterNd.cpp b/compiler/luci/import/src/Nodes/CircleScatterNd.cpp
index 7f86aeb74..f8c175110 100644
--- a/compiler/luci/import/src/Nodes/CircleScatterNd.cpp
+++ b/compiler/luci/import/src/Nodes/CircleScatterNd.cpp
@@ -25,10 +25,10 @@ namespace luci
bool CircleScatterNdGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 3)
+ if (!GraphBuilder::validate(args, 3))
return false;
+ const auto &inputs = args.op.inputs;
// indices must have the same type as shape
const auto &tensors = args.reader.tensors();
diff --git a/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp b/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp
index fb84e5d52..bfa333e8d 100644
--- a/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSegmentSum.cpp
@@ -25,13 +25,11 @@ namespace luci
bool CircleSegmentSumGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 2)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor_in = tensors.at(inputs.at(0));
const auto &tensor_out = tensors.at(outputs[0]);
diff --git a/compiler/luci/import/src/Nodes/CircleSelect.cpp b/compiler/luci/import/src/Nodes/CircleSelect.cpp
index 1e649f1e0..36a5fa8a8 100644
--- a/compiler/luci/import/src/Nodes/CircleSelect.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSelect.cpp
@@ -25,13 +25,10 @@ namespace luci
bool CircleSelectGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 3)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 3))
return false;
+ const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
if (tensor->type != circle::TensorType_BOOL)
diff --git a/compiler/luci/import/src/Nodes/CircleSelectV2.cpp b/compiler/luci/import/src/Nodes/CircleSelectV2.cpp
index e6dd04de0..556c8fa33 100644
--- a/compiler/luci/import/src/Nodes/CircleSelectV2.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSelectV2.cpp
@@ -25,13 +25,10 @@ namespace luci
bool CircleSelectV2GraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 3)
- return false;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 3))
return false;
+ const auto &inputs = args.op.inputs;
const auto &tensors = args.reader.tensors();
const auto &condition = tensors.at(inputs.at(0));
if (condition->type != circle::TensorType_BOOL)
diff --git a/compiler/luci/import/src/Nodes/CircleShape.cpp b/compiler/luci/import/src/Nodes/CircleShape.cpp
index bd7dfc9d9..86c0bf59b 100644
--- a/compiler/luci/import/src/Nodes/CircleShape.cpp
+++ b/compiler/luci/import/src/Nodes/CircleShape.cpp
@@ -25,16 +25,8 @@ namespace luci
bool CircleShapeGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
- if (inputs.size() != 1)
- return false;
- if (outputs.size() != 1)
- return false;
-
// TODO check shape, dtype
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleShapeGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleSin.cpp b/compiler/luci/import/src/Nodes/CircleSin.cpp
index 4b245ef6b..22f461123 100644
--- a/compiler/luci/import/src/Nodes/CircleSin.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSin.cpp
@@ -25,12 +25,10 @@ namespace luci
bool CircleSinGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
- return false;
- if (args.op.outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
// input type check
const auto &tensors = args.reader.tensors();
const auto &tensor = tensors.at(inputs.at(0));
diff --git a/compiler/luci/import/src/Nodes/CircleSlice.cpp b/compiler/luci/import/src/Nodes/CircleSlice.cpp
index 8601fbf21..4166040b3 100644
--- a/compiler/luci/import/src/Nodes/CircleSlice.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSlice.cpp
@@ -27,14 +27,8 @@ namespace luci
bool CircleSliceGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 3)
- return false;
- if (args.op.outputs.size() != 1)
- return false;
-
// TODO check shapes and types
-
- return true;
+ return GraphBuilder::validate(args, 3);
}
CircleNode *CircleSliceGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleSoftmax.cpp b/compiler/luci/import/src/Nodes/CircleSoftmax.cpp
index 0ef0b5418..e79914455 100644
--- a/compiler/luci/import/src/Nodes/CircleSoftmax.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSoftmax.cpp
@@ -25,12 +25,8 @@ namespace luci
bool CircleSoftmaxGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
// TODO do attribute checks
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleSoftmaxGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp b/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp
index 8ccd55dc6..2152b65c9 100644
--- a/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSpaceToDepth.cpp
@@ -27,13 +27,8 @@ namespace luci
bool CircleSpaceToDepthGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
- return false;
-
// TODO do attribute checks
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleSpaceToDepthGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp b/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp
index ac756b1f3..ce0688bb9 100644
--- a/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSparseToDense.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleSparseToDenseGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 4)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 4);
}
CircleNode *CircleSparseToDenseGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleSplit.cpp b/compiler/luci/import/src/Nodes/CircleSplit.cpp
index 07b6cc939..d0a24aae3 100644
--- a/compiler/luci/import/src/Nodes/CircleSplit.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSplit.cpp
@@ -58,62 +58,27 @@ bool CircleSplitGraphBuilder::validate(const ValidateArgs &args) const
* \- CircleSplitOut --- FullyConnected ---
*/
-void CircleSplitGraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const
+CircleNode *CircleSplitGraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
+ auto node = bna.context->graph()->nodes()->create<CircleSplit>();
- auto graph = context->graph();
+ node->split_dim(bna.input_nodes[0]);
+ node->input(bna.input_nodes[1]);
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
+ const auto *options = bna.op.builtin_options.AsSplitOptions();
+ node->num_split(options->num_splits);
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
+ return node;
+}
- // Create CircleSplit
- auto node = graph->nodes()->create<CircleSplit>();
- node->split_dim(input_nodes[0]);
- node->input(input_nodes[1]);
+CircleNode *CircleSplitGraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleSplitOut>();
- const auto *options = op.builtin_options.AsSplitOptions();
- node->num_split(options->num_splits);
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
- assert(outputs.size() > 0);
- assert(int32_t(outputs.size()) == options->num_splits);
- {
- // Let's use name of output 0 as Split name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
-
- // NOTE We don't set quantization for Split itself but to virtual outputs
- }
-
- // Create virtual outputs of Split
- for (int32_t n = 0; n < options->num_splits; ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
-
- auto *nodeout = graph->nodes()->create<CircleSplitOut>();
- copy_tensor_attributes(output_tensor, nodeout);
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
-
- nodeout->input(node);
- nodeout->index(n);
-
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleSplitV.cpp b/compiler/luci/import/src/Nodes/CircleSplitV.cpp
index 7c6e83e17..76cbf7046 100644
--- a/compiler/luci/import/src/Nodes/CircleSplitV.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSplitV.cpp
@@ -58,64 +58,30 @@ bool CircleSplitVGraphBuilder::validate(const ValidateArgs &args) const
* \- CircleSplitVOut --- FullyConnected ---
*/
-void CircleSplitVGraphBuilder::build(const circle::OperatorT &op,
- GraphBuilderContext *context) const
+CircleNode *CircleSplitVGraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
-
- auto graph = context->graph();
-
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
-
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
-
- // Create CircleSplitV
- auto node = graph->nodes()->create<CircleSplitV>();
- node->input(input_nodes[0]);
- node->size_splits(input_nodes[1]);
- node->split_dim(input_nodes[2]);
-
- const auto *options = op.builtin_options.AsSplitVOptions();
+ auto node = bna.context->graph()->nodes()->create<CircleSplitV>();
+
+ node->input(bna.input_nodes[0]);
+ node->size_splits(bna.input_nodes[1]);
+ node->split_dim(bna.input_nodes[2]);
+
+ const auto *options = bna.op.builtin_options.AsSplitVOptions();
node->num_split(options->num_splits);
- assert(outputs.size() > 0);
- assert(int32_t(outputs.size()) == options->num_splits);
- {
- // Let's use name of output 0 as Split name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
-
- // NOTE We don't set quantization for Split itself but to virtual outputs
- }
-
- // Create virtual outputs of Split
- for (int32_t n = 0; n < options->num_splits; ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
-
- auto *nodeout = graph->nodes()->create<CircleSplitVOut>();
- copy_tensor_attributes(output_tensor, nodeout);
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
-
- nodeout->input(node);
- nodeout->index(n);
-
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ assert(int32_t(bna.op.outputs.size()) == options->num_splits);
+
+ return node;
+}
+
+CircleNode *CircleSplitVGraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleSplitVOut>();
+
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
+
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleSqrt.cpp b/compiler/luci/import/src/Nodes/CircleSqrt.cpp
index c8beaee0d..b1fdf7996 100644
--- a/compiler/luci/import/src/Nodes/CircleSqrt.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSqrt.cpp
@@ -25,10 +25,7 @@ namespace luci
bool CircleSqrtGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleSqrtGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/Nodes/CircleSquare.cpp b/compiler/luci/import/src/Nodes/CircleSquare.cpp
index b5ba048d7..7ff2b84e6 100644
--- a/compiler/luci/import/src/Nodes/CircleSquare.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSquare.cpp
@@ -25,10 +25,10 @@ namespace luci
bool CircleSquareGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
// Must be one of the following types
// bfloat16, half (float16), float32, float64, complex64, complex128
// Currently, circle supports float16, float32, complex64
diff --git a/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp b/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp
index 6deae94c5..f4e193713 100644
--- a/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSquaredDifference.cpp
@@ -25,15 +25,11 @@ namespace luci
bool CircleSquaredDifferenceGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- return false;
-
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
// Inputs must be one of the following types
// bfloat16, half(float16), float32, float64, int32, int64, complex64, complex128
const auto &tensors = args.reader.tensors();
diff --git a/compiler/luci/import/src/Nodes/CircleSqueeze.cpp b/compiler/luci/import/src/Nodes/CircleSqueeze.cpp
index 32792c266..d24d8166c 100644
--- a/compiler/luci/import/src/Nodes/CircleSqueeze.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSqueeze.cpp
@@ -16,7 +16,6 @@
#include "luci/Import/Nodes/CircleSqueeze.h"
-#include <luci/IR/Nodes/CircleConst.h>
#include <luci/IR/Nodes/CircleSqueeze.h>
namespace luci
@@ -24,13 +23,7 @@ namespace luci
bool CircleSqueezeGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleSqueezeGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp b/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp
index 8f943a682..ca8259cac 100644
--- a/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp
+++ b/compiler/luci/import/src/Nodes/CircleStridedSlice.cpp
@@ -27,14 +27,8 @@ namespace luci
bool CircleStridedSliceGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 4)
- return false;
- if (args.op.outputs.size() != 1)
- return false;
-
// TODO check shapes and types
-
- return true;
+ return GraphBuilder::validate(args, 4);
}
CircleNode *CircleStridedSliceGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleSub.cpp b/compiler/luci/import/src/Nodes/CircleSub.cpp
index 9acf83d40..c3978f218 100644
--- a/compiler/luci/import/src/Nodes/CircleSub.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSub.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CircleSubGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleSubGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleSum.cpp b/compiler/luci/import/src/Nodes/CircleSum.cpp
index bd3cb6239..e348a62d9 100644
--- a/compiler/luci/import/src/Nodes/CircleSum.cpp
+++ b/compiler/luci/import/src/Nodes/CircleSum.cpp
@@ -23,10 +23,7 @@ namespace luci
bool CircleSumGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleSumGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleTanh.cpp b/compiler/luci/import/src/Nodes/CircleTanh.cpp
index 018f5701b..95625a0e4 100644
--- a/compiler/luci/import/src/Nodes/CircleTanh.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTanh.cpp
@@ -25,13 +25,11 @@ namespace luci
bool CircleTanhGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- if (inputs.size() != 1)
- return false;
- const auto &outputs = args.op.outputs;
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
if (tensors.at(inputs.at(0))->type != tensors.at(outputs[0])->type)
return false;
diff --git a/compiler/luci/import/src/Nodes/CircleTile.cpp b/compiler/luci/import/src/Nodes/CircleTile.cpp
index bc6f320ba..6da44130c 100644
--- a/compiler/luci/import/src/Nodes/CircleTile.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTile.cpp
@@ -25,15 +25,11 @@ namespace luci
bool CircleTileGraphBuilder::validate(const ValidateArgs &args) const
{
- auto inputs = args.op.inputs;
- auto outputs = args.op.outputs;
-
- if (inputs.size() != 2)
- return false;
-
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 2))
return false;
+ auto inputs = args.op.inputs;
+ auto outputs = args.op.outputs;
// Multiples (inputs.at(1)) must be one of the following types
// int32, int64
const auto &tensors = args.reader.tensors();
diff --git a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp
index f0677de86..49f858798 100644
--- a/compiler/luci/import/src/Nodes/CircleTopKV2.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTopKV2.cpp
@@ -59,59 +59,24 @@ bool CircleTopKV2GraphBuilder::validate(const ValidateArgs &args) const
* \- CircleTopKV2Out --- FullyConnected ---
*/
-void CircleTopKV2GraphBuilder::build(const circle::OperatorT &op,
- GraphBuilderContext *context) const
+CircleNode *CircleTopKV2GraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
-
- auto graph = context->graph();
-
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
-
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
-
- // Create CircleTopKV2
- auto node = graph->nodes()->create<CircleTopKV2>();
- node->input(input_nodes[0]);
- node->k(input_nodes[1]);
-
- assert(outputs.size() == 2);
- {
- // Let's use name of output 0 as TopKV2 name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
-
- // NOTE We don't set quantization for TopKV2 itself but to virtual outputs
- }
-
- // Create virtual outputs of TopKV2
- for (size_t n = 0; n < outputs.size(); ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
-
- auto *nodeout = graph->nodes()->create<CircleTopKV2Out>();
- copy_tensor_attributes(output_tensor, nodeout);
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
-
- nodeout->input(node);
- nodeout->index(n);
-
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ auto node = bna.context->graph()->nodes()->create<CircleTopKV2>();
+
+ node->input(bna.input_nodes[0]);
+ node->k(bna.input_nodes[1]);
+
+ return node;
+}
+
+CircleNode *CircleTopKV2GraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleTopKV2Out>();
+
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
+
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleTranspose.cpp b/compiler/luci/import/src/Nodes/CircleTranspose.cpp
index cc3153085..01095239e 100644
--- a/compiler/luci/import/src/Nodes/CircleTranspose.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTranspose.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CircleTransposeGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 2)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 2);
}
CircleNode *CircleTransposeGraphBuilder::build_node(const circle::OperatorT &op,
diff --git a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp
index c280faaf5..5a60e2f54 100644
--- a/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp
+++ b/compiler/luci/import/src/Nodes/CircleTransposeConv.cpp
@@ -61,16 +61,15 @@ CircleNode *CircleTransposeConvGraphBuilder::build_node(const circle::OperatorT
node->filter(inputs.at(1));
node->outBackprop(inputs.at(2));
if (inputs.size() == 3)
- node->bias(graph->nodes()->create<CircleOutputExclude>());
- else
- node->bias(inputs.at(3));
-
- if (auto bias = dynamic_cast<luci::CircleOutputExclude *>(node->bias()))
{
- // CircleOutputExclude doesn't need a type, but since all nodes must have a type, a dummy type
- // is inserted.
+ auto *bias = graph->nodes()->create<CircleOutputExclude>();
+ // CircleOutputExclude doesn't need a type, but since all nodes must have a type,
+ // a dummy type is inserted.
bias->dtype(loco::DataType::FLOAT32);
+ node->bias(bias);
}
+ else
+ node->bias(inputs.at(3));
const auto *options = op.builtin_options.AsTransposeConvOptions();
node->padding(luci_padding(options->padding));
diff --git a/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp b/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp
index c41cf4def..d9cc3f8d0 100644
--- a/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp
+++ b/compiler/luci/import/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp
@@ -25,14 +25,11 @@ namespace luci
bool CircleUnidirectionalSequenceLSTMGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 24)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 24);
}
CircleNode *CircleUnidirectionalSequenceLSTMGraphBuilder::build_node(
- const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const
+ const circle::OperatorT &op, const std::vector<CircleNode *> &inputs, loco::Graph *graph) const
{
auto *node = graph->nodes()->create<CircleUnidirectionalSequenceLSTM>();
node->input(inputs.at(0));
@@ -59,16 +56,6 @@ CircleNode *CircleUnidirectionalSequenceLSTMGraphBuilder::build_node(
node->forget_layer_norm_coefficients(inputs.at(21)); // Optional
node->cell_layer_norm_coefficients(inputs.at(22)); // Optional
node->output_layer_norm_coefficients(inputs.at(23)); // Optional
- const std::vector<int32_t> optionals = {1, 5, 9, 10, 11, 12, 16, 17, 20, 21, 22, 23};
- for (auto optional : optionals)
- {
- if (auto inp = dynamic_cast<luci::CircleOutputExclude *>(node->arg(optional)))
- {
- // CircleOutputExclude doesn't need a type, but since all nodes must have a type, a dummy type
- // is inserted.
- inp->dtype(loco::DataType::FLOAT32);
- }
- }
const auto *options = op.builtin_options.AsUnidirectionalSequenceLSTMOptions();
node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
diff --git a/compiler/luci/import/src/Nodes/CircleUnique.cpp b/compiler/luci/import/src/Nodes/CircleUnique.cpp
index 5e79a2920..f6914c24a 100644
--- a/compiler/luci/import/src/Nodes/CircleUnique.cpp
+++ b/compiler/luci/import/src/Nodes/CircleUnique.cpp
@@ -35,55 +35,26 @@ bool CircleUniqueGraphBuilder::validate(const ValidateArgs &args) const
return true;
}
-void CircleUniqueGraphBuilder::build(const circle::OperatorT &op,
- GraphBuilderContext *context) const
+CircleNode *CircleUniqueGraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
+ auto node = bna.context->graph()->nodes()->create<CircleUnique>();
- auto graph = context->graph();
+ node->input(bna.input_nodes[0]);
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
+ const auto *options = bna.op.builtin_options.AsUniqueOptions();
+ node->idx_out_type(luci_datatype(options->idx_out_type));
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
-
- // Create CircleUnique
- auto node = graph->nodes()->create<CircleUnique>();
- node->input(input_nodes[0]);
-
- const auto *options = op.builtin_options.AsUniqueOptions();
- node->output_type(luci_datatype(options->idx_out_type));
-
- assert(int32_t(outputs.size()) == 2);
- // Let's use name of output 0 as Unique name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
-
- // Create virtual outputs of Unique
- for (int32_t n = 0; n < 2; ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
+ return node;
+}
- auto *nodeout = graph->nodes()->create<CircleUniqueOut>();
- copy_tensor_attributes(output_tensor, nodeout);
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
+CircleNode *CircleUniqueGraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleUniqueOut>();
- nodeout->input(node);
- nodeout->index(n);
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleUnpack.cpp b/compiler/luci/import/src/Nodes/CircleUnpack.cpp
index 9e7f3d3e1..9bfc76b57 100644
--- a/compiler/luci/import/src/Nodes/CircleUnpack.cpp
+++ b/compiler/luci/import/src/Nodes/CircleUnpack.cpp
@@ -88,64 +88,27 @@ bool CircleUnpackGraphBuilder::validate(const ValidateArgs &args) const
* \- CircleUnpackOut --- FullyConnected ---
*/
-void CircleUnpackGraphBuilder::build(const circle::OperatorT &op,
- GraphBuilderContext *context) const
+CircleNode *CircleUnpackGraphBuilder::build_node(const BuildNodeArgs &bna) const
{
- assert(context != nullptr);
+ auto node = bna.context->graph()->nodes()->create<CircleUnpack>();
- auto graph = context->graph();
+ node->value(bna.input_nodes[0]);
- const std::vector<int32_t> &inputs = op.inputs;
- const std::vector<int32_t> &outputs = op.outputs;
- const auto &tensors = context->reader()->tensors();
- const auto &opcodes = context->reader()->opcodes();
- auto tensors_ptr = context->reader()->tensors_ptr();
- assert(tensors_ptr != nullptr);
-
- // NOTE Unpack has only one input so running a loop is not necessary
- // This is provided as a reference for other Ops as a reference
- std::vector<CircleNode *> input_nodes;
- for (const int32_t input_tensor_index : inputs)
- {
- input_nodes.push_back(context->nodefinder()->node(input_tensor_index));
- }
-
- // Create CircleUnpack
- CircleUnpack *node = graph->nodes()->create<CircleUnpack>();
- node->value(input_nodes[0]);
-
- const auto *options = op.builtin_options.AsUnpackOptions();
+ const auto *options = bna.op.builtin_options.AsUnpackOptions();
node->num(options->num);
node->axis(options->axis);
- assert(outputs.size() > 0);
- {
- // Let's use name of output 0 as Unpack name
- const circle::TensorT &output_tensor = *tensors[outputs[0]];
- node->name(tensor_name(output_tensor));
- node->op_version(opcodes[op.opcode_index].get()->version);
-
- // NOTE We don't set quantization for Unpack itself but to virtual outputs
- }
-
- // Create virtual outputs of Unpack
- for (int32_t n = 0; n < options->num; ++n)
- {
- const circle::TensorT &output_tensor = *tensors[outputs[n]];
+ return node;
+}
- auto *nodeout = graph->nodes()->create<CircleUnpackOut>();
- copy_tensor_attributes(output_tensor, nodeout);
- // mark shape_status
- if (tensors_ptr->Get(outputs[n])->shape() == nullptr)
- nodeout->shape_status(ShapeStatus::NOSHAPE);
- else
- nodeout->shape_status(ShapeStatus::VALID);
+CircleNode *CircleUnpackGraphBuilder::build_out(const BuildOutArgs &boa) const
+{
+ auto *nodeout = boa.node->graph()->nodes()->create<CircleUnpackOut>();
- nodeout->input(node);
- nodeout->index(n);
+ nodeout->input(boa.node);
+ nodeout->index(boa.index);
- context->nodefinder()->enroll(outputs[n], nodeout);
- }
+ return nodeout;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleWhere.cpp b/compiler/luci/import/src/Nodes/CircleWhere.cpp
index f4c5f0c66..8e4f1a0c4 100644
--- a/compiler/luci/import/src/Nodes/CircleWhere.cpp
+++ b/compiler/luci/import/src/Nodes/CircleWhere.cpp
@@ -25,15 +25,11 @@ namespace luci
bool CircleWhereGraphBuilder::validate(const ValidateArgs &args) const
{
- const auto &inputs = args.op.inputs;
- const auto &outputs = args.op.outputs;
-
- if (inputs.size() != 1)
- return false;
-
- if (outputs.size() != 1)
+ if (!GraphBuilder::validate(args, 1))
return false;
+ const auto &inputs = args.op.inputs;
+ const auto &outputs = args.op.outputs;
const auto &tensors = args.reader.tensors();
const auto &tensor_condition = tensors.at(inputs.at(0));
const auto &tensor_out = tensors.at(outputs[0]);
diff --git a/compiler/luci/import/src/Nodes/CircleWhile.cpp b/compiler/luci/import/src/Nodes/CircleWhile.cpp
index aead25071..26147562f 100644
--- a/compiler/luci/import/src/Nodes/CircleWhile.cpp
+++ b/compiler/luci/import/src/Nodes/CircleWhile.cpp
@@ -58,7 +58,8 @@ bool CircleWhileGraphBuilder::validate(const ValidateArgs &args) const
* \- CircleWhileOut --- Node ---
*/
-void CircleWhileGraphBuilder::build(const circle::OperatorT &op, GraphBuilderContext *context) const
+CircleNode *CircleWhileGraphBuilder::build(const circle::OperatorT &op,
+ GraphBuilderContext *context) const
{
assert(context != nullptr);
@@ -118,6 +119,8 @@ void CircleWhileGraphBuilder::build(const circle::OperatorT &op, GraphBuilderCon
context->nodefinder()->enroll(outputs[n], nodeout);
}
+
+ return node;
}
} // namespace luci
diff --git a/compiler/luci/import/src/Nodes/CircleZerosLike.cpp b/compiler/luci/import/src/Nodes/CircleZerosLike.cpp
index e60424def..ddb05e8a4 100644
--- a/compiler/luci/import/src/Nodes/CircleZerosLike.cpp
+++ b/compiler/luci/import/src/Nodes/CircleZerosLike.cpp
@@ -25,13 +25,7 @@ namespace luci
bool CircleZerosLikeGraphBuilder::validate(const ValidateArgs &args) const
{
- if (args.op.inputs.size() != 1)
- return false;
-
- if (args.op.outputs.size() != 1)
- return false;
-
- return true;
+ return GraphBuilder::validate(args, 1);
}
CircleNode *CircleZerosLikeGraphBuilder::build_node(const circle::OperatorT &,
diff --git a/compiler/luci/import/src/PostImport.cpp b/compiler/luci/import/src/PostImport.cpp
index f436b48e8..63b16bb95 100644
--- a/compiler/luci/import/src/PostImport.cpp
+++ b/compiler/luci/import/src/PostImport.cpp
@@ -130,7 +130,10 @@ private:
namespace
{
/**
- * @brief ValidateNodeProp will validate inter graph connections for each Nodes
+ * @brief ValidateNodeProp will validate inter graph connections for each Nodes.
+ * @note In here, only loco::GraphInput and loco::GraphOutput are validated,
+ * since this class is for checking inter graph connections.
+ * CircleNodes such as CircleInput and CircleOutput will be validated at later steps.
*/
class ValidateNodeProp final : public luci::CircleNodeMutableVisitor<void>
{
@@ -172,9 +175,19 @@ public:
auto then_graph_output = then_graph_outputs->at(then_out->index());
auto else_graph_output = else_graph_outputs->at(else_out->index());
- if (!(*then_graph_output->shape() == *else_graph_output->shape()))
+ if (then_graph_output->shape()->rank() != else_graph_output->shape()->rank())
{
- INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output shape mismatch ", idx);
+ INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output rank mismatch ", idx);
+ }
+ for (uint32_t i = 0; i < then_graph_output->shape()->rank(); ++i)
+ {
+ if (then_graph_output->shape()->dim(i).known() &&
+ else_graph_output->shape()->dim(i).known() &&
+ then_graph_output->shape()->dim(i).value() !=
+ else_graph_output->shape()->dim(i).value())
+ {
+ INTERNAL_EXN_V("CircleIf THEN and ELSE Graph Output dimension mismatch ", idx);
+ }
}
if (then_graph_output->dtype() != else_graph_output->dtype())
{
@@ -231,18 +244,20 @@ public:
auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
auto body_graph_input = body_graph_inputs->at(body_in->index());
- if ((cond_in->rank() != body_in->rank()))
+ if (cond_graph_input->shape()->rank() != body_graph_input->shape()->rank())
{
- INTERNAL_EXN_V("CircleWhile COND input and BODY input shape mismatch ", idx);
+ INTERNAL_EXN_V("CircleWhile COND input and BODY input rank mismatch ", idx);
}
- if (cond_in->rank() > 0 && body_in->rank() > 0)
+ for (uint32_t i = 0; i < cond_graph_input->shape()->rank(); ++i)
{
- if (!(*cond_graph_input->shape() == *body_graph_input->shape()))
+ if (cond_graph_input->shape()->dim(i).known() &&
+ body_graph_input->shape()->dim(i).known() &&
+ cond_graph_input->shape()->dim(i).value() != body_graph_input->shape()->dim(i).value())
{
- INTERNAL_EXN_V("CircleWhile COND input and BODY input shape mismatch ", idx);
+ INTERNAL_EXN_V("CircleWhile COND input and BODY input dimension mismatch ", idx);
}
}
- if (cond_in->dtype() != body_in->dtype())
+ if (cond_graph_input->dtype() != body_graph_input->dtype())
{
INTERNAL_EXN_V("CircleWhile COND input and BODY input type mismatch ", idx);
}
@@ -257,18 +272,20 @@ public:
auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
auto body_graph_output = body_graph_outputs->at(body_out->index());
- if ((cond_in->rank() != body_out->rank()))
+ if (cond_graph_input->shape()->rank() != body_graph_output->shape()->rank())
{
- INTERNAL_EXN_V("CircleWhile COND input and BODY output shape mismatch ", idx);
+ INTERNAL_EXN_V("CircleWhile COND input and BODY output rank mismatch ", idx);
}
- if (cond_in->rank() > 0 && body_out->rank() > 0)
+ for (uint32_t i = 0; i < cond_graph_input->shape()->rank(); ++i)
{
- if (!(*cond_graph_input->shape() == *body_graph_output->shape()))
+ if (cond_graph_input->shape()->dim(i).known() &&
+ body_graph_output->shape()->dim(i).known() &&
+ cond_graph_input->shape()->dim(i).value() != body_graph_output->shape()->dim(i).value())
{
- INTERNAL_EXN_V("CircleWhile COND input and BODY output shape mismatch ", idx);
+ INTERNAL_EXN_V("CircleWhile COND input and BODY output dimension mismatch ", idx);
}
}
- if (cond_in->dtype() != body_out->dtype())
+ if (cond_graph_input->dtype() != body_graph_output->dtype())
{
INTERNAL_EXN_V("CircleWhile COND input and BODY output type mismatch ", idx);
}
diff --git a/compiler/luci/lang/CMakeLists.txt b/compiler/luci/lang/CMakeLists.txt
index 32d0a890d..c618fdd6f 100644
--- a/compiler/luci/lang/CMakeLists.txt
+++ b/compiler/luci/lang/CMakeLists.txt
@@ -7,6 +7,7 @@ target_include_directories(luci_lang PRIVATE src)
target_include_directories(luci_lang PUBLIC include)
target_link_libraries(luci_lang PUBLIC loco)
target_link_libraries(luci_lang PUBLIC oops)
+target_link_libraries(luci_lang PUBLIC nncc_coverage)
target_link_libraries(luci_lang PRIVATE logo)
target_link_libraries(luci_lang PRIVATE nncc_common)
diff --git a/compiler/luci/lang/include/luci/IR/CircleNodeDecl.h b/compiler/luci/lang/include/luci/IR/CircleNodeDecl.h
index e6410d154..edec9d18b 100644
--- a/compiler/luci/lang/include/luci/IR/CircleNodeDecl.h
+++ b/compiler/luci/lang/include/luci/IR/CircleNodeDecl.h
@@ -20,7 +20,6 @@
#include <loco/IR/Dialect.h>
#include <loco/IR/Node.h>
#include <loco/IR/NodeMixins.h>
-#include <luci/IR/CircleShapeSignature.h>
#include <luci/IR/PropertyShapeStatus.h>
#include "CircleOpcode.h"
@@ -62,9 +61,6 @@ struct CircleNode : public loco::Node,
_sparsityparam = std::move(sparsityparam);
}
- const ShapeSignature &shape_signature(void) const { return _shape_signature; }
- void shape_signature(const ShapeSignature &ss) { _shape_signature = ss; }
-
ShapeStatus shape_status(void) const { return _shape_status; }
void shape_status(ShapeStatus ss) { _shape_status = ss; }
@@ -75,7 +71,6 @@ private:
NodeName _name;
std::unique_ptr<CircleQuantParam> _quantparam;
std::unique_ptr<SparsityParam> _sparsityparam;
- ShapeSignature _shape_signature;
ShapeStatus _shape_status{ShapeStatus::UNDEFINED};
int32_t _op_version = 1;
};
diff --git a/compiler/luci/lang/include/luci/IR/CircleNodeImpl.h b/compiler/luci/lang/include/luci/IR/CircleNodeImpl.h
index a6b9488db..4b3178b9b 100644
--- a/compiler/luci/lang/include/luci/IR/CircleNodeImpl.h
+++ b/compiler/luci/lang/include/luci/IR/CircleNodeImpl.h
@@ -34,8 +34,10 @@ template <typename T> T CircleNode::accept(CircleNodeVisitorBase<T> *v) const
\
case CircleOpcode::OPCODE: \
return v->visit(dynamic_cast<const CLASS *>(this));
+#define CIRCLE_VNODE CIRCLE_NODE
#include "CircleNodes.lst"
+#undef CIRCLE_VNODE
#undef CIRCLE_NODE
default:
@@ -53,8 +55,10 @@ template <typename T> T CircleNode::accept(CircleNodeMutableVisitorBase<T> *v)
\
case CircleOpcode::OPCODE: \
return v->visit(dynamic_cast<CLASS *>(this));
+#define CIRCLE_VNODE CIRCLE_NODE
#include "CircleNodes.lst"
+#undef CIRCLE_VNODE
#undef CIRCLE_NODE
default:
diff --git a/compiler/luci/lang/include/luci/IR/CircleNodeMixins.h b/compiler/luci/lang/include/luci/IR/CircleNodeMixins.h
new file mode 100644
index 000000000..3f8ab7d61
--- /dev/null
+++ b/compiler/luci/lang/include/luci/IR/CircleNodeMixins.h
@@ -0,0 +1,107 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IR_CIRCLE_NODE_MIXINS_H__
+#define __LUCI_IR_CIRCLE_NODE_MIXINS_H__
+
+#include "luci/IR/AttrFusedActFunc.h"
+
+#include <loco/IR/Node.h>
+#include <loco/IR/NodeMixins.h>
+
+#include <vector>
+
+namespace luci
+{
+
+/// @brief enumeration of mixin class
+enum class CircleNodeTrait
+{
+ FusedActFunc,
+ Bias
+};
+
+template <CircleNodeTrait T> class CircleNodeMixin;
+
+template <> class CircleNodeMixin<CircleNodeTrait::FusedActFunc>
+{
+public:
+ CircleNodeMixin() = default;
+
+public:
+ FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
+ void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }
+
+private:
+ FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED;
+};
+
+/**
+ * @brief Mixin class for nodes that has a bias input
+ */
+template <> class CircleNodeMixin<CircleNodeTrait::Bias>
+{
+public:
+ CircleNodeMixin() = default;
+
+public:
+ virtual loco::Node *bias(void) const = 0; /// @brief get the input for bias.
+ virtual void bias(loco::Node *node) = 0; /// @brief set the input for bias.
+};
+
+/**
+ * @brief Nodes with the fixed number of inputs
+ *
+ * TODO Deprecated this class, and use loco::FixedArity instead
+ */
+template <unsigned N, typename Base> class FixedArityNode : public Base
+{
+public:
+ FixedArityNode()
+ {
+ _args.resize(N);
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _args[n] = std::make_unique<loco::Use>(this);
+ }
+ }
+
+ virtual ~FixedArityNode() = default;
+
+public:
+ unsigned arity(void) const final { return N; }
+
+ loco::Node *arg(uint32_t n) const final { return _args.at(n)->node(); }
+
+ void drop(void) final
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _args.at(n)->node(nullptr);
+ }
+ }
+
+protected:
+ // This API allows inherited classes to access "_args" field.
+ loco::Use *at(unsigned n) const { return _args.at(n).get(); }
+
+private:
+ std::vector<std::unique_ptr<loco::Use>> _args{};
+};
+
+} // namespace luci
+
+#endif // __LUCI_IR_CIRCLE_NODE_MIXINS_H__
diff --git a/compiler/luci/lang/include/luci/IR/CircleNodeVisitor.h b/compiler/luci/lang/include/luci/IR/CircleNodeVisitor.h
index 43339fe84..599e4bcd9 100644
--- a/compiler/luci/lang/include/luci/IR/CircleNodeVisitor.h
+++ b/compiler/luci/lang/include/luci/IR/CircleNodeVisitor.h
@@ -33,8 +33,10 @@ template <typename T> struct CircleNodeVisitorBase
virtual ~CircleNodeVisitorBase() = default;
#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) virtual T visit(const CIRCLE_CLASS *) = 0;
+#define CIRCLE_VNODE CIRCLE_NODE
#include "CircleNodes.lst"
+#undef CIRCLE_VNODE
#undef CIRCLE_NODE
};
@@ -44,9 +46,11 @@ template <typename T> struct CircleNodeVisitor : public CircleNodeVisitorBase<T>
#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) \
virtual T visit(const CIRCLE_CLASS *node) { return visit(static_cast<const CircleNode *>(node)); }
+#define CIRCLE_VNODE CIRCLE_NODE
#include "CircleNodes.lst"
+#undef CIRCLE_VNODE
#undef CIRCLE_NODE
/// @brief Default fallback
@@ -61,9 +65,11 @@ template <typename T> struct CircleNodeMutableVisitorBase
virtual ~CircleNodeMutableVisitorBase() = default;
#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) virtual T visit(CIRCLE_CLASS *) = 0;
+#define CIRCLE_VNODE CIRCLE_NODE
#include "CircleNodes.lst"
+#undef CIRCLE_VNODE
#undef CIRCLE_NODE
};
@@ -73,9 +79,11 @@ template <typename T> struct CircleNodeMutableVisitor : public CircleNodeMutable
#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) \
virtual T visit(CIRCLE_CLASS *node) { return visit(static_cast<CircleNode *>(node)); }
+#define CIRCLE_VNODE CIRCLE_NODE
#include "CircleNodes.lst"
+#undef CIRCLE_VNODE
#undef CIRCLE_NODE
/// @brief Default fallback
diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.h b/compiler/luci/lang/include/luci/IR/CircleNodes.h
index fde0b612b..69a82a7b9 100644
--- a/compiler/luci/lang/include/luci/IR/CircleNodes.h
+++ b/compiler/luci/lang/include/luci/IR/CircleNodes.h
@@ -25,6 +25,7 @@
#include "Nodes/CircleAveragePool2D.h"
#include "Nodes/CircleBatchMatMul.h"
#include "Nodes/CircleBatchToSpaceND.h"
+#include "Nodes/CircleBidirectionalSequenceLSTM.h"
#include "Nodes/CircleCast.h"
#include "Nodes/CircleCeil.h"
#include "Nodes/CircleConcatenation.h"
@@ -40,6 +41,7 @@
#include "Nodes/CircleEqual.h"
#include "Nodes/CircleExp.h"
#include "Nodes/CircleExpandDims.h"
+#include "Nodes/CircleFakeQuant.h"
#include "Nodes/CircleFill.h"
#include "Nodes/CircleFloor.h"
#include "Nodes/CircleFloorDiv.h"
@@ -134,6 +136,7 @@
// Virtual nodes
#include "Nodes/CircleInput.h"
#include "Nodes/CircleOutput.h"
+#include "Nodes/CircleBidirectionalSequenceLSTMOut.h"
#include "Nodes/CircleCustomOut.h"
#include "Nodes/CircleIfOut.h"
#include "Nodes/CircleNonMaxSuppressionV4Out.h"
@@ -150,15 +153,6 @@
namespace luci
{
-/**
- * @brief Set both CircleReshape's 2nd input as CircleConst, and newShape attribute
- * with same value
- * @note Shape inference for TFLReshape forces them to be same
- *
- * TODO find better place for this helper
- */
-void set_new_shape(CircleReshape *node, int32_t *base, uint32_t size);
-
/// @brief Link GraphOutput with CircleOutput node
void link(loco::GraphOutput *, CircleOutput *);
diff --git a/compiler/luci/lang/include/luci/IR/CircleNodes.lst b/compiler/luci/lang/include/luci/IR/CircleNodes.lst
index b9d545893..b93fdc89d 100644
--- a/compiler/luci/lang/include/luci/IR/CircleNodes.lst
+++ b/compiler/luci/lang/include/luci/IR/CircleNodes.lst
@@ -2,6 +2,10 @@
#error "Define CIRCLE_NODE"
#endif // CIRCLE_NODE
+#ifndef CIRCLE_VNODE
+#error "Define CIRCLE_VNODE"
+#endif // CIRCLE_VNODE
+
//
// PLEASE SORT NODE DECLS IN ALPHABETICAL ORDER
//
@@ -18,7 +22,8 @@ CIRCLE_NODE(ARG_MAX, luci::CircleArgMax)
CIRCLE_NODE(ARG_MIN, luci::CircleArgMin)
CIRCLE_NODE(AVERAGE_POOL_2D, luci::CircleAveragePool2D)
CIRCLE_NODE(BATCH_TO_SPACE_ND, luci::CircleBatchToSpaceND)
-CIRCLE_NODE(BATCHMATMUL, luci::CircleBatchMatMul)
+CIRCLE_NODE(BATCH_MATMUL, luci::CircleBatchMatMul)
+CIRCLE_NODE(BIDIRECTIONAL_SEQUENCE_LSTM, luci::CircleBidirectionalSequenceLSTM)
CIRCLE_NODE(CAST, luci::CircleCast)
CIRCLE_NODE(CEIL, luci::CircleCeil)
CIRCLE_NODE(CONCATENATION, luci::CircleConcatenation)
@@ -33,6 +38,7 @@ CIRCLE_NODE(ELU, luci::CircleElu)
CIRCLE_NODE(EQUAL, luci::CircleEqual)
CIRCLE_NODE(EXP, luci::CircleExp)
CIRCLE_NODE(EXPAND_DIMS, luci::CircleExpandDims)
+CIRCLE_NODE(FAKE_QUANT, luci::CircleFakeQuant)
CIRCLE_NODE(FILL, luci::CircleFill)
CIRCLE_NODE(FLOOR, luci::CircleFloor)
CIRCLE_NODE(FLOOR_DIV, luci::CircleFloorDiv)
@@ -125,18 +131,19 @@ CIRCLE_NODE(BCQ_FULLY_CONNECTED, luci::CircleBCQFullyConnected)
CIRCLE_NODE(BCQ_GATHER, luci::CircleBCQGather)
CIRCLE_NODE(INSTANCE_NORM, luci::CircleInstanceNorm)
// Virtual node(s)
-CIRCLE_NODE(CIRCLECONST, luci::CircleConst)
-CIRCLE_NODE(CIRCLEINPUT, luci::CircleInput)
-CIRCLE_NODE(CIRCLEOUTPUT, luci::CircleOutput)
-CIRCLE_NODE(CIRCLEOUTPUTDUMMY, luci::CircleOutputDummy)
-CIRCLE_NODE(CIRCLEOUTPUTEXCLUDE, luci::CircleOutputExclude)
-CIRCLE_NODE(CIRCLECUSTOMOUT, luci::CircleCustomOut)
-CIRCLE_NODE(CIRCLEIFOUT, luci::CircleIfOut)
-CIRCLE_NODE(CIRCLENONMAXSUPPRESSIONV4OUT, luci::CircleNonMaxSuppressionV4Out)
-CIRCLE_NODE(CIRCLENONMAXSUPPRESSIONV5OUT, luci::CircleNonMaxSuppressionV5Out)
-CIRCLE_NODE(CIRCLESPLITOUT, luci::CircleSplitOut)
-CIRCLE_NODE(CIRCLESPLITVOUT, luci::CircleSplitVOut)
-CIRCLE_NODE(CIRCLETOPKV2OUT, luci::CircleTopKV2Out)
-CIRCLE_NODE(CIRCLEUNIQUEOUT, luci::CircleUniqueOut)
-CIRCLE_NODE(CIRCLEUNPACKOUT, luci::CircleUnpackOut)
-CIRCLE_NODE(CIRCLEWHILEOUT, luci::CircleWhileOut)
+CIRCLE_VNODE(CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT, luci::CircleBidirectionalSequenceLSTMOut)
+CIRCLE_VNODE(CIRCLECONST, luci::CircleConst)
+CIRCLE_VNODE(CIRCLEINPUT, luci::CircleInput)
+CIRCLE_VNODE(CIRCLEOUTPUT, luci::CircleOutput)
+CIRCLE_VNODE(CIRCLEOUTPUTDUMMY, luci::CircleOutputDummy)
+CIRCLE_VNODE(CIRCLEOUTPUTEXCLUDE, luci::CircleOutputExclude)
+CIRCLE_VNODE(CIRCLECUSTOMOUT, luci::CircleCustomOut)
+CIRCLE_VNODE(CIRCLEIFOUT, luci::CircleIfOut)
+CIRCLE_VNODE(CIRCLENONMAXSUPPRESSIONV4OUT, luci::CircleNonMaxSuppressionV4Out)
+CIRCLE_VNODE(CIRCLENONMAXSUPPRESSIONV5OUT, luci::CircleNonMaxSuppressionV5Out)
+CIRCLE_VNODE(CIRCLESPLITOUT, luci::CircleSplitOut)
+CIRCLE_VNODE(CIRCLESPLITVOUT, luci::CircleSplitVOut)
+CIRCLE_VNODE(CIRCLETOPKV2OUT, luci::CircleTopKV2Out)
+CIRCLE_VNODE(CIRCLEUNIQUEOUT, luci::CircleUniqueOut)
+CIRCLE_VNODE(CIRCLEUNPACKOUT, luci::CircleUnpackOut)
+CIRCLE_VNODE(CIRCLEWHILEOUT, luci::CircleWhileOut)
diff --git a/compiler/luci/lang/include/luci/IR/CircleOpcode.h b/compiler/luci/lang/include/luci/IR/CircleOpcode.h
index 703b70da2..be3069f94 100644
--- a/compiler/luci/lang/include/luci/IR/CircleOpcode.h
+++ b/compiler/luci/lang/include/luci/IR/CircleOpcode.h
@@ -23,7 +23,9 @@ namespace luci
enum class CircleOpcode
{
#define CIRCLE_NODE(OPCODE, CLASS) OPCODE,
+#define CIRCLE_VNODE CIRCLE_NODE
#include "CircleNodes.lst"
+#undef CIRCLE_VNODE
#undef CIRCLE_NODE
};
diff --git a/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h b/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h
deleted file mode 100644
index 18a260486..000000000
--- a/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __LUCI_IR_SHAPE_SIGNATURE_H__
-#define __LUCI_IR_SHAPE_SIGNATURE_H__
-
-#include <stdint.h>
-#include <vector>
-
-namespace luci
-{
-
-class ShapeSignature
-{
-public:
- ShapeSignature() = default;
-
- ShapeSignature(const std::vector<int32_t> &shape_signature)
- {
- _shape_signature = shape_signature;
- }
-
-public:
- const std::vector<int32_t> &as_vector() const { return _shape_signature; }
-
- int32_t dim(uint32_t d) const { return _shape_signature.at(d); }
- int32_t &dim(uint32_t d) { return _shape_signature.at(d); }
-
- uint32_t rank(void) const { return _shape_signature.size(); }
- void rank(uint32_t rank) { _shape_signature.resize(rank); }
-
-private:
- std::vector<int32_t> _shape_signature{};
-};
-
-bool operator==(const ShapeSignature &lhs, const ShapeSignature &rhs);
-
-} // namespace luci
-
-#endif // __LUCI_IR_SHAPE_SIGNATURE_H__
diff --git a/compiler/luci/lang/src/DeadNodeQueryService.h b/compiler/luci/lang/include/luci/IR/DeadNodeQueryService.h
index d10696667..d10696667 100644
--- a/compiler/luci/lang/src/DeadNodeQueryService.h
+++ b/compiler/luci/lang/include/luci/IR/DeadNodeQueryService.h
diff --git a/compiler/luci/lang/include/luci/IR/LuciNodeMixins.h b/compiler/luci/lang/include/luci/IR/LuciNodeMixins.h
index c1bb0db11..2078495c6 100644
--- a/compiler/luci/lang/include/luci/IR/LuciNodeMixins.h
+++ b/compiler/luci/lang/include/luci/IR/LuciNodeMixins.h
@@ -17,90 +17,16 @@
#ifndef __LUCI_IR_LUCINODEMIXINS_H__
#define __LUCI_IR_LUCINODEMIXINS_H__
-#include "luci/IR/AttrFusedActFunc.h"
+// TODO remove this file after LuciNodeTrait and LuciNodeMixin are not used in backend
-#include <loco/IR/Node.h>
-#include <loco/IR/NodeMixins.h>
-
-#include <vector>
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
-/// @brief enumeration of mixin class
-enum class LuciNodeTrait
-{
- FusedActFunc,
- Bias
-};
-
-template <LuciNodeTrait T> class LuciNodeMixin;
-
-template <> class LuciNodeMixin<LuciNodeTrait::FusedActFunc>
-{
-public:
- LuciNodeMixin() = default;
-
-public:
- FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
- void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }
-
-private:
- FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED;
-};
-
-/**
- * @brief Mixin class for nodes that has a bias input
- */
-template <> class LuciNodeMixin<LuciNodeTrait::Bias>
-{
-public:
- LuciNodeMixin() = default;
-
-public:
- virtual loco::Node *bias(void) const = 0; /// @brief get the input for bias.
- virtual void bias(loco::Node *node) = 0; /// @brief set the input for bias.
-};
-
-/**
- * @brief Nodes with the fixed number of inputs
- *
- * TODO Deprecated this class, and use loco::FixedArity instead
- */
-template <unsigned N, typename Base> class FixedArityNode : public Base
-{
-public:
- FixedArityNode()
- {
- _args.resize(N);
- for (uint32_t n = 0; n < N; ++n)
- {
- _args[n] = std::make_unique<loco::Use>(this);
- }
- }
-
- virtual ~FixedArityNode() = default;
-
-public:
- unsigned arity(void) const final { return N; }
-
- loco::Node *arg(uint32_t n) const final { return _args.at(n)->node(); }
-
- void drop(void) final
- {
- for (uint32_t n = 0; n < N; ++n)
- {
- _args.at(n)->node(nullptr);
- }
- }
-
-protected:
- // This API allows inherited classes to access "_args" field.
- loco::Use *at(unsigned n) const { return _args.at(n).get(); }
+using LuciNodeTrait = CircleNodeTrait;
-private:
- std::vector<std::unique_ptr<loco::Use>> _args{};
-};
+template <LuciNodeTrait T> using LuciNodeMixin = CircleNodeMixin<T>;
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleAbs.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleAbs.h
index 45dba15bf..7a73f37cd 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleAbs.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleAbs.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleAdd.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleAdd.h
index f26eccd1a..92563de4c 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleAdd.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleAdd.h
@@ -21,7 +21,7 @@
#include "luci/IR/CircleOpcode.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -30,7 +30,7 @@ namespace luci
* @brief ADD in Circle
*/
class CircleAdd final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::ADD>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
{
public:
loco::Node *x(void) const { return at(0)->node(); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMax.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMax.h
index dbc4b2b3a..c1e4631e4 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMax.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMax.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMin.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMin.h
index 8cb561983..b4d026201 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMin.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleArgMin.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleAveragePool2D.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleAveragePool2D.h
index 0b43b40c8..4aa45c2d8 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleAveragePool2D.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleAveragePool2D.h
@@ -24,7 +24,7 @@
#include "luci/IR/AttrPadding.h"
#include "luci/IR/AttrStride.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -33,16 +33,14 @@ namespace luci
* @brief AVERAGE_POOL_2D in Circle
*/
class CircleAveragePool2D final
- : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::AVERAGE_POOL_2D>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>
+ : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::AVERAGE_POOL_2D>>,
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
{
public:
- CircleAveragePool2D() : _padding(Padding::UNDEFINED) { /* empty */}
-
-public:
loco::Node *value(void) const { return at(0)->node(); }
void value(loco::Node *node) { at(0)->node(node); }
+public:
Padding padding() const { return _padding; }
void padding(Padding padding) { _padding = padding; }
@@ -53,7 +51,7 @@ public:
Stride *stride(void) { return &_stride; }
private:
- Padding _padding;
+ Padding _padding{Padding::UNDEFINED};
Stride _stride;
Filter _filter;
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQFullyConnected.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQFullyConnected.h
index 7d12d593a..4c164ebca 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQFullyConnected.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQFullyConnected.h
@@ -21,7 +21,7 @@
#include "luci/IR/CircleOpcode.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -30,9 +30,9 @@ namespace luci
* @brief BCQ_FULLY_CONNECTED in Circle
*/
class CircleBCQFullyConnected final
- : public FixedArityNode<5, CircleNodeImpl<CircleOpcode::BCQ_FULLY_CONNECTED>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>,
- public LuciNodeMixin<LuciNodeTrait::Bias>
+ : public FixedArityNode<5, CircleNodeImpl<CircleOpcode::BCQ_FULLY_CONNECTED>>,
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>,
+ public CircleNodeMixin<CircleNodeTrait::Bias>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
@@ -58,7 +58,7 @@ public:
}
private:
- int32_t _weights_hidden_size = 0;
+ int32_t _weights_hidden_size{0};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQGather.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQGather.h
index f7638261d..1a0bf4f19 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQGather.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleBCQGather.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -51,8 +51,8 @@ public:
void input_hidden_size(int32_t input_hidden_size) { _input_hidden_size = input_hidden_size; }
private:
- int32_t _axis = 0;
- int32_t _input_hidden_size = 0;
+ int32_t _axis{0};
+ int32_t _input_hidden_size{0};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchMatMul.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchMatMul.h
index 19999924e..864b033ed 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchMatMul.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchMatMul.h
@@ -20,15 +20,15 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
/**
- * @brief BATCHMATMUL in Circle
+ * @brief BATCH_MATMUL in Circle
*/
-class CircleBatchMatMul final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::BATCHMATMUL>>
+class CircleBatchMatMul final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::BATCH_MATMUL>>
{
public:
loco::Node *x(void) const { return at(0)->node(); }
@@ -45,8 +45,8 @@ public:
void adj_y(bool arg) { _adj_y = arg; }
private:
- bool _adj_x = false;
- bool _adj_y = false;
+ bool _adj_x{false};
+ bool _adj_y{false};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchToSpaceND.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchToSpaceND.h
index 67c0a2102..80fa53b8e 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchToSpaceND.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleBatchToSpaceND.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,7 +29,7 @@ namespace luci
* @brief BATCH_TO_SPACE_ND in Circle
*/
class CircleBatchToSpaceND final
- : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::BATCH_TO_SPACE_ND>>
+ : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::BATCH_TO_SPACE_ND>>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleBidirectionalSequenceLSTM.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleBidirectionalSequenceLSTM.h
new file mode 100644
index 000000000..d16281b69
--- /dev/null
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleBidirectionalSequenceLSTM.h
@@ -0,0 +1,172 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IR_CIRCLEBIDIRECTIONALSEQUENCE_LSTM_H__
+#define __LUCI_IR_CIRCLEBIDIRECTIONALSEQUENCE_LSTM_H__
+
+#include "luci/IR/CircleNodeDecl.h"
+#include "luci/IR/CircleOpcode.h"
+
+#include "luci/IR/AttrFusedActFunc.h"
+#include "luci/IR/CircleNodeMixins.h"
+
+namespace luci
+{
+
+/**
+ * @brief BIDIRECTIONAL_SEQUENCE_LSTM in Circle
+ */
+class CircleBidirectionalSequenceLSTM final
+ : public FixedArityNode<48, CircleNodeImpl<CircleOpcode::BIDIRECTIONAL_SEQUENCE_LSTM>>,
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
+{
+public:
+ loco::Node *input(void) const { return at(0)->node(); }
+ void input(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *fw_input_to_input_weights(void) const { return at(1)->node(); }
+ void fw_input_to_input_weights(loco::Node *node) { at(1)->node(node); }
+ loco::Node *fw_input_to_forget_weights(void) const { return at(2)->node(); }
+ void fw_input_to_forget_weights(loco::Node *node) { at(2)->node(node); }
+ loco::Node *fw_input_to_cell_weights(void) const { return at(3)->node(); }
+ void fw_input_to_cell_weights(loco::Node *node) { at(3)->node(node); }
+ loco::Node *fw_input_to_output_weights(void) const { return at(4)->node(); }
+ void fw_input_to_output_weights(loco::Node *node) { at(4)->node(node); }
+
+ loco::Node *fw_recurrent_to_input_weights(void) const { return at(5)->node(); }
+ void fw_recurrent_to_input_weights(loco::Node *node) { at(5)->node(node); }
+ loco::Node *fw_recurrent_to_forget_weights(void) const { return at(6)->node(); }
+ void fw_recurrent_to_forget_weights(loco::Node *node) { at(6)->node(node); }
+ loco::Node *fw_recurrent_to_cell_weights(void) const { return at(7)->node(); }
+ void fw_recurrent_to_cell_weights(loco::Node *node) { at(7)->node(node); }
+ loco::Node *fw_recurrent_to_output_weights(void) const { return at(8)->node(); }
+ void fw_recurrent_to_output_weights(loco::Node *node) { at(8)->node(node); }
+
+ loco::Node *fw_cell_to_input_weights(void) const { return at(9)->node(); }
+ void fw_cell_to_input_weights(loco::Node *node) { at(9)->node(node); }
+ loco::Node *fw_cell_to_forget_weights(void) const { return at(10)->node(); }
+ void fw_cell_to_forget_weights(loco::Node *node) { at(10)->node(node); }
+ loco::Node *fw_cell_to_output_weights(void) const { return at(11)->node(); }
+ void fw_cell_to_output_weights(loco::Node *node) { at(11)->node(node); }
+
+ loco::Node *fw_input_gate_bias(void) const { return at(12)->node(); }
+ void fw_input_gate_bias(loco::Node *node) { at(12)->node(node); }
+ loco::Node *fw_forget_gate_bias(void) const { return at(13)->node(); }
+ void fw_forget_gate_bias(loco::Node *node) { at(13)->node(node); }
+ loco::Node *fw_cell_gate_bias(void) const { return at(14)->node(); }
+ void fw_cell_gate_bias(loco::Node *node) { at(14)->node(node); }
+ loco::Node *fw_output_gate_bias(void) const { return at(15)->node(); }
+ void fw_output_gate_bias(loco::Node *node) { at(15)->node(node); }
+
+ loco::Node *fw_projection_weights(void) const { return at(16)->node(); }
+ void fw_projection_weights(loco::Node *node) { at(16)->node(node); }
+ loco::Node *fw_projection_bias(void) const { return at(17)->node(); }
+ void fw_projection_bias(loco::Node *node) { at(17)->node(node); }
+
+ loco::Node *bw_input_to_input_weights(void) const { return at(18)->node(); }
+ void bw_input_to_input_weights(loco::Node *node) { at(18)->node(node); }
+ loco::Node *bw_input_to_forget_weights(void) const { return at(19)->node(); }
+ void bw_input_to_forget_weights(loco::Node *node) { at(19)->node(node); }
+ loco::Node *bw_input_to_cell_weights(void) const { return at(20)->node(); }
+ void bw_input_to_cell_weights(loco::Node *node) { at(20)->node(node); }
+ loco::Node *bw_input_to_output_weights(void) const { return at(21)->node(); }
+ void bw_input_to_output_weights(loco::Node *node) { at(21)->node(node); }
+
+ loco::Node *bw_recurrent_to_input_weights(void) const { return at(22)->node(); }
+ void bw_recurrent_to_input_weights(loco::Node *node) { at(22)->node(node); }
+ loco::Node *bw_recurrent_to_forget_weights(void) const { return at(23)->node(); }
+ void bw_recurrent_to_forget_weights(loco::Node *node) { at(23)->node(node); }
+ loco::Node *bw_recurrent_to_cell_weights(void) const { return at(24)->node(); }
+ void bw_recurrent_to_cell_weights(loco::Node *node) { at(24)->node(node); }
+ loco::Node *bw_recurrent_to_output_weights(void) const { return at(25)->node(); }
+ void bw_recurrent_to_output_weights(loco::Node *node) { at(25)->node(node); }
+
+ loco::Node *bw_cell_to_input_weights(void) const { return at(26)->node(); }
+ void bw_cell_to_input_weights(loco::Node *node) { at(26)->node(node); }
+ loco::Node *bw_cell_to_forget_weights(void) const { return at(27)->node(); }
+ void bw_cell_to_forget_weights(loco::Node *node) { at(27)->node(node); }
+ loco::Node *bw_cell_to_output_weights(void) const { return at(28)->node(); }
+ void bw_cell_to_output_weights(loco::Node *node) { at(28)->node(node); }
+
+ loco::Node *bw_input_gate_bias(void) const { return at(29)->node(); }
+ void bw_input_gate_bias(loco::Node *node) { at(29)->node(node); }
+ loco::Node *bw_forget_gate_bias(void) const { return at(30)->node(); }
+ void bw_forget_gate_bias(loco::Node *node) { at(30)->node(node); }
+ loco::Node *bw_cell_gate_bias(void) const { return at(31)->node(); }
+ void bw_cell_gate_bias(loco::Node *node) { at(31)->node(node); }
+ loco::Node *bw_output_gate_bias(void) const { return at(32)->node(); }
+ void bw_output_gate_bias(loco::Node *node) { at(32)->node(node); }
+
+ loco::Node *bw_projection_weights(void) const { return at(33)->node(); }
+ void bw_projection_weights(loco::Node *node) { at(33)->node(node); }
+ loco::Node *bw_projection_bias(void) const { return at(34)->node(); }
+ void bw_projection_bias(loco::Node *node) { at(34)->node(node); }
+
+ loco::Node *fw_activation_state(void) const { return at(35)->node(); }
+ void fw_activation_state(loco::Node *node) { at(35)->node(node); }
+ loco::Node *fw_cell_state(void) const { return at(36)->node(); }
+ void fw_cell_state(loco::Node *node) { at(36)->node(node); }
+
+ loco::Node *bw_activation_state(void) const { return at(37)->node(); }
+ void bw_activation_state(loco::Node *node) { at(37)->node(node); }
+ loco::Node *bw_cell_state(void) const { return at(38)->node(); }
+ void bw_cell_state(loco::Node *node) { at(38)->node(node); }
+
+ loco::Node *auxillary_input(void) const { return at(39)->node(); }
+ void auxillary_input(loco::Node *node) { at(39)->node(node); }
+ loco::Node *fw_auxillary_input_to_input_weights(void) const { return at(40)->node(); }
+ void fw_auxillary_input_to_input_weights(loco::Node *node) { at(40)->node(node); }
+ loco::Node *fw_auxillary_input_to_forget_weights(void) const { return at(41)->node(); }
+ void fw_auxillary_input_to_forget_weights(loco::Node *node) { at(41)->node(node); }
+ loco::Node *fw_auxillary_input_to_cell_weights(void) const { return at(42)->node(); }
+ void fw_auxillary_input_to_cell_weights(loco::Node *node) { at(42)->node(node); }
+ loco::Node *fw_auxillary_input_to_output_weights(void) const { return at(43)->node(); }
+ void fw_auxillary_input_to_output_weights(loco::Node *node) { at(43)->node(node); }
+ loco::Node *bw_auxillary_input_to_input_weights(void) const { return at(44)->node(); }
+ void bw_auxillary_input_to_input_weights(loco::Node *node) { at(44)->node(node); }
+ loco::Node *bw_auxillary_input_to_forget_weights(void) const { return at(45)->node(); }
+ void bw_auxillary_input_to_forget_weights(loco::Node *node) { at(45)->node(node); }
+ loco::Node *bw_auxillary_input_to_cell_weights(void) const { return at(46)->node(); }
+ void bw_auxillary_input_to_cell_weights(loco::Node *node) { at(46)->node(node); }
+ loco::Node *bw_auxillary_input_to_output_weights(void) const { return at(47)->node(); }
+ void bw_auxillary_input_to_output_weights(loco::Node *node) { at(47)->node(node); }
+
+public:
+ float cell_clip(void) const { return _cell_clip; }
+ void cell_clip(float cell_clip) { _cell_clip = cell_clip; }
+ float proj_clip(void) const { return _proj_clip; }
+ void proj_clip(float proj_clip) { _proj_clip = proj_clip; }
+ bool merge_outputs(void) const { return _merge_outputs; }
+ void merge_outputs(bool merge_outputs) { _merge_outputs = merge_outputs; }
+ bool time_major(void) const { return _time_major; }
+ void time_major(bool time_major) { _time_major = time_major; }
+ bool asymmetric_quantize_inputs(void) const { return _asymmetric_quantize_inputs; }
+ void asymmetric_quantize_inputs(bool asymmetric_quantize_inputs)
+ {
+ _asymmetric_quantize_inputs = asymmetric_quantize_inputs;
+ }
+
+private:
+ float _cell_clip{0.0f};
+ float _proj_clip{0.0f};
+ bool _merge_outputs{false};
+ bool _time_major{false};
+ bool _asymmetric_quantize_inputs{false};
+};
+
+} // namespace luci
+
+#endif // __LUCI_IR_CIRCLEBIDIRECTIONALSEQUENCE_LSTM_H__
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleBidirectionalSequenceLSTMOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleBidirectionalSequenceLSTMOut.h
new file mode 100644
index 000000000..fb2eb0831
--- /dev/null
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleBidirectionalSequenceLSTMOut.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IR_CIRCLE_BIDIRECTIONAL_SEQUENCE_LSTM_OUT_H__
+#define __LUCI_IR_CIRCLE_BIDIRECTIONAL_SEQUENCE_LSTM_OUT_H__
+
+#include "luci/IR/CircleNodeDecl.h"
+#include "luci/IR/CircleOpcode.h"
+
+#include "luci/IR/CircleNodeMixins.h"
+
+namespace luci
+{
+
+/**
+ * @brief Virtual CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT in Circle
+ */
+class CircleBidirectionalSequenceLSTMOut final
+ : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEBIDIRECTIONAL_SEQUENCE_LSTM_OUT>>
+{
+public:
+ loco::Node *input(void) const { return at(0)->node(); }
+ void input(loco::Node *node) { at(0)->node(node); }
+
+public:
+ int32_t index(void) const { return _index; }
+ void index(int32_t index) { _index = index; }
+
+private:
+ int32_t _index{-1};
+};
+
+} // namespace luci
+
+#endif // __LUCI_IR_CIRCLE_BIDIRECTIONAL_SEQUENCE_LSTM_OUT_H__
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleCast.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleCast.h
index 9a89d0b2b..0b793607f 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleCast.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleCast.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleCeil.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleCeil.h
index 8a8715dcf..3d7a7ebc7 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleCeil.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleCeil.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleConcatenation.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleConcatenation.h
index dea1a4613..2746a0a2e 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleConcatenation.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleConcatenation.h
@@ -21,7 +21,7 @@
#include "luci/IR/CircleOpcode.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
#include "luci/IR/VariadicArityNode.h"
#include <cassert>
@@ -33,12 +33,12 @@ namespace luci
* @brief CONCATENATION in Circle
*/
class CircleConcatenation final
- : public VariadicArityNode<CircleNodeImpl<CircleOpcode::CONCATENATION>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>
+ : public VariadicArityNode<CircleNodeImpl<CircleOpcode::CONCATENATION>>,
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
{
public:
CircleConcatenation(uint32_t arity)
- : VariadicArityNode<CircleNodeImpl<CircleOpcode::CONCATENATION>>(arity)
+ : VariadicArityNode<CircleNodeImpl<CircleOpcode::CONCATENATION>>(arity)
{
// TODO Support when arity is 0
assert(arity >= 1);
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleConst.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleConst.h
index 250282049..e44363d14 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleConst.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleConst.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
#include <loco/IR/DataTypeTraits.h>
@@ -34,9 +34,6 @@ namespace luci
class CircleConst final : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLECONST>>
{
public:
- CircleConst() = default;
-
-public:
template <loco::DataType DT> uint32_t size(void) const;
template <loco::DataType DT> void size(uint32_t size);
template <loco::DataType DT> const typename loco::DataTypeImpl<DT>::Type &at(uint32_t n) const;
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleConv2D.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleConv2D.h
index 13657cee4..7c390940e 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleConv2D.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleConv2D.h
@@ -24,7 +24,7 @@
#include "luci/IR/AttrStride.h"
#include "luci/IR/AttrDilation.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -33,8 +33,8 @@ namespace luci
* @brief CONV_2D in Circle
*/
class CircleConv2D final : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::CONV_2D>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>,
- public LuciNodeMixin<LuciNodeTrait::Bias>
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>,
+ public CircleNodeMixin<CircleNodeTrait::Bias>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
@@ -57,7 +57,7 @@ public:
Dilation *dilation(void) { return &_dilation; }
private:
- Padding _padding = Padding::UNDEFINED;
+ Padding _padding{Padding::UNDEFINED};
Stride _stride;
Dilation _dilation;
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleCos.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleCos.h
index 07ced620a..cff04906d 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleCos.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleCos.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleCustom.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleCustom.h
index 6c722b766..b21cc679f 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleCustom.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleCustom.h
@@ -29,19 +29,23 @@ namespace luci
class CircleCustom final : public VariadicArityNode<CircleNodeImpl<CircleOpcode::CUSTOM>>
{
public:
- CircleCustom(uint32_t arity) : VariadicArityNode<CircleNodeImpl<CircleOpcode::CUSTOM>>(arity)
+ CircleCustom(uint32_t arity, uint32_t out)
+ : VariadicArityNode<CircleNodeImpl<CircleOpcode::CUSTOM>>(arity), _output_count(out)
{
// TODO Support when arity is 0
assert(arity >= 1);
+ assert(out > 0);
}
public:
uint32_t numInputs(void) const { return arity(); }
+ uint32_t numOutputs(void) const { return _output_count; }
public:
Node *inputs(uint32_t index) const { return at(index)->node(); }
void inputs(uint32_t index, Node *node) { at(index)->node(node); }
+public:
const std::vector<uint8_t> &custom_options(void) const { return _custom_options; }
void custom_options(const std::vector<uint8_t> &custom_options)
{
@@ -54,6 +58,7 @@ public:
private:
std::vector<uint8_t> _custom_options;
std::string _custom_code;
+ uint32_t _output_count{0};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleCustomOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleCustomOut.h
index 36b8e4aed..91a89c151 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleCustomOut.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleCustomOut.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,12 +29,9 @@ namespace luci
* @brief Virtual CIRCLECUSTOMOUT in Circle
*/
class CircleCustomOut final
- : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLECUSTOMOUT>>
+ : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLECUSTOMOUT>>
{
public:
- CircleCustomOut() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthToSpace.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthToSpace.h
index e19282b97..85b567fb7 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthToSpace.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthToSpace.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,18 +29,18 @@ namespace luci
* @brief DEPTH_TO_SPACE in Circle
*/
class CircleDepthToSpace final
- : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::DEPTH_TO_SPACE>>
+ : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::DEPTH_TO_SPACE>>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
public:
- int block_size(void) const { return _block_size; }
- void block_size(int block_size) { _block_size = block_size; }
+ int32_t block_size(void) const { return _block_size; }
+ void block_size(int32_t block_size) { _block_size = block_size; }
private:
- int _block_size{0};
+ int32_t _block_size{0};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthwiseConv2D.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthwiseConv2D.h
index eb058cec1..046aa5908 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthwiseConv2D.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleDepthwiseConv2D.h
@@ -25,7 +25,7 @@
#include "luci/IR/AttrPadding.h"
#include "luci/IR/AttrStride.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -34,9 +34,9 @@ namespace luci
* @brief DEPTHWISE_CONV_2D in Circle
*/
class CircleDepthwiseConv2D final
- : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::DEPTHWISE_CONV_2D>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>,
- public LuciNodeMixin<LuciNodeTrait::Bias>
+ : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::DEPTHWISE_CONV_2D>>,
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>,
+ public CircleNodeMixin<CircleNodeTrait::Bias>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
@@ -62,9 +62,9 @@ public:
Dilation *dilation(void) { return &_dilation; }
private:
- Padding _padding = Padding::UNDEFINED;
+ Padding _padding{Padding::UNDEFINED};
Stride _stride;
- int32_t _depth_multiplier = 0;
+ int32_t _depth_multiplier{0};
Dilation _dilation;
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleDequantize.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleDequantize.h
index 847c5dfc5..c3ee44253 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleDequantize.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleDequantize.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleDiv.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleDiv.h
index 1d4d3a239..fcc3f427c 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleDiv.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleDiv.h
@@ -24,7 +24,7 @@
#include "luci/IR/AttrPadding.h"
#include "luci/IR/AttrStride.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -33,12 +33,9 @@ namespace luci
* @brief DIV in Circle
*/
class CircleDiv final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::DIV>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
{
public:
- CircleDiv() = default;
-
-public:
loco::Node *x(void) const { return at(0)->node(); }
void x(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleElu.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleElu.h
index fbb2f3533..721edd9ae 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleElu.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleElu.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleElu final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::ELU>>
{
public:
- CircleElu() = default;
-
-public:
loco::Node *features(void) const { return at(0)->node(); }
void features(loco::Node *node) { at(0)->node(node); }
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleEqual.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleEqual.h
index 2087d097a..69697ac7e 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleEqual.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleEqual.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleExp.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleExp.h
index 97aecb30a..b8a5d4561 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleExp.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleExp.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleExpandDims.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleExpandDims.h
index f70219614..15bfe6a29 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleExpandDims.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleExpandDims.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleExpandDims final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::EXPAND_DIMS>>
{
public:
- CircleExpandDims() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFakeQuant.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFakeQuant.h
new file mode 100644
index 000000000..9e3159685
--- /dev/null
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFakeQuant.h
@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_IR_CIRCLE_FAKE_QUANT_H__
+#define __LUCI_IR_CIRCLE_FAKE_QUANT_H__
+
+#include "luci/IR/CircleNodeDecl.h"
+#include "luci/IR/CircleOpcode.h"
+
+#include "luci/IR/CircleNodeMixins.h"
+
+namespace luci
+{
+
+/**
+ * @brief FAKE_QUANT in Circle
+ * @note 'inputs' came from TF.quantize.fake_quant_from_min_max_vars
+ */
+class CircleFakeQuant final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::FAKE_QUANT>>
+{
+public:
+ loco::Node *inputs(void) const { return at(0)->node(); }
+ void inputs(loco::Node *node) { at(0)->node(node); }
+
+public:
+ float min(void) const { return _min; }
+ void min(float min) { _min = min; }
+
+ float max(void) const { return _max; }
+ void max(float max) { _max = max; }
+
+ int32_t num_bits(void) const { return _num_bits; }
+ void num_bits(int32_t num_bits) { _num_bits = num_bits; }
+
+ bool narrow_range(void) const { return _narrow_range; }
+ void narrow_range(bool narrow_range) { _narrow_range = narrow_range; }
+
+private:
+ float _min{0.0f};
+ float _max{0.0f};
+ int32_t _num_bits{0};
+ bool _narrow_range{false};
+};
+
+} // namespace luci
+
+#endif // __LUCI_IR_CIRCLEGATHER_H__
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFill.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFill.h
index bfc65274a..183794d41 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFill.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFill.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFloor.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFloor.h
index 7e10547b6..ce6807e98 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFloor.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFloor.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorDiv.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorDiv.h
index ba9db010c..bf76e37b6 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorDiv.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorDiv.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorMod.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorMod.h
index 4d13717a0..1af0af758 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorMod.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFloorMod.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
index 952befc87..2862cadb2 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
@@ -21,7 +21,7 @@
#include "luci/IR/CircleOpcode.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -30,9 +30,9 @@ namespace luci
* @brief FULLY_CONNECTED in Circle
*/
class CircleFullyConnected final
- : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::FULLY_CONNECTED>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>,
- public LuciNodeMixin<LuciNodeTrait::Bias>
+ : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::FULLY_CONNECTED>>,
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>,
+ public CircleNodeMixin<CircleNodeTrait::Bias>
{
public:
enum class WeightsFormat
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleGather.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleGather.h
index 1e8c4982a..78fa2fc28 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleGather.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleGather.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -42,7 +42,7 @@ public:
void axis(int32_t axis) { _axis = axis; }
private:
- int32_t _axis = 0;
+ int32_t _axis{0};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleGatherNd.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleGatherNd.h
index 3423a8216..d6f34f1ea 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleGatherNd.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleGatherNd.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleGreater.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleGreater.h
index 040a4e338..a03b6c749 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleGreater.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleGreater.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleGreaterEqual.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleGreaterEqual.h
index 82bdab212..e435320b2 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleGreaterEqual.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleGreaterEqual.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,7 +29,7 @@ namespace luci
* @brief GREATER EQUAL in Circle
*/
class CircleGreaterEqual final
- : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::GREATER_EQUAL>>
+ : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::GREATER_EQUAL>>
{
public:
loco::Node *x(void) const { return at(0)->node(); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleIf.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleIf.h
index 2f9eac211..1c037a406 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleIf.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleIf.h
@@ -34,7 +34,7 @@ class CircleIf final : public VariadicArityNode<CircleNodeImpl<CircleOpcode::IF>
{
public:
CircleIf(uint32_t arity, uint32_t out)
- : VariadicArityNode<CircleNodeImpl<CircleOpcode::IF>>(arity + 1), _output_count(out)
+ : VariadicArityNode<CircleNodeImpl<CircleOpcode::IF>>(arity + 1), _output_count(out)
{
assert(arity > 0);
assert(out > 0);
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleIfOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleIfOut.h
index 3654e943b..5adaaa447 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleIfOut.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleIfOut.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleIfOut final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEIFOUT>>
{
public:
- CircleIfOut() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleInput.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleInput.h
index 4a7d36a4e..e0be9aa6e 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleInput.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleInput.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
#include <loco/IR/DataTypeTraits.h>
#include <loco/IR/GraphInputIndex.h>
@@ -35,16 +35,13 @@ namespace luci
class CircleInput final : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLEINPUT>>
{
public:
- CircleInput() = default;
-
-public:
void index(const loco::GraphInputIndex &index);
loco::GraphInputIndex index(void) const;
bool indexed(void) const { return _index != -1; }
private:
- int64_t _index = -1; // Uninitialized
+ int64_t _index{-1}; // Uninitialized
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleInstanceNorm.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleInstanceNorm.h
index db0faa05e..65c34194d 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleInstanceNorm.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleInstanceNorm.h
@@ -21,7 +21,7 @@
#include "luci/IR/CircleOpcode.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -30,8 +30,8 @@ namespace luci
* @brief INSTANCE_NORM in Circle
*/
class CircleInstanceNorm final
- : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::INSTANCE_NORM>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>
+ : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::INSTANCE_NORM>>,
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
{
public:
/// @note Currently only support FLOAT32 as input node
@@ -44,11 +44,12 @@ public:
loco::Node *beta(void) const { return at(2)->node(); }
void beta(loco::Node *node) { at(2)->node(node); }
+public:
float epsilon() const { return _epsilon; }
void epsilon(float epsilon) { _epsilon = epsilon; }
private:
- float _epsilon = 1e-05;
+ float _epsilon{1e-05};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Normalize.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Normalize.h
index efa932d95..eb2b372ce 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Normalize.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Normalize.h
@@ -21,7 +21,7 @@
#include "luci/IR/CircleOpcode.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -30,8 +30,8 @@ namespace luci
* @brief L2_NORMALIZATION in Circle
*/
class CircleL2Normalize final
- : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::L2_NORMALIZATION>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>
+ : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::L2_NORMALIZATION>>,
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
{
public:
loco::Node *x(void) const { return at(0)->node(); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Pool2D.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Pool2D.h
index 7c76ee5d0..624d29e9e 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Pool2D.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleL2Pool2D.h
@@ -24,7 +24,7 @@
#include "luci/IR/AttrPadding.h"
#include "luci/IR/AttrStride.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -33,15 +33,13 @@ namespace luci
* @brief L2_POOL_2D in Circle
*/
class CircleL2Pool2D final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::L2_POOL_2D>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
{
public:
- CircleL2Pool2D() : _padding(Padding::UNDEFINED) { /* empty */}
-
-public:
loco::Node *value(void) const { return at(0)->node(); }
void value(loco::Node *node) { at(0)->node(node); }
+public:
Padding padding() const { return _padding; }
void padding(Padding padding) { _padding = padding; }
@@ -52,7 +50,7 @@ public:
Stride *stride(void) { return &_stride; }
private:
- Padding _padding;
+ Padding _padding{Padding::UNDEFINED};
Stride _stride;
Filter _filter;
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLeakyRelu.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLeakyRelu.h
index d6ac97fc0..c8e93af91 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLeakyRelu.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLeakyRelu.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,17 +31,15 @@ namespace luci
class CircleLeakyRelu final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::LEAKY_RELU>>
{
public:
- CircleLeakyRelu() = default;
-
-public:
loco::Node *features(void) const { return at(0)->node(); }
void features(loco::Node *node) { at(0)->node(node); }
+public:
float alpha() const { return _alpha; }
void alpha(float alpha) { _alpha = alpha; }
private:
- float _alpha = 0.2f;
+ float _alpha{0.2f};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLess.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLess.h
index cd6cf1872..7adf67842 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLess.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLess.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLessEqual.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLessEqual.h
index 4c7c6a49b..eb8962494 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLessEqual.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLessEqual.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLocalResponseNormalization.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLocalResponseNormalization.h
index 8ad2b40fd..4d324700e 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLocalResponseNormalization.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLocalResponseNormalization.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,7 +29,7 @@ namespace luci
* @brief LOCAL_RESPONSE_NORMALIZATION in Circle
*/
class CircleLocalResponseNormalization final
- : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::LOCAL_RESPONSE_NORMALIZATION>>
+ : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::LOCAL_RESPONSE_NORMALIZATION>>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLog.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLog.h
index aeb13fed9..2cc57ce2d 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLog.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLog.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogSoftmax.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogSoftmax.h
index 5dfd2c1f9..b73ff7c2a 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogSoftmax.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogSoftmax.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalAnd.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalAnd.h
index 975f6dbc7..9943c71cd 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalAnd.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalAnd.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalNot.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalNot.h
index 749dbe518..369a3e7bf 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalNot.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalNot.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalOr.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalOr.h
index 570be57af..c54ec3ebf 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalOr.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogicalOr.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogistic.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogistic.h
index 8328cb328..1f95e0f77 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleLogistic.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleLogistic.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleLogistic final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::LOGISTIC>>
{
public:
- CircleLogistic() = default;
-
-public:
loco::Node *x(void) const { return at(0)->node(); }
void x(loco::Node *node) { at(0)->node(node); }
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixDiag.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixDiag.h
index dca6538c3..f8bf259f9 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixDiag.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixDiag.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixSetDiag.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixSetDiag.h
index c1f5f3023..76aeaff40 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixSetDiag.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMatrixSetDiag.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,7 +29,7 @@ namespace luci
* @brief MATRIX_SET_DIAG in Circle
*/
class CircleMatrixSetDiag final
- : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::MATRIX_SET_DIAG>>
+ : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::MATRIX_SET_DIAG>>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMaxPool2D.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMaxPool2D.h
index 1eb6532ff..557240d54 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMaxPool2D.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMaxPool2D.h
@@ -24,7 +24,7 @@
#include "luci/IR/AttrPadding.h"
#include "luci/IR/AttrStride.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -33,15 +33,13 @@ namespace luci
* @brief MAX_POOL_2D in Circle
*/
class CircleMaxPool2D final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::MAX_POOL_2D>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
{
public:
- CircleMaxPool2D() : _padding(Padding::UNDEFINED) { /* empty */}
-
-public:
loco::Node *value(void) const { return at(0)->node(); }
void value(loco::Node *node) { at(0)->node(node); }
+public:
Padding padding() const { return _padding; }
void padding(Padding padding) { _padding = padding; }
@@ -52,7 +50,7 @@ public:
Stride *stride(void) { return &_stride; }
private:
- Padding _padding;
+ Padding _padding{Padding::UNDEFINED};
Stride _stride;
Filter _filter;
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMaximum.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMaximum.h
index 6f789bc14..317cea308 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMaximum.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMaximum.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMean.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMean.h
index 7f8aeb5aa..f56e4f4c0 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMean.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMean.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -42,7 +42,7 @@ public:
void keep_dims(bool keep_dims) { _keep_dims = keep_dims; }
private:
- bool _keep_dims = false;
+ bool _keep_dims{false};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMinimum.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMinimum.h
index 79d5a6f17..959d9c93b 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMinimum.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMinimum.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMirrorPad.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMirrorPad.h
index 68db8f6f3..c69e8f7c1 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMirrorPad.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMirrorPad.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
#include "luci/IR/AttrMirrorPadMode.h"
namespace luci
@@ -32,9 +32,6 @@ namespace luci
class CircleMirrorPad final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::MIRROR_PAD>>
{
public:
- CircleMirrorPad() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleMul.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleMul.h
index 67e897170..85ed694b3 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleMul.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleMul.h
@@ -21,7 +21,7 @@
#include "luci/IR/CircleOpcode.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -30,7 +30,7 @@ namespace luci
* @brief MUL in Circle
*/
class CircleMul final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::MUL>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
{
public:
loco::Node *x(void) const { return at(0)->node(); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleNeg.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleNeg.h
index 4149ac4a7..adea3fb83 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleNeg.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleNeg.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4.h
index 69f3368c0..b47404bb0 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,7 +29,7 @@ namespace luci
* @brief NON_MAX_SUPPRESSION_V4 in Circle
*/
class CircleNonMaxSuppressionV4 final
- : public FixedArityNode<5, CircleNodeImpl<CircleOpcode::NON_MAX_SUPPRESSION_V4>>
+ : public FixedArityNode<5, CircleNodeImpl<CircleOpcode::NON_MAX_SUPPRESSION_V4>>
{
public:
loco::Node *boxes(void) const { return at(0)->node(); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4Out.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4Out.h
index a24dc3e9c..7e6923b5e 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4Out.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV4Out.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,12 +29,9 @@ namespace luci
* @brief Virtual NONMAXSUPPRESSIONV4OUT in Circle
*/
class CircleNonMaxSuppressionV4Out final
- : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLENONMAXSUPPRESSIONV4OUT>>
+ : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLENONMAXSUPPRESSIONV4OUT>>
{
public:
- CircleNonMaxSuppressionV4Out() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5.h
index 52d682147..77086ede7 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,7 +29,7 @@ namespace luci
* @brief NON_MAX_SUPPRESSION_V5 in Circle
*/
class CircleNonMaxSuppressionV5 final
- : public FixedArityNode<6, CircleNodeImpl<CircleOpcode::NON_MAX_SUPPRESSION_V5>>
+ : public FixedArityNode<6, CircleNodeImpl<CircleOpcode::NON_MAX_SUPPRESSION_V5>>
{
public:
loco::Node *boxes(void) const { return at(0)->node(); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5Out.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5Out.h
index 0c6989cc7..63d061f11 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5Out.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleNonMaxSuppressionV5Out.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,12 +29,9 @@ namespace luci
* @brief Virtual NONMAXSUPPRESSIONV5OUT in Circle
*/
class CircleNonMaxSuppressionV5Out final
- : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLENONMAXSUPPRESSIONV5OUT>>
+ : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLENONMAXSUPPRESSIONV5OUT>>
{
public:
- CircleNonMaxSuppressionV5Out() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleNotEqual.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleNotEqual.h
index cca7a5e22..add6a0747 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleNotEqual.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleNotEqual.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleOneHot.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleOneHot.h
index 665e01d48..b3eb0f436 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleOneHot.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleOneHot.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -48,7 +48,7 @@ public:
void axis(int32_t axis) { _axis = axis; }
private:
- int32_t _axis = -1;
+ int32_t _axis{-1};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleOutput.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleOutput.h
index 67e55f1a1..eb02f824e 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleOutput.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleOutput.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
#include <loco/IR/GraphOutputIndex.h>
@@ -34,8 +34,6 @@ namespace luci
class CircleOutput final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEOUTPUT>>
{
public:
- CircleOutput() = default;
-
void index(const loco::GraphOutputIndex &index);
loco::GraphOutputIndex index(void) const;
@@ -46,7 +44,7 @@ public:
void from(loco::Node *node) { at(0)->node(node); }
private:
- int64_t _index = -1; // Uninitialized
+ int64_t _index{-1}; // Uninitialized
};
/**
@@ -54,7 +52,7 @@ private:
*/
// TODO remove CircleOutputDummy
class CircleOutputDummy final
- : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLEOUTPUTDUMMY>>
+ : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLEOUTPUTDUMMY>>
{
public:
CircleOutputDummy() = default;
@@ -64,7 +62,7 @@ public:
* @brief CircleOutputExclude is used to specifying not exported nodes
*/
class CircleOutputExclude final
- : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLEOUTPUTEXCLUDE>>
+ : public FixedArityNode<0, CircleNodeImpl<CircleOpcode::CIRCLEOUTPUTEXCLUDE>>
{
public:
CircleOutputExclude() = default;
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CirclePRelu.h b/compiler/luci/lang/include/luci/IR/Nodes/CirclePRelu.h
index 693777512..3c5559db2 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CirclePRelu.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CirclePRelu.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CirclePRelu final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::PRELU>>
{
public:
- CirclePRelu() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CirclePad.h b/compiler/luci/lang/include/luci/IR/Nodes/CirclePad.h
index 31599bda0..ede217789 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CirclePad.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CirclePad.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CirclePad final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::PAD>>
{
public:
- CirclePad() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CirclePadV2.h b/compiler/luci/lang/include/luci/IR/Nodes/CirclePadV2.h
index 563cfd9a4..644e2bb27 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CirclePadV2.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CirclePadV2.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CirclePadV2 final : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::PADV2>>
{
public:
- CirclePadV2() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CirclePow.h b/compiler/luci/lang/include/luci/IR/Nodes/CirclePow.h
index 006e3dd86..40c5a829d 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CirclePow.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CirclePow.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CirclePow final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::POW>>
{
public:
- CirclePow() = default;
-
-public:
loco::Node *x(void) const { return at(0)->node(); }
void x(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleRange.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleRange.h
index 977a37a52..56f8a2eba 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleRange.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleRange.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleRank.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleRank.h
index ba6d67f69..034f251bc 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleRank.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleRank.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceAny.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceAny.h
index 0456be863..c64dbbdf8 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceAny.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceAny.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -42,7 +42,7 @@ public:
void keep_dims(bool keep_dims) { _keep_dims = keep_dims; }
private:
- bool _keep_dims = false;
+ bool _keep_dims{false};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMax.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMax.h
index 925c977e5..97cbecd08 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMax.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMax.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -42,7 +42,7 @@ public:
void keep_dims(bool keep_dims) { _keep_dims = keep_dims; }
private:
- bool _keep_dims = false;
+ bool _keep_dims{false};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMin.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMin.h
index fd789ae5e..33708928f 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMin.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceMin.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -42,7 +42,7 @@ public:
void keep_dims(bool keep_dims) { _keep_dims = keep_dims; }
private:
- bool _keep_dims = false;
+ bool _keep_dims{false};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceProd.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceProd.h
index b7d226255..3689ee532 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceProd.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReduceProd.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -42,7 +42,7 @@ public:
void keep_dims(bool keep_dims) { _keep_dims = keep_dims; }
private:
- bool _keep_dims = false;
+ bool _keep_dims{false};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu.h
index 91272d2bf..6148caa03 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleRelu final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::RELU>>
{
public:
- CircleRelu() = default;
-
-public:
loco::Node *features(void) const { return at(0)->node(); }
void features(loco::Node *node) { at(0)->node(node); }
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu6.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu6.h
index b4274ded9..0fa25e873 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu6.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleRelu6.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleRelu6 final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::RELU6>>
{
public:
- CircleRelu6() = default;
-
-public:
loco::Node *features(void) const { return at(0)->node(); }
void features(loco::Node *node) { at(0)->node(node); }
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReluN1To1.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReluN1To1.h
index a5c5710c2..13c0d166f 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReluN1To1.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReluN1To1.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleReluN1To1 final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::RELU_N1_TO_1>>
{
public:
- CircleReluN1To1() = default;
-
-public:
loco::Node *features(void) const { return at(0)->node(); }
void features(loco::Node *node) { at(0)->node(node); }
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReshape.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReshape.h
index b13144f7e..090df4044 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReshape.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReshape.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,14 +31,11 @@ namespace luci
class CircleReshape final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::RESHAPE>>
{
public:
- CircleReshape() = default;
-
-public:
loco::Node *tensor(void) const { return at(0)->node(); }
void tensor(loco::Node *node) { at(0)->node(node); }
// NOTE shape is optional and can be CircleConst or any other type
- // and also can be CircleOutputDummy when reshape option does not exist
+ // and also should be CircleOutputDummy when reshape option does not exist
loco::Node *shape(void) const { return at(1)->node(); }
void shape(loco::Node *node) { at(1)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeBilinear.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeBilinear.h
index 3c8223338..091916a2b 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeBilinear.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeBilinear.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,18 +29,16 @@ namespace luci
* @brief RESIZE_BILINEAR in Circle
*/
class CircleResizeBilinear final
- : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::RESIZE_BILINEAR>>
+ : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::RESIZE_BILINEAR>>
{
public:
- CircleResizeBilinear() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
loco::Node *size(void) const { return at(1)->node(); }
void size(loco::Node *node) { at(1)->node(node); }
+public:
bool align_corners() const { return _align_corners; }
void align_corners(bool value) { _align_corners = value; }
@@ -48,8 +46,8 @@ public:
void half_pixel_centers(bool value) { _half_pixel_centers = value; }
private:
- bool _align_corners = false;
- bool _half_pixel_centers = false;
+ bool _align_corners{false};
+ bool _half_pixel_centers{false};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeNearestNeighbor.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeNearestNeighbor.h
index dc32ebee7..ab880d767 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeNearestNeighbor.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleResizeNearestNeighbor.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,23 +29,21 @@ namespace luci
* @brief RESIZE_NEAREST_NEIGHBOR in Circle
*/
class CircleResizeNearestNeighbor final
- : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::RESIZE_NEAREST_NEIGHBOR>>
+ : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::RESIZE_NEAREST_NEIGHBOR>>
{
public:
- CircleResizeNearestNeighbor() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
loco::Node *size(void) const { return at(1)->node(); }
void size(loco::Node *node) { at(1)->node(node); }
+public:
bool align_corners() const { return _align_corners; }
void align_corners(bool value) { _align_corners = value; }
private:
- bool _align_corners = false;
+ bool _align_corners{false};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseSequence.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseSequence.h
index b0766dd3e..5f089a768 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseSequence.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseSequence.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,12 +29,9 @@ namespace luci
* @brief REVERSE_SEQUENCE in Circle
*/
class CircleReverseSequence final
- : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::REVERSE_SEQUENCE>>
+ : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::REVERSE_SEQUENCE>>
{
public:
- CircleReverseSequence() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
@@ -42,15 +39,15 @@ public:
void seq_lengths(loco::Node *node) { at(1)->node(node); }
public:
- int seq_axis(void) const { return _seq_axis; }
- void seq_axis(int seq_axis) { _seq_axis = seq_axis; }
+ int32_t seq_axis(void) const { return _seq_axis; }
+ void seq_axis(int32_t seq_axis) { _seq_axis = seq_axis; }
- int batch_axis(void) const { return _batch_axis; }
- void batch_axis(int batch_axis) { _batch_axis = batch_axis; }
+ int32_t batch_axis(void) const { return _batch_axis; }
+ void batch_axis(int32_t batch_axis) { _batch_axis = batch_axis; }
private:
- int _seq_axis{0};
- int _batch_axis{0};
+ int32_t _seq_axis{0};
+ int32_t _batch_axis{0};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseV2.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseV2.h
index 71d9f65aa..96b6a793d 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseV2.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleReverseV2.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleRound.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleRound.h
index 30296ce9e..e340266ed 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleRound.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleRound.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleRound final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::ROUND>>
{
public:
- CircleRound() = default;
-
-public:
loco::Node *x(void) const { return at(0)->node(); }
void x(loco::Node *node) { at(0)->node(node); }
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleRsqrt.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleRsqrt.h
index 873397bce..7907f326b 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleRsqrt.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleRsqrt.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleRsqrt final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::RSQRT>>
{
public:
- CircleRsqrt() = default;
-
-public:
loco::Node *x(void) const { return at(0)->node(); }
void x(loco::Node *node) { at(0)->node(node); }
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleScatterNd.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleScatterNd.h
index 9f93a0a80..fda3abafc 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleScatterNd.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleScatterNd.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSegmentSum.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSegmentSum.h
index 416d617b2..e7227e9ee 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSegmentSum.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSegmentSum.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleSegmentSum final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::SEGMENT_SUM>>
{
public:
- CircleSegmentSum() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSelect.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSelect.h
index 727647168..6f778d72d 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSelect.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSelect.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleSelect final : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::SELECT>>
{
public:
- CircleSelect() = default;
-
-public:
loco::Node *condition(void) const { return at(0)->node(); }
void condition(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSelectV2.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSelectV2.h
index 7ac3c0524..7969cc2aa 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSelectV2.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSelectV2.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleSelectV2 final : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::SELECT_V2>>
{
public:
- CircleSelectV2() = default;
-
-public:
loco::Node *condition(void) const { return at(0)->node(); }
void condition(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleShape.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleShape.h
index ff20ce684..903894dbd 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleShape.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleShape.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleShape final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::SHAPE>>
{
public:
- CircleShape() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSin.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSin.h
index 5624db253..25dc18b0d 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSin.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSin.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSlice.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSlice.h
index a2113643d..98556d7a6 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSlice.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSlice.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSoftmax.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSoftmax.h
index 7166a329b..d10cb1682 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSoftmax.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSoftmax.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToBatchND.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToBatchND.h
index 042ebffcd..ef715c6d0 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToBatchND.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToBatchND.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,7 +29,7 @@ namespace luci
* @brief SPACE_TO_BATCH_ND in Circle
*/
class CircleSpaceToBatchND final
- : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::SPACE_TO_BATCH_ND>>
+ : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::SPACE_TO_BATCH_ND>>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToDepth.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToDepth.h
index 420a4cb96..387e0d80f 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToDepth.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSpaceToDepth.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,18 +29,18 @@ namespace luci
* @brief SPACE_TO_DEPTH in Circle
*/
class CircleSpaceToDepth final
- : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::SPACE_TO_DEPTH>>
+ : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::SPACE_TO_DEPTH>>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
public:
- int block_size(void) const { return _block_size; }
- void block_size(int block_size) { _block_size = block_size; }
+ int32_t block_size(void) const { return _block_size; }
+ void block_size(int32_t block_size) { _block_size = block_size; }
private:
- int _block_size{0};
+ int32_t _block_size{0};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSparseToDense.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSparseToDense.h
index 7e80304b0..94a20c064 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSparseToDense.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSparseToDense.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,7 +29,7 @@ namespace luci
* @brief SPARSE_TO_DENSE in Circle
*/
class CircleSparseToDense final
- : public FixedArityNode<4, CircleNodeImpl<CircleOpcode::SPARSE_TO_DENSE>>
+ : public FixedArityNode<4, CircleNodeImpl<CircleOpcode::SPARSE_TO_DENSE>>
{
public:
loco::Node *indices(void) const { return at(0)->node(); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplit.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplit.h
index 0eda19501..0cb953131 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplit.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplit.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitOut.h
index 6bf4a9fef..a507740e4 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitOut.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitOut.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleSplitOut final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLESPLITOUT>>
{
public:
- CircleSplitOut() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitV.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitV.h
index 1b7d55534..cb02cbbcf 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitV.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitV.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitVOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitVOut.h
index d3b2f1e5a..adf79f30c 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitVOut.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSplitVOut.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,12 +29,9 @@ namespace luci
* @brief Virtual CIRCLESPLITVOUT in Circle
*/
class CircleSplitVOut final
- : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLESPLITVOUT>>
+ : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLESPLITVOUT>>
{
public:
- CircleSplitVOut() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSqrt.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSqrt.h
index c96ca8498..b76bd1ad5 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSqrt.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSqrt.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleSqrt final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::SQRT>>
{
public:
- CircleSqrt() = default;
-
-public:
loco::Node *x(void) const { return at(0)->node(); }
void x(loco::Node *node) { at(0)->node(node); }
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSquare.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSquare.h
index a29edfe82..3f9228b3b 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSquare.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSquare.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleSquare final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::SQUARE>>
{
public:
- CircleSquare() = default;
-
-public:
loco::Node *x(void) const { return at(0)->node(); }
void x(loco::Node *node) { at(0)->node(node); }
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSquaredDifference.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSquaredDifference.h
index b5b39f920..355c9f3d3 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSquaredDifference.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSquaredDifference.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,12 +29,9 @@ namespace luci
* @brief SQUARED_DIFFERENCE in Circle
*/
class CircleSquaredDifference final
- : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::SQUARED_DIFFERENCE>>
+ : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::SQUARED_DIFFERENCE>>
{
public:
- CircleSquaredDifference() = default;
-
-public:
loco::Node *x(void) const { return at(0)->node(); }
void x(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSqueeze.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSqueeze.h
index f175f1411..ba71ff217 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSqueeze.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSqueeze.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleSqueeze final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::SQUEEZE>>
{
public:
- CircleSqueeze() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleStridedSlice.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleStridedSlice.h
index 98799fec1..6a4155ef1 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleStridedSlice.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleStridedSlice.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,7 +29,7 @@ namespace luci
* @brief STRIDED_SLICE in Circle
*/
class CircleStridedSlice final
- : public FixedArityNode<4, CircleNodeImpl<CircleOpcode::STRIDED_SLICE>>
+ : public FixedArityNode<4, CircleNodeImpl<CircleOpcode::STRIDED_SLICE>>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSub.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSub.h
index 08208f942..d9aaa44e5 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSub.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSub.h
@@ -21,7 +21,7 @@
#include "luci/IR/CircleOpcode.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -30,12 +30,9 @@ namespace luci
* @brief SUB in Circle
*/
class CircleSub final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::SUB>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
{
public:
- CircleSub() = default;
-
-public:
loco::Node *x(void) const { return at(0)->node(); }
void x(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleSum.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleSum.h
index 21faa76fe..a72e18f54 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleSum.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleSum.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleTanh.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleTanh.h
index f7444921f..2036a7301 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleTanh.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleTanh.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleTanh final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::TANH>>
{
public:
- CircleTanh() = default;
-
-public:
loco::Node *x(void) const { return at(0)->node(); }
void x(loco::Node *node) { at(0)->node(node); }
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleTile.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleTile.h
index 96e1f69c6..1ec2f5e82 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleTile.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleTile.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleTile final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::TILE>>
{
public:
- CircleTile() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2.h
index 3b2b5abb7..0bf78c3ee 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleTopKV2 final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::TOPK_V2>>
{
public:
- CircleTopKV2() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2Out.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2Out.h
index 5a6dd0c02..f1a6b4a41 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2Out.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleTopKV2Out.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,12 +29,9 @@ namespace luci
* @brief Virtual CIRCLETOPKV2OUT in Circle
*/
class CircleTopKV2Out final
- : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLETOPKV2OUT>>
+ : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLETOPKV2OUT>>
{
public:
- CircleTopKV2Out() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleTranspose.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleTranspose.h
index 095cd6746..72ce0738c 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleTranspose.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleTranspose.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,13 +31,7 @@ namespace luci
class CircleTranspose final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::TRANSPOSE>>
{
public:
- CircleTranspose() = default;
-
-public:
- /// @brief Get the input node to transpose
loco::Node *a(void) const { return at(0)->node(); }
-
- /// @brief Set the input node to transpose
void a(loco::Node *node) { at(0)->node(node); }
loco::Node *perm(void) const { return at(1)->node(); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleTransposeConv.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleTransposeConv.h
index e355102d6..5ae41c0c4 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleTransposeConv.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleTransposeConv.h
@@ -22,7 +22,7 @@
#include "luci/IR/AttrPadding.h"
#include "luci/IR/AttrStride.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -34,8 +34,8 @@ namespace luci
* 'out' acutally means 'out' and 'in' of the this node.
*/
class CircleTransposeConv final
- : public FixedArityNode<4, CircleNodeImpl<CircleOpcode::TRANSPOSE_CONV>>,
- public LuciNodeMixin<LuciNodeTrait::Bias>
+ : public FixedArityNode<4, CircleNodeImpl<CircleOpcode::TRANSPOSE_CONV>>,
+ public CircleNodeMixin<CircleNodeTrait::Bias>
{
public:
loco::Node *inputSizes(void) const { return at(0)->node(); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnidirectionalSequenceLSTM.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnidirectionalSequenceLSTM.h
index 4352b045b..faf0ec94d 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnidirectionalSequenceLSTM.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnidirectionalSequenceLSTM.h
@@ -21,7 +21,7 @@
#include "luci/IR/CircleOpcode.h"
#include "luci/IR/AttrFusedActFunc.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -30,8 +30,8 @@ namespace luci
* @brief UNIDIRECTIONAL_SEQUENCE_LSTM in Circle
*/
class CircleUnidirectionalSequenceLSTM final
- : public FixedArityNode<24, CircleNodeImpl<CircleOpcode::UNIDIRECTIONAL_SEQUENCE_LSTM>>,
- public LuciNodeMixin<LuciNodeTrait::FusedActFunc>
+ : public FixedArityNode<24, CircleNodeImpl<CircleOpcode::UNIDIRECTIONAL_SEQUENCE_LSTM>>,
+ public CircleNodeMixin<CircleNodeTrait::FusedActFunc>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
@@ -104,10 +104,10 @@ public:
}
private:
- float _cell_clip = 0.0f;
- float _proj_clip = 0.0f;
- bool _time_major = false;
- bool _asymmetric_quantize_inputs = false;
+ float _cell_clip{0.0f};
+ float _proj_clip{0.0f};
+ bool _time_major{false};
+ bool _asymmetric_quantize_inputs{false};
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnique.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnique.h
index 719a72362..2dd48b2f9 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnique.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnique.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -36,7 +36,7 @@ public:
public:
loco::DataType idx_out_type(void) const { return _idx_out_type; }
- void output_type(loco::DataType ot) { _idx_out_type = ot; }
+ void idx_out_type(loco::DataType ot) { _idx_out_type = ot; }
private:
loco::DataType _idx_out_type{loco::DataType::S32};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleUniqueOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleUniqueOut.h
index f846403e0..233351860 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleUniqueOut.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleUniqueOut.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,12 +29,9 @@ namespace luci
* @brief Virtual CIRCLEUNIQUEOUT in Circle
*/
class CircleUniqueOut final
- : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEUNIQUEOUT>>
+ : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEUNIQUEOUT>>
{
public:
- CircleUniqueOut() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpack.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpack.h
index cb91d7e6a..fd0c66ce0 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpack.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpack.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleUnpack final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::UNPACK>>
{
public:
- CircleUnpack() = default;
-
-public:
loco::Node *value(void) const { return at(0)->node(); }
void value(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpackOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpackOut.h
index 6f24578a1..640d2f1bb 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpackOut.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleUnpackOut.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -29,12 +29,9 @@ namespace luci
* @brief Virtual CIRCLEUNPACKOUT in Circle
*/
class CircleUnpackOut final
- : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEUNPACKOUT>>
+ : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEUNPACKOUT>>
{
public:
- CircleUnpackOut() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleWhere.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleWhere.h
index 51eda3d6e..8895bcbbd 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleWhere.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleWhere.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
#include <cassert>
@@ -33,9 +33,6 @@ namespace luci
class CircleWhere final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::WHERE>>
{
public:
- CircleWhere() = default;
-
-public:
loco::Node *condition() const { return at(0)->node(); }
void condition(loco::Node *node) { at(0)->node(node); }
};
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleWhile.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleWhile.h
index 40ec96414..f4154d3ab 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleWhile.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleWhile.h
@@ -34,7 +34,7 @@ class CircleWhile final : public VariadicArityNode<CircleNodeImpl<CircleOpcode::
{
public:
CircleWhile(uint32_t arity, uint32_t out)
- : VariadicArityNode<CircleNodeImpl<CircleOpcode::WHILE>>(arity), _output_count(out)
+ : VariadicArityNode<CircleNodeImpl<CircleOpcode::WHILE>>(arity), _output_count(out)
{
assert(arity > 0);
assert(out > 0);
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleWhileOut.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleWhileOut.h
index cdf617848..98efc21e5 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleWhileOut.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleWhileOut.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,9 +31,6 @@ namespace luci
class CircleWhileOut final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::CIRCLEWHILEOUT>>
{
public:
- CircleWhileOut() = default;
-
-public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleZerosLike.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleZerosLike.h
index d3b6d272a..9302facd0 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleZerosLike.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleZerosLike.h
@@ -20,7 +20,7 @@
#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"
-#include "luci/IR/LuciNodeMixins.h"
+#include "luci/IR/CircleNodeMixins.h"
namespace luci
{
@@ -31,13 +31,7 @@ namespace luci
class CircleZerosLike final : public FixedArityNode<1, CircleNodeImpl<CircleOpcode::ZEROS_LIKE>>
{
public:
- CircleZerosLike() = default;
-
-public:
- /// @brief Get the input node
loco::Node *input(void) const { return at(0)->node(); }
-
- /// @brief Set the input node
void input(loco::Node *node) { at(0)->node(node); }
};
diff --git a/compiler/luci/lang/include/luci/IR/SparsityParam.h b/compiler/luci/lang/include/luci/IR/SparsityParam.h
index f471e5ef9..6cfff67e1 100644
--- a/compiler/luci/lang/include/luci/IR/SparsityParam.h
+++ b/compiler/luci/lang/include/luci/IR/SparsityParam.h
@@ -44,7 +44,7 @@ class SparseIndexVector
public:
SparseIndexVector() = default;
SparseIndexVector(const SparseIndexVectorType &type, const std::vector<int32_t> &sparse_index_vec)
- : _type{type}
+ : _type{type}
{
switch (type)
{
@@ -53,7 +53,7 @@ public:
case SparseIndexVectorType::I32:
{
_vec_ptr = static_cast<void *>(
- new std::vector<int32_t>(sparse_index_vec.begin(), sparse_index_vec.end()));
+ new std::vector<int32_t>(sparse_index_vec.begin(), sparse_index_vec.end()));
break;
}
case SparseIndexVectorType::U16:
@@ -90,21 +90,21 @@ public:
case SparseIndexVectorType::I32:
{
const std::vector<int32_t> *vec =
- static_cast<const std::vector<int32_t> *>(sparse_index_vec);
+ static_cast<const std::vector<int32_t> *>(sparse_index_vec);
_vec_ptr = static_cast<void *>(new std::vector<int32_t>(vec->begin(), vec->end()));
break;
}
case SparseIndexVectorType::U16:
{
const std::vector<uint16_t> *vec =
- static_cast<const std::vector<uint16_t> *>(sparse_index_vec);
+ static_cast<const std::vector<uint16_t> *>(sparse_index_vec);
_vec_ptr = static_cast<void *>(new std::vector<uint16_t>(vec->begin(), vec->end()));
break;
}
case SparseIndexVectorType::U8:
{
const std::vector<uint8_t> *vec =
- static_cast<const std::vector<uint8_t> *>(sparse_index_vec);
+ static_cast<const std::vector<uint8_t> *>(sparse_index_vec);
_vec_ptr = static_cast<void *>(new std::vector<uint8_t>(vec->begin(), vec->end()));
break;
}
@@ -114,12 +114,12 @@ public:
}
SparseIndexVector(const SparseIndexVector &sparse_index_vec)
- : SparseIndexVector(sparse_index_vec._type, sparse_index_vec._vec_ptr)
+ : SparseIndexVector(sparse_index_vec._type, sparse_index_vec._vec_ptr)
{
}
SparseIndexVector(SparseIndexVector &&sparse_index_vec)
- : _type{sparse_index_vec._type}, _vec_ptr{std::exchange(sparse_index_vec._vec_ptr, nullptr)}
+ : _type{sparse_index_vec._type}, _vec_ptr{std::exchange(sparse_index_vec._vec_ptr, nullptr)}
{
}
@@ -178,8 +178,8 @@ public:
const std::vector<uint16_t> *as_uint16_vector(void) const
{
return _type == SparseIndexVectorType::U16
- ? static_cast<const std::vector<uint16_t> *>(_vec_ptr)
- : nullptr;
+ ? static_cast<const std::vector<uint16_t> *>(_vec_ptr)
+ : nullptr;
}
const std::vector<uint8_t> *as_uint8_vector(void) const
{
@@ -202,8 +202,8 @@ public:
}
DimMetaData(DimensionType format, int32_t dense_size, const SparseIndexVector &array_segments,
const SparseIndexVector &array_indices)
- : _format{format}, _dense_size{dense_size}, _array_segments{array_segments},
- _array_indices{array_indices}
+ : _format{format}, _dense_size{dense_size}, _array_segments{array_segments}, _array_indices{
+ array_indices}
{
// DO NOTHING
}
diff --git a/compiler/luci/lang/src/CircleDialect.cpp b/compiler/luci/lang/src/CircleDialect.cpp
index 42ca3c917..0d315fc55 100644
--- a/compiler/luci/lang/src/CircleDialect.cpp
+++ b/compiler/luci/lang/src/CircleDialect.cpp
@@ -15,6 +15,7 @@
*/
#include "luci/IR/CircleDialect.h"
+#include "luci/IR/DeadNodeQueryService.h"
#include "luci/IR/Nodes/CircleInput.h"
#include "luci/IR/Nodes/CircleOutput.h"
@@ -22,8 +23,6 @@
#include <loco/IR/GraphInputIndex.h>
#include <loco/IR/GraphOutputIndex.h>
-#include "DeadNodeQueryService.h"
-
#include <cassert>
#include <memory>
diff --git a/compiler/luci/lang/src/LuciNodeMixins.cpp b/compiler/luci/lang/src/CircleNodeMixins.cpp
index 660cbe1a5..f72178df5 100644
--- a/compiler/luci/lang/src/LuciNodeMixins.cpp
+++ b/compiler/luci/lang/src/CircleNodeMixins.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,5 +14,5 @@
* limitations under the License.
*/
-// This is to validate LuciNodeMixins.h
-#include "luci/IR/LuciNodeMixins.h"
+// This is to validate CircleNodeMixins.h
+#include "luci/IR/CircleNodeMixins.h"
diff --git a/compiler/luci/lang/src/CircleNodes.cpp b/compiler/luci/lang/src/CircleNodes.cpp
index c77c06861..2c2688c9e 100644
--- a/compiler/luci/lang/src/CircleNodes.cpp
+++ b/compiler/luci/lang/src/CircleNodes.cpp
@@ -23,31 +23,6 @@
namespace luci
{
-void set_new_shape(CircleReshape *node, int32_t *base, uint32_t size)
-{
- // Check node does not have both of new shape infos
- LUCI_ASSERT(node->shape() == nullptr, "node already has shape input");
- LUCI_ASSERT(node->newShape()->rank() == 0, "node already has newShape attribute");
-
- const loco::DataType S32 = loco::DataType::S32;
-
- // Set 2nd input as CircleConst
- auto const_shape_node = node->graph()->nodes()->create<CircleConst>();
- const_shape_node->rank(1);
- const_shape_node->dim(0) = size;
- const_shape_node->dtype(S32);
- const_shape_node->size<S32>(size);
- const_shape_node->shape_status(luci::ShapeStatus::VALID);
- for (uint32_t axis = 0; axis < size; ++axis)
- const_shape_node->at<S32>(axis) = base[axis];
- node->shape(const_shape_node);
-
- // Set newShape attribute
- node->newShape()->rank(size);
- for (uint32_t axis = 0; axis < size; ++axis)
- node->newShape()->dim(axis) = base[axis];
-}
-
void link(loco::GraphOutput *output, CircleOutput *node) { node->index(output->index()); }
CircleOutput *output_node(loco::Graph *g, const loco::GraphOutputIndex &index)
diff --git a/compiler/luci/lang/src/DeadNodeQueryService.cpp b/compiler/luci/lang/src/DeadNodeQueryService.cpp
index a22574c94..7dac08b5f 100644
--- a/compiler/luci/lang/src/DeadNodeQueryService.cpp
+++ b/compiler/luci/lang/src/DeadNodeQueryService.cpp
@@ -14,9 +14,8 @@
* limitations under the License.
*/
-#include "DeadNodeQueryService.h"
-
#include "luci/IR/CircleNodeVisitor.h"
+#include "luci/IR/DeadNodeQueryService.h"
#include <loco/IR/Graph.h>
diff --git a/compiler/luci/lang/src/Nodes/CircleBatchMatMul.test.cpp b/compiler/luci/lang/src/Nodes/CircleBatchMatMul.test.cpp
index d7712c8dd..3859d7fca 100644
--- a/compiler/luci/lang/src/Nodes/CircleBatchMatMul.test.cpp
+++ b/compiler/luci/lang/src/Nodes/CircleBatchMatMul.test.cpp
@@ -26,7 +26,7 @@ TEST(CircleBatchMatMulTest, constructor)
luci::CircleBatchMatMul batchmatmul_node;
ASSERT_EQ(luci::CircleDialect::get(), batchmatmul_node.dialect());
- ASSERT_EQ(luci::CircleOpcode::BATCHMATMUL, batchmatmul_node.opcode());
+ ASSERT_EQ(luci::CircleOpcode::BATCH_MATMUL, batchmatmul_node.opcode());
ASSERT_EQ(nullptr, batchmatmul_node.x());
ASSERT_EQ(nullptr, batchmatmul_node.y());
diff --git a/compiler/luci/lang/src/Nodes/CircleBidrectionalSequenceLSTM.test.cpp b/compiler/luci/lang/src/Nodes/CircleBidrectionalSequenceLSTM.test.cpp
new file mode 100644
index 000000000..3f13422e5
--- /dev/null
+++ b/compiler/luci/lang/src/Nodes/CircleBidrectionalSequenceLSTM.test.cpp
@@ -0,0 +1,130 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/Nodes/CircleBidirectionalSequenceLSTM.h"
+
+#include "luci/IR/CircleDialect.h"
+#include "luci/IR/CircleNodeVisitor.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleBidirectionalSequenceLSTMTest, constructor_P)
+{
+ luci::CircleBidirectionalSequenceLSTM trc_node;
+
+ ASSERT_EQ(luci::CircleDialect::get(), trc_node.dialect());
+ ASSERT_EQ(luci::CircleOpcode::BIDIRECTIONAL_SEQUENCE_LSTM, trc_node.opcode());
+
+ ASSERT_EQ(nullptr, trc_node.input());
+
+ ASSERT_EQ(nullptr, trc_node.fw_input_to_input_weights());
+ ASSERT_EQ(nullptr, trc_node.fw_input_to_forget_weights());
+ ASSERT_EQ(nullptr, trc_node.fw_input_to_cell_weights());
+ ASSERT_EQ(nullptr, trc_node.fw_input_to_output_weights());
+
+ ASSERT_EQ(nullptr, trc_node.fw_recurrent_to_input_weights());
+ ASSERT_EQ(nullptr, trc_node.fw_recurrent_to_forget_weights());
+ ASSERT_EQ(nullptr, trc_node.fw_recurrent_to_cell_weights());
+ ASSERT_EQ(nullptr, trc_node.fw_recurrent_to_output_weights());
+
+ ASSERT_EQ(nullptr, trc_node.fw_cell_to_input_weights());
+ ASSERT_EQ(nullptr, trc_node.fw_cell_to_forget_weights());
+ ASSERT_EQ(nullptr, trc_node.fw_cell_to_output_weights());
+
+ ASSERT_EQ(nullptr, trc_node.fw_input_gate_bias());
+ ASSERT_EQ(nullptr, trc_node.fw_forget_gate_bias());
+ ASSERT_EQ(nullptr, trc_node.fw_cell_gate_bias());
+ ASSERT_EQ(nullptr, trc_node.fw_output_gate_bias());
+
+ ASSERT_EQ(nullptr, trc_node.fw_projection_weights());
+ ASSERT_EQ(nullptr, trc_node.fw_projection_bias());
+
+ ASSERT_EQ(nullptr, trc_node.bw_input_to_input_weights());
+ ASSERT_EQ(nullptr, trc_node.bw_input_to_forget_weights());
+ ASSERT_EQ(nullptr, trc_node.bw_input_to_cell_weights());
+ ASSERT_EQ(nullptr, trc_node.bw_input_to_output_weights());
+
+ ASSERT_EQ(nullptr, trc_node.bw_recurrent_to_input_weights());
+ ASSERT_EQ(nullptr, trc_node.bw_recurrent_to_forget_weights());
+ ASSERT_EQ(nullptr, trc_node.bw_recurrent_to_cell_weights());
+ ASSERT_EQ(nullptr, trc_node.bw_recurrent_to_output_weights());
+
+ ASSERT_EQ(nullptr, trc_node.bw_cell_to_input_weights());
+ ASSERT_EQ(nullptr, trc_node.bw_cell_to_forget_weights());
+ ASSERT_EQ(nullptr, trc_node.bw_cell_to_output_weights());
+
+ ASSERT_EQ(nullptr, trc_node.bw_input_gate_bias());
+ ASSERT_EQ(nullptr, trc_node.bw_forget_gate_bias());
+ ASSERT_EQ(nullptr, trc_node.bw_cell_gate_bias());
+ ASSERT_EQ(nullptr, trc_node.bw_output_gate_bias());
+
+ ASSERT_EQ(nullptr, trc_node.bw_projection_weights());
+ ASSERT_EQ(nullptr, trc_node.bw_projection_bias());
+
+ ASSERT_EQ(nullptr, trc_node.fw_activation_state());
+ ASSERT_EQ(nullptr, trc_node.fw_cell_state());
+ ASSERT_EQ(nullptr, trc_node.bw_activation_state());
+ ASSERT_EQ(nullptr, trc_node.bw_cell_state());
+
+ ASSERT_EQ(nullptr, trc_node.auxillary_input());
+ ASSERT_EQ(nullptr, trc_node.fw_auxillary_input_to_input_weights());
+ ASSERT_EQ(nullptr, trc_node.fw_auxillary_input_to_forget_weights());
+ ASSERT_EQ(nullptr, trc_node.fw_auxillary_input_to_cell_weights());
+ ASSERT_EQ(nullptr, trc_node.fw_auxillary_input_to_output_weights());
+ ASSERT_EQ(nullptr, trc_node.bw_auxillary_input_to_input_weights());
+ ASSERT_EQ(nullptr, trc_node.bw_auxillary_input_to_forget_weights());
+ ASSERT_EQ(nullptr, trc_node.bw_auxillary_input_to_cell_weights());
+ ASSERT_EQ(nullptr, trc_node.bw_auxillary_input_to_output_weights());
+
+ ASSERT_EQ(luci::FusedActFunc::UNDEFINED, trc_node.fusedActivationFunction());
+ ASSERT_EQ(0.f, trc_node.cell_clip());
+ ASSERT_EQ(0.f, trc_node.proj_clip());
+ ASSERT_EQ(false, trc_node.merge_outputs());
+ ASSERT_EQ(false, trc_node.time_major());
+ ASSERT_EQ(false, trc_node.asymmetric_quantize_inputs());
+}
+
+TEST(CircleBidirectionalSequenceLSTMTest, arity_NEG)
+{
+ luci::CircleBidirectionalSequenceLSTM trc_node;
+
+ ASSERT_NO_THROW(trc_node.arg(36));
+ ASSERT_THROW(trc_node.arg(48), std::out_of_range);
+}
+
+TEST(CircleBidirectionalSequenceLSTMTest, visit_mutable_NEG)
+{
+ struct TestVisitor final : public luci::CircleNodeMutableVisitor<void>
+ {
+ };
+
+ luci::CircleBidirectionalSequenceLSTM trc_node;
+
+ TestVisitor tv;
+ ASSERT_THROW(trc_node.accept(&tv), std::exception);
+}
+
+TEST(CircleBidirectionalSequenceLSTMTest, visit_NEG)
+{
+ struct TestVisitor final : public luci::CircleNodeVisitor<void>
+ {
+ };
+
+ luci::CircleBidirectionalSequenceLSTM trc_node;
+
+ TestVisitor tv;
+ ASSERT_THROW(trc_node.accept(&tv), std::exception);
+}
diff --git a/compiler/luci/lang/src/Nodes/CircleConst.test.cpp b/compiler/luci/lang/src/Nodes/CircleConst.test.cpp
new file mode 100644
index 000000000..a81f4b00d
--- /dev/null
+++ b/compiler/luci/lang/src/Nodes/CircleConst.test.cpp
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/Nodes/CircleConst.h"
+
+#include "luci/IR/CircleDialect.h"
+#include "luci/IR/CircleNodeVisitor.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleConstTest, constructor)
+{
+ luci::CircleConst const_node;
+
+ ASSERT_EQ(luci::CircleDialect::get(), const_node.dialect());
+ ASSERT_EQ(luci::CircleOpcode::CIRCLECONST, const_node.opcode());
+}
+
+TEST(CircleConstTest, dype_size)
+{
+ luci::CircleConst const_node;
+
+ const_node.dtype(loco::DataType::S32);
+ const_node.size<loco::DataType::S32>(1);
+
+ ASSERT_EQ(loco::DataType::S32, const_node.dtype());
+ ASSERT_EQ(1, const_node.size<loco::DataType::S32>());
+}
+
+TEST(CircleConstTest, scalar)
+{
+ luci::CircleConst const_node;
+
+ const_node.dtype(loco::DataType::S32);
+ const_node.size<loco::DataType::S32>(1);
+ const_node.scalar<loco::DataType::S32>() = 1;
+
+ auto const &cs = const_node.scalar<loco::DataType::S32>();
+ ASSERT_EQ(1, cs);
+}
diff --git a/compiler/luci/lang/src/Nodes/CircleCustom.test.cpp b/compiler/luci/lang/src/Nodes/CircleCustom.test.cpp
index c07268cbf..76b70f38b 100644
--- a/compiler/luci/lang/src/Nodes/CircleCustom.test.cpp
+++ b/compiler/luci/lang/src/Nodes/CircleCustom.test.cpp
@@ -22,7 +22,7 @@
TEST(CircleCustomTest, constructor)
{
- luci::CircleCustom custom_node(2);
+ luci::CircleCustom custom_node(2, 1);
ASSERT_EQ(luci::CircleDialect::get(), custom_node.dialect());
ASSERT_EQ(luci::CircleOpcode::CUSTOM, custom_node.opcode());
@@ -33,18 +33,19 @@ TEST(CircleCustomTest, constructor)
ASSERT_EQ(2, custom_node.numInputs());
ASSERT_EQ(0, custom_node.custom_code().size());
+ ASSERT_EQ(1, custom_node.numOutputs());
}
TEST(CircleCustomTest, constructor_NEG)
{
- ASSERT_DEBUG_DEATH(luci::CircleCustom{0}, "");
+ ASSERT_DEBUG_DEATH(luci::CircleCustom(0, 0), "");
SUCCEED();
}
TEST(CircleCustomTest, invalidIndex_NEG)
{
- luci::CircleCustom custom_node(2);
+ luci::CircleCustom custom_node(2, 1);
EXPECT_ANY_THROW(custom_node.arg(5));
}
diff --git a/compiler/luci/lang/src/Nodes/CircleFakeQuant.test.cpp b/compiler/luci/lang/src/Nodes/CircleFakeQuant.test.cpp
new file mode 100644
index 000000000..912e40570
--- /dev/null
+++ b/compiler/luci/lang/src/Nodes/CircleFakeQuant.test.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/Nodes/CircleFakeQuant.h"
+
+#include "luci/IR/CircleDialect.h"
+#include "luci/IR/CircleNodeVisitor.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleFakeQuantTest, constructor_P)
+{
+ luci::CircleFakeQuant fakequant;
+
+ ASSERT_EQ(fakequant.dialect(), luci::CircleDialect::get());
+ ASSERT_EQ(fakequant.opcode(), luci::CircleOpcode::FAKE_QUANT);
+
+ ASSERT_EQ(nullptr, fakequant.inputs());
+ ASSERT_EQ(0.0f, fakequant.min());
+ ASSERT_EQ(0.0f, fakequant.max());
+ ASSERT_EQ(0, fakequant.num_bits());
+ ASSERT_FALSE(fakequant.narrow_range());
+}
diff --git a/compiler/luci/logex/src/FormattedGraph.cpp b/compiler/luci/logex/src/FormattedGraph.cpp
index b2b9cb72b..f1337e3e6 100644
--- a/compiler/luci/logex/src/FormattedGraph.cpp
+++ b/compiler/luci/logex/src/FormattedGraph.cpp
@@ -146,7 +146,9 @@ std::string circle_opname(uint32_t opnum)
#define CIRCLE_NODE(OPCODE, CLASS) \
case luci::CircleOpcode::OPCODE: \
return prefix + #OPCODE;
+#define CIRCLE_VNODE CIRCLE_NODE
#include <luci/IR/CircleNodes.lst>
+#undef CIRCLE_VNODE
#undef CIRCLE_NODE
default:
break;
@@ -175,7 +177,9 @@ protected:
s.state(locop::NodeSummary::State::PartiallyKnown); \
return true; \
}
+#define CIRCLE_VNODE CIRCLE_NODE
#include <luci/IR/CircleNodes.lst>
+#undef CIRCLE_VNODE
#undef CIRCLE_NODE
protected:
@@ -205,6 +209,7 @@ private:
IMPLEMENT(luci::CircleAveragePool2D)
IMPLEMENT(luci::CircleBatchMatMul)
IMPLEMENT(luci::CircleBatchToSpaceND)
+ IMPLEMENT(luci::CircleBidirectionalSequenceLSTM)
IMPLEMENT(luci::CircleCast)
IMPLEMENT(luci::CircleCeil)
IMPLEMENT(luci::CircleConcatenation)
@@ -219,6 +224,7 @@ private:
IMPLEMENT(luci::CircleElu)
IMPLEMENT(luci::CircleExp)
IMPLEMENT(luci::CircleExpandDims)
+ IMPLEMENT(luci::CircleFakeQuant)
IMPLEMENT(luci::CircleFill)
IMPLEMENT(luci::CircleFloor)
IMPLEMENT(luci::CircleFloorDiv)
@@ -433,6 +439,96 @@ bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBatchToSpaceN
return true;
}
+bool summary_node(const locop::SymbolTable *tbl, const luci::CircleBidirectionalSequenceLSTM *node,
+ locop::NodeSummary &s)
+{
+ s.args().append("input", tbl->lookup(node->input()));
+
+ s.args().append("fw_input_to_input_weights", tbl->lookup(node->fw_input_to_input_weights()));
+ s.args().append("fw_input_to_forget_weights", tbl->lookup(node->fw_input_to_forget_weights()));
+ s.args().append("fw_input_to_cell_weights", tbl->lookup(node->fw_input_to_cell_weights()));
+ s.args().append("fw_input_to_output_weights", tbl->lookup(node->fw_input_to_output_weights()));
+
+ s.args().append("fw_recurrent_to_input_weights",
+ tbl->lookup(node->fw_recurrent_to_input_weights()));
+ s.args().append("fw_recurrent_to_forget_weights",
+ tbl->lookup(node->fw_recurrent_to_forget_weights()));
+ s.args().append("fw_recurrent_to_cell_weights",
+ tbl->lookup(node->fw_recurrent_to_cell_weights()));
+ s.args().append("fw_recurrent_to_output_weights",
+ tbl->lookup(node->fw_recurrent_to_output_weights()));
+
+ s.args().append("fw_cell_to_input_weights", tbl->lookup(node->fw_cell_to_input_weights()));
+ s.args().append("fw_cell_to_forget_weights", tbl->lookup(node->fw_cell_to_forget_weights()));
+ s.args().append("fw_cell_to_output_weights", tbl->lookup(node->fw_cell_to_output_weights()));
+
+ s.args().append("fw_input_gate_bias", tbl->lookup(node->fw_input_gate_bias()));
+ s.args().append("fw_forget_gate_bias", tbl->lookup(node->fw_forget_gate_bias()));
+ s.args().append("fw_cell_gate_bias", tbl->lookup(node->fw_cell_gate_bias()));
+ s.args().append("fw_output_gate_bias", tbl->lookup(node->fw_output_gate_bias()));
+
+ s.args().append("fw_projection_weights", tbl->lookup(node->fw_projection_weights()));
+ s.args().append("fw_projection_bias", tbl->lookup(node->fw_projection_bias()));
+
+ s.args().append("bw_input_to_input_weights", tbl->lookup(node->bw_input_to_input_weights()));
+ s.args().append("bw_input_to_forget_weights", tbl->lookup(node->bw_input_to_forget_weights()));
+ s.args().append("bw_input_to_cell_weights", tbl->lookup(node->bw_input_to_cell_weights()));
+ s.args().append("bw_input_to_output_weights", tbl->lookup(node->bw_input_to_output_weights()));
+
+ s.args().append("bw_recurrent_to_input_weights",
+ tbl->lookup(node->bw_recurrent_to_input_weights()));
+ s.args().append("bw_recurrent_to_forget_weights",
+ tbl->lookup(node->bw_recurrent_to_forget_weights()));
+ s.args().append("bw_recurrent_to_cell_weights",
+ tbl->lookup(node->bw_recurrent_to_cell_weights()));
+ s.args().append("bw_recurrent_to_output_weights",
+ tbl->lookup(node->bw_recurrent_to_output_weights()));
+
+ s.args().append("bw_cell_to_input_weights", tbl->lookup(node->bw_cell_to_input_weights()));
+ s.args().append("bw_cell_to_forget_weights", tbl->lookup(node->bw_cell_to_forget_weights()));
+ s.args().append("bw_cell_to_output_weights", tbl->lookup(node->bw_cell_to_output_weights()));
+
+ s.args().append("bw_input_gate_bias", tbl->lookup(node->bw_input_gate_bias()));
+ s.args().append("bw_forget_gate_bias", tbl->lookup(node->bw_forget_gate_bias()));
+ s.args().append("bw_cell_gate_bias", tbl->lookup(node->bw_cell_gate_bias()));
+ s.args().append("bw_output_gate_bias", tbl->lookup(node->bw_output_gate_bias()));
+
+ s.args().append("bw_projection_weights", tbl->lookup(node->bw_projection_weights()));
+ s.args().append("bw_projection_bias", tbl->lookup(node->bw_projection_bias()));
+
+ s.args().append("fw_activation_state", tbl->lookup(node->fw_activation_state()));
+ s.args().append("fw_cell_state", tbl->lookup(node->fw_cell_state()));
+ s.args().append("bw_activation_state", tbl->lookup(node->bw_activation_state()));
+ s.args().append("bw_cell_state", tbl->lookup(node->bw_cell_state()));
+
+ s.args().append("auxillary_input", tbl->lookup(node->auxillary_input()));
+ s.args().append("fw_auxillary_input_to_input_weights",
+ tbl->lookup(node->fw_auxillary_input_to_input_weights()));
+ s.args().append("fw_auxillary_input_to_forget_weights",
+ tbl->lookup(node->fw_auxillary_input_to_forget_weights()));
+ s.args().append("fw_auxillary_input_to_cell_weights",
+ tbl->lookup(node->fw_auxillary_input_to_cell_weights()));
+ s.args().append("fw_auxillary_input_to_output_weights",
+ tbl->lookup(node->fw_auxillary_input_to_output_weights()));
+ s.args().append("bw_auxillary_input_to_input_weights",
+ tbl->lookup(node->bw_auxillary_input_to_input_weights()));
+ s.args().append("bw_auxillary_input_to_forget_weights",
+ tbl->lookup(node->bw_auxillary_input_to_forget_weights()));
+ s.args().append("bw_auxillary_input_to_cell_weights",
+ tbl->lookup(node->bw_auxillary_input_to_cell_weights()));
+ s.args().append("bw_auxillary_input_to_output_weights",
+ tbl->lookup(node->bw_auxillary_input_to_output_weights()));
+
+ s.args().append("cell_clip", to_str(node->cell_clip()));
+ s.args().append("proj_clip", to_str(node->proj_clip()));
+ s.args().append("merge_outputs", to_str(node->merge_outputs()));
+ s.args().append("time_major", to_str(node->time_major()));
+ s.args().append("asymmetric_quantize_inputs", to_str(node->asymmetric_quantize_inputs()));
+
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
bool summary_node(const locop::SymbolTable *tbl, const luci::CircleCast *node,
locop::NodeSummary &s)
{
@@ -521,6 +617,18 @@ bool summary_node(const locop::SymbolTable *tbl, const luci::CircleExpandDims *n
return true;
}
+bool summary_node(const locop::SymbolTable *tbl, const luci::CircleFakeQuant *node,
+ locop::NodeSummary &s)
+{
+ s.args().append("inputs", tbl->lookup(node->inputs()));
+ s.args().append("min", pepper::str(node->min()));
+ s.args().append("max", pepper::str(node->max()));
+ s.args().append("num_bits", pepper::str(node->num_bits()));
+ s.args().append("narrow_range", node->narrow_range() ? "true" : "false");
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
+
bool summary_node(const locop::SymbolTable *tbl, const luci::CircleFill *node,
locop::NodeSummary &s)
{
@@ -1189,7 +1297,9 @@ bool CircleNodeSummaryBuilderBase::build(const loco::Node *node, locop::NodeSumm
s.comments().append("Mem = " + ptr_to_str(node)); \
return summary(dynamic_cast<const CLASS *>(node), s); \
}
+#define CIRCLE_VNODE CIRCLE_NODE
#include <luci/IR/CircleNodes.lst>
+#undef CIRCLE_VNODE
#undef CIRCLE_NODE
return false;
@@ -1238,6 +1348,12 @@ bool CircleNodeSummaryBuilder::summary(const luci::CircleBatchToSpaceND *node,
return summary_node(tbl(), node, s);
}
+bool CircleNodeSummaryBuilder::summary(const luci::CircleBidirectionalSequenceLSTM *node,
+ locop::NodeSummary &s) const
+{
+ return summary_node(tbl(), node, s);
+}
+
bool CircleNodeSummaryBuilder::summary(const luci::CircleCast *node, locop::NodeSummary &s) const
{
return summary_node(tbl(), node, s);
@@ -1314,6 +1430,17 @@ bool CircleNodeSummaryBuilder::summary(const luci::CircleExpandDims *node,
return summary_node(tbl(), node, s);
}
+bool CircleNodeSummaryBuilder::summary(const luci::CircleFakeQuant *node,
+ locop::NodeSummary &s) const
+{
+ return summary_node(tbl(), node, s);
+}
+
+bool CircleNodeSummaryBuilder::summary(const luci::CircleFill *node, locop::NodeSummary &s) const
+{
+ return summary_node(tbl(), node, s);
+}
+
bool CircleNodeSummaryBuilder::summary(const luci::CircleFloor *node, locop::NodeSummary &s) const
{
return use_x(tbl(), node, s);
@@ -1331,11 +1458,6 @@ bool CircleNodeSummaryBuilder::summary(const luci::CircleFloorMod *node,
return use_xy(tbl(), node, s);
}
-bool CircleNodeSummaryBuilder::summary(const luci::CircleFill *node, locop::NodeSummary &s) const
-{
- return summary_node(tbl(), node, s);
-}
-
bool CircleNodeSummaryBuilder::summary(const luci::CircleFullyConnected *node,
locop::NodeSummary &s) const
{
diff --git a/compiler/luci/partition/CMakeLists.txt b/compiler/luci/partition/CMakeLists.txt
new file mode 100644
index 000000000..838642b6e
--- /dev/null
+++ b/compiler/luci/partition/CMakeLists.txt
@@ -0,0 +1,29 @@
+file(GLOB_RECURSE SOURCES "src/*.cpp")
+file(GLOB_RECURSE TESTS "src/*.test.cpp")
+list(REMOVE_ITEM SOURCES ${TESTS})
+
+add_library(luci_partition SHARED ${SOURCES})
+target_include_directories(luci_partition PRIVATE src)
+target_include_directories(luci_partition PUBLIC include)
+target_link_libraries(luci_partition PUBLIC luci_lang)
+target_link_libraries(luci_partition PRIVATE luci_service)
+target_link_libraries(luci_partition PRIVATE luci_log)
+target_link_libraries(luci_partition PRIVATE luci_logex)
+target_link_libraries(luci_partition PRIVATE mio_circle)
+target_link_libraries(luci_partition PRIVATE nncc_common)
+target_link_libraries(luci_partition PRIVATE oops)
+
+install(TARGETS luci_partition DESTINATION lib)
+
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest REQUIRED)
+
+GTest_AddTest(luci_partition_test ${TESTS})
+target_include_directories(luci_partition_test PRIVATE src)
+target_link_libraries(luci_partition_test luci_lang)
+target_link_libraries(luci_partition_test luci_partition)
+target_link_libraries(luci_partition_test luci_testhelper)
+target_link_libraries(luci_partition_test luci_service)
diff --git a/compiler/luci/partition/README.md b/compiler/luci/partition/README.md
new file mode 100644
index 000000000..40a46bc56
--- /dev/null
+++ b/compiler/luci/partition/README.md
@@ -0,0 +1,4 @@
+# luci-partition
+
+`luci-partition` provides partition of a model to two or more sub models and
+its connection configuration having same computational results.
diff --git a/compiler/luci/partition/include/luci/Partition.h b/compiler/luci/partition/include/luci/Partition.h
new file mode 100644
index 000000000..cf90e448b
--- /dev/null
+++ b/compiler/luci/partition/include/luci/Partition.h
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PARTITION_H__
+#define __LUCI_PARTITION_H__
+
+#include <luci/IR/Module.h>
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+namespace luci
+{
+
+/**
+ * @brief PartitionTable holds partition information
+ */
+struct PartitionTable
+{
+ std::vector<std::string> groups;
+ std::string default_group;
+
+ // assign by opcode name: OPCODENAME=group
+ std::unordered_map<std::string /* OPCODENAME */, std::string /* group */> byopcodes;
+
+ // TODO add assign by OP name
+};
+
+/**
+ * @brief PartedModule holds partitioned module and group name
+ */
+struct PartedModule
+{
+ std::unique_ptr<Module> module;
+ // group name used to partition this module
+ std::string group;
+
+ // unique name(filename) of this module
+ std::string name;
+};
+
+struct PartedModules
+{
+ std::vector<PartedModule> pmodules;
+
+ // TODO add connections ?
+};
+
+/**
+ * @brief Method to do paritioning from module and PartitionTable to produce PartedModules
+ */
+PartedModules apply(Module *module, const PartitionTable &partition);
+
+} // namespace luci
+
+#endif // __LUCI_PARTITION_H__
diff --git a/compiler/luci/partition/src/CircleOpCode.cpp b/compiler/luci/partition/src/CircleOpCode.cpp
new file mode 100644
index 000000000..86694fa40
--- /dev/null
+++ b/compiler/luci/partition/src/CircleOpCode.cpp
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleOpCode.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+#include <mio/circle/schema_generated.h>
+
+namespace
+{
+
+using namespace luci;
+using namespace circle;
+
+class QueryOpCode final : public CircleNodeVisitor<BuiltinOperator>
+{
+public:
+// NOTE only circle operator may have BuiltinOperator_XXX
+#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) \
+ BuiltinOperator visit(const CIRCLE_CLASS *) final { return BuiltinOperator_##OPCODE; }
+#define CIRCLE_VNODE(OPCODE, CIRCLE_CLASS)
+
+#include "luci/IR/CircleNodes.lst"
+#undef CIRCLE_VNODE
+#undef CIRCLE_NODE
+
+ // NOTE only builtin operators should be called (NOT virtual nodes)
+};
+
+class QueryCircleName final : public luci::CircleNodeVisitor<const char *>
+{
+public:
+// NOTE provide names for circle virtual nodes
+#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS)
+#define CIRCLE_VNODE(OPCODE, CIRCLE_CLASS) \
+ const char *visit(const CIRCLE_CLASS *) final { return #OPCODE; }
+
+#include "luci/IR/CircleNodes.lst"
+#undef CIRCLE_VNODE
+#undef CIRCLE_NODE
+
+ // default is null
+ const char *visit(const luci::CircleNode *) final { return nullptr; }
+};
+
+} // namespace
+
+namespace luci
+{
+
+std::string opcode_name(const CircleNode *node)
+{
+ QueryCircleName qcn;
+ auto cname = node->accept(&qcn);
+ if (cname != nullptr)
+ return std::string(cname);
+
+ QueryOpCode qoc;
+ auto opcode = node->accept(&qoc);
+ auto name = circle::EnumNameBuiltinOperator(opcode);
+ return std::string(name);
+}
+
+} // namespace luci
diff --git a/compiler/luci/lang/src/CircleShapeSignature.cpp b/compiler/luci/partition/src/CircleOpCode.h
index 970000203..d17b09261 100644
--- a/compiler/luci/lang/src/CircleShapeSignature.cpp
+++ b/compiler/luci/partition/src/CircleOpCode.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,21 +14,18 @@
* limitations under the License.
*/
-#include "luci/IR/CircleShapeSignature.h"
+#ifndef __LUCI_PARTITION_CIRCLE_OP_CODE_H__
+#define __LUCI_PARTITION_CIRCLE_OP_CODE_H__
-namespace luci
-{
+#include <luci/IR/CircleNode.h>
-bool operator==(const ShapeSignature &lhs, const ShapeSignature &rhs)
-{
- if (lhs.rank() != rhs.rank())
- return false;
+#include <string>
- for (uint32_t i = 0; i < lhs.rank(); ++i)
- if (lhs.dim(i) != rhs.dim(i))
- return false;
+namespace luci
+{
- return true;
-}
+std::string opcode_name(const CircleNode *node);
} // namespace luci
+
+#endif // __LUCI_PARTITION_CIRCLE_OP_CODE_H__
diff --git a/compiler/luci/partition/src/CircleOpCode.test.cpp b/compiler/luci/partition/src/CircleOpCode.test.cpp
new file mode 100644
index 000000000..d2524a2ef
--- /dev/null
+++ b/compiler/luci/partition/src/CircleOpCode.test.cpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleOpCode.h"
+
+// NOTE any node will do for testing
+#include <luci/IR/Nodes/CircleSqrt.h>
+
+#include <gtest/gtest.h>
+
+TEST(CircleOpCodeTest, name)
+{
+ auto g = loco::make_graph();
+ auto node = g->nodes()->create<luci::CircleSqrt>();
+
+ auto name = luci::opcode_name(node);
+ ASSERT_EQ(name, "SQRT");
+}
diff --git a/compiler/luci/partition/src/ConnectNode.cpp b/compiler/luci/partition/src/ConnectNode.cpp
new file mode 100644
index 000000000..336be7c57
--- /dev/null
+++ b/compiler/luci/partition/src/ConnectNode.cpp
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include <oops/UserExn.h>
+
+namespace luci
+{
+
+void clone_connect(const luci::CircleNode *node, luci::CloneContext &clonecontext)
+{
+ ConnectNode cn(clonecontext);
+ node->accept(&cn);
+}
+
+luci::CircleNode *ConnectNode::find_clone(const luci::CircleNode *node)
+{
+ auto it = _clonecontext.find(node);
+ if (it == _clonecontext.end())
+ throw oops::UserExn("Invalid node in ConnectNode");
+ return it->second;
+}
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/ConnectNode.h b/compiler/luci/partition/src/ConnectNode.h
new file mode 100644
index 000000000..017c587e5
--- /dev/null
+++ b/compiler/luci/partition/src/ConnectNode.h
@@ -0,0 +1,209 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PARTITION_CONNECT_NODE_H__
+#define __LUCI_PARTITION_CONNECT_NODE_H__
+
+#include <luci/IR/CircleNode.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+/**
+ * @note MapNode2Clone is used as a map from original node to cloned node
+ * to find input of a cloned node
+ *
+ * (Original) (Clone)
+ *
+ * [A] [A']
+ * | [B] | [B']
+ * | | | |
+ * \ / \ /
+ * [C] [C']
+ *
+ * From view of [C'] we need to find [A'] and [B']. We know [C] from [C'],
+ * then we can get from input of [C] as [A], [B] then [A]->[A'] and [B]->[B']
+ * from the map.
+ */
+using MapNode2Clone = std::map<const CircleNode * /* ORG */, CircleNode * /* CLONE */>;
+
+struct CloneContext
+{
+ std::pair<MapNode2Clone::iterator, bool> emplace(const CircleNode *org, CircleNode *clone)
+ {
+ return node2clone.emplace(org, clone);
+ }
+ MapNode2Clone::iterator find(const CircleNode *org) { return node2clone.find(org); }
+ MapNode2Clone::iterator end(void) { return node2clone.end(); }
+
+ MapNode2Clone node2clone;
+};
+
+class ConnectNode final : public luci::CircleNodeVisitor<void>
+{
+public:
+ ConnectNode(luci::CloneContext &clonecontext) : _clonecontext(clonecontext){};
+
+public:
+ // void visit(const luci::CircleAbs *) final;
+ void visit(const luci::CircleAdd *) final;
+ // void visit(const luci::CircleAddN *) final;
+ // void visit(const luci::CircleArgMax *) final;
+ // void visit(const luci::CircleArgMin *) final;
+ // void visit(const luci::CircleAveragePool2D *) final;
+ // void visit(const luci::CircleBatchMatMul *) final;
+ // void visit(const luci::CircleBatchToSpaceND *) final;
+ // void visit(const luci::CircleCast *) final;
+ // void visit(const luci::CircleCeil *) final;
+ // void visit(const luci::CircleConcatenation *) final;
+ void visit(const luci::CircleConst *) final;
+ // void visit(const luci::CircleConv2D *) final;
+ // void visit(const luci::CircleCos *) final;
+ // void visit(const luci::CircleCustom *) final;
+ // void visit(const luci::CircleDepthToSpace *) final;
+ // void visit(const luci::CircleDepthwiseConv2D *) final;
+ // void visit(const luci::CircleDequantize *) final;
+ void visit(const luci::CircleDiv *) final;
+ // void visit(const luci::CircleElu *) final;
+ // void visit(const luci::CircleEqual *) final;
+ // void visit(const luci::CircleExp *) final;
+ // void visit(const luci::CircleExpandDims *) final;
+ // void visit(const luci::CircleFakeQuant *) final;
+ // void visit(const luci::CircleFill *) final;
+ // void visit(const luci::CircleFloor *) final;
+ // void visit(const luci::CircleFloorDiv *) final;
+ // void visit(const luci::CircleFloorMod *) final;
+ // void visit(const luci::CircleFullyConnected *) final;
+ // void visit(const luci::CircleGather *) final;
+ // void visit(const luci::CircleGatherNd *) final;
+ // void visit(const luci::CircleGreater *) final;
+ // void visit(const luci::CircleGreaterEqual *) final;
+ // void visit(const luci::CircleIf *) final;
+ // void visit(const luci::CircleL2Normalize *) final;
+ // void visit(const luci::CircleL2Pool2D *) final;
+ // void visit(const luci::CircleLeakyRelu *) final;
+ // void visit(const luci::CircleLess *) final;
+ // void visit(const luci::CircleLessEqual *) final;
+ // void visit(const luci::CircleLocalResponseNormalization *) final;
+ // void visit(const luci::CircleLog *) final;
+ // void visit(const luci::CircleLogicalAnd *) final;
+ // void visit(const luci::CircleLogicalNot *) final;
+ // void visit(const luci::CircleLogicalOr *) final;
+ // void visit(const luci::CircleLogistic *) final;
+ // void visit(const luci::CircleLogSoftmax *) final;
+ // void visit(const luci::CircleMatrixDiag *) final;
+ // void visit(const luci::CircleMatrixSetDiag *) final;
+ // void visit(const luci::CircleMaximum *) final;
+ // void visit(const luci::CircleMaxPool2D *) final;
+ void visit(const luci::CircleMean *) final;
+ // void visit(const luci::CircleMinimum *) final;
+ // void visit(const luci::CircleMirrorPad *) final;
+ void visit(const luci::CircleMul *) final;
+ // void visit(const luci::CircleNeg *) final;
+ // void visit(const luci::CircleNonMaxSuppressionV4 *) final;
+ // void visit(const luci::CircleNonMaxSuppressionV5 *) final;
+ // void visit(const luci::CircleNotEqual *) final;
+ // void visit(const luci::CircleOneHot *) final;
+ // void visit(const luci::CirclePack *) final;
+ // void visit(const luci::CirclePad *) final;
+ // void visit(const luci::CirclePadV2 *) final;
+ void visit(const luci::CirclePow *) final;
+ // void visit(const luci::CirclePRelu *) final;
+ // void visit(const luci::CircleRange *) final;
+ // void visit(const luci::CircleRank *) final;
+ // void visit(const luci::CircleReduceAny *) final;
+ // void visit(const luci::CircleReduceMax *) final;
+ // void visit(const luci::CircleReduceMin *) final;
+ // void visit(const luci::CircleReduceProd *) final;
+ // void visit(const luci::CircleRelu *) final;
+ // void visit(const luci::CircleRelu6 *) final;
+ // void visit(const luci::CircleReluN1To1 *) final;
+ // void visit(const luci::CircleReshape *) final;
+ // void visit(const luci::CircleResizeBilinear *) final;
+ // void visit(const luci::CircleResizeNearestNeighbor *) final;
+ // void visit(const luci::CircleReverseSequence *) final;
+ // void visit(const luci::CircleReverseV2 *) final;
+ // void visit(const luci::CircleRound *) final;
+ void visit(const luci::CircleRsqrt *) final;
+ // void visit(const luci::CircleScatterNd *) final;
+ // void visit(const luci::CircleSegmentSum *) final;
+ // void visit(const luci::CircleSelect *) final;
+ // void visit(const luci::CircleSelectV2 *) final;
+ // void visit(const luci::CircleShape *) final;
+ // void visit(const luci::CircleSin *) final;
+ // void visit(const luci::CircleSlice *) final;
+ // void visit(const luci::CircleSoftmax *) final;
+ // void visit(const luci::CircleSpaceToBatchND *) final;
+ // void visit(const luci::CircleSpaceToDepth *) final;
+ // void visit(const luci::CircleSparseToDense *) final;
+ // void visit(const luci::CircleSplit *) final;
+ // void visit(const luci::CircleSplitV *) final;
+ void visit(const luci::CircleSqrt *) final;
+ // void visit(const luci::CircleSquare *) final;
+ void visit(const luci::CircleSquaredDifference *) final;
+ // void visit(const luci::CircleSqueeze *) final;
+ // void visit(const luci::CircleStridedSlice *) final;
+ void visit(const luci::CircleSub *) final;
+ // void visit(const luci::CircleSum *) final;
+ // void visit(const luci::CircleTanh *) final;
+ // void visit(const luci::CircleTile *) final;
+ // void visit(const luci::CircleTopKV2 *) final;
+ // void visit(const luci::CircleTranspose *) final;
+ // void visit(const luci::CircleTransposeConv *) final;
+ // void visit(const luci::CircleUnidirectionalSequenceLSTM *) final;
+ // void visit(const luci::CircleUnique *) final;
+ // void visit(const luci::CircleUnpack *) final;
+ // void visit(const luci::CircleWhere *) final;
+ // void visit(const luci::CircleWhile *) final;
+ // void visit(const luci::CircleZerosLike *) final;
+
+ // Circle Only
+ // void visit(const luci::CircleBCQFullyConnected *) final;
+ // void visit(const luci::CircleBCQGather *) final;
+ // void visit(const luci::CircleInstanceNorm *) final;
+
+ // Virtual
+ // void visit(const luci::CircleCustomOut *) final;
+ // void visit(const luci::CircleIfOut *) final;
+ // void visit(const luci::CircleInput *) final;
+ // void visit(const luci::CircleNonMaxSuppressionV4Out *) final;
+ // void visit(const luci::CircleNonMaxSuppressionV5Out *) final;
+ // void visit(const luci::CircleOutput *) final;
+ // void visit(const luci::CircleOutputDummy *) final;
+ // void visit(const luci::CircleOutputExclude *) final;
+ // void visit(const luci::CircleSplitOut *) final;
+ // void visit(const luci::CircleSplitVOut *) final;
+ // void visit(const luci::CircleTopKV2Out *) final;
+ // void visit(const luci::CircleUniqueOut *) final;
+ // void visit(const luci::CircleUnpackOut *) final;
+ // void visit(const luci::CircleWhileOut *) final;
+
+public:
+ luci::CircleNode *find_clone(const luci::CircleNode *node);
+
+protected:
+ luci::CloneContext &_clonecontext;
+};
+
+/**
+ * @brief Connect cloned node from input node
+ */
+void clone_connect(const luci::CircleNode *node, luci::CloneContext &clonecontext);
+
+} // namespace luci
+
+#endif // __LUCI_PARTITION_CONNECT_NODE_H__
diff --git a/compiler/luci/partition/src/ConnectNode.test.cpp b/compiler/luci/partition/src/ConnectNode.test.cpp
new file mode 100644
index 000000000..a2009c654
--- /dev/null
+++ b/compiler/luci/partition/src/ConnectNode.test.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.test.h"
+
+// This file validates "ConnectNode.test.h". Please DO NOT remove this file.
diff --git a/compiler/luci/partition/src/ConnectNode.test.h b/compiler/luci/partition/src/ConnectNode.test.h
new file mode 100644
index 000000000..f7333ff99
--- /dev/null
+++ b/compiler/luci/partition/src/ConnectNode.test.h
@@ -0,0 +1,146 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CONNECT_NODE_TEST_H__
+#define __CONNECT_NODE_TEST_H__
+
+#include "ConnectNode.h"
+
+#include <luci/Service/CircleNodeClone.h>
+#include <luci/test/TestIOGraph.h>
+
+#include <loco/IR/Graph.h>
+
+#include <initializer_list>
+#include <memory>
+#include <stdexcept>
+#include <vector>
+
+namespace luci
+{
+namespace test
+{
+
+template <unsigned N> class TestIsOGraph : public TestIsGraphlet<N>, public TestOGraphlet
+{
+public:
+ TestIsOGraph() = default;
+
+public:
+ virtual void init(const std::initializer_list<ShapeU32> shape_in, const ShapeU32 shape_out)
+ {
+ if (shape_in.size() != N)
+ throw std::runtime_error("Failed to init TestIsOGraph");
+
+ TestIsGraphlet<N>::init(TestIsGraphlet<N>::g(), shape_in);
+ TestOGraphlet::init(TestIsGraphlet<N>::g(), shape_out);
+ }
+};
+
+template <class T> class NodeGraphletT
+{
+public:
+ virtual void init(loco::Graph *g)
+ {
+ _node = g->nodes()->create<T>();
+ _node->dtype(loco::DataType::S32);
+ _node->name("node");
+ }
+
+ T *node(void) const { return _node; }
+
+protected:
+ T *_node{nullptr};
+};
+
+template <class T> class NodeIsGraphletT
+{
+public:
+ virtual void init(loco::Graph *g, uint32_t n)
+ {
+ _node = g->nodes()->create<T>(n);
+ _node->dtype(loco::DataType::S32);
+ _node->name("node");
+ }
+
+ T *node(void) const { return _node; }
+
+protected:
+ T *_node{nullptr};
+};
+
+/**
+ * @brief ConnectionTestHelper provides common framework for testing
+ * cloned CircleNode connection
+ */
+class ConnectionTestHelper
+{
+public:
+ ConnectionTestHelper() { _graph_clone = loco::make_graph(); }
+
+public:
+ template <unsigned N> void prepare_inputs(TestIsOGraph<N> *isograph)
+ {
+ assert(N == isograph->num_inputs());
+
+ for (uint32_t i = 0; i < N; ++i)
+ {
+ auto *input = _graph_clone->nodes()->create<luci::CircleInput>();
+ luci::copy_common_attributes(isograph->input(i), input);
+ _clonectx.emplace(isograph->input(i), input);
+ _inputs.push_back(input);
+ }
+ }
+
+ /**
+ * @note prepare_inputs_miss is for negative testing
+ */
+ template <unsigned N> void prepare_inputs_miss(TestIsOGraph<N> *isograph)
+ {
+ assert(N == isograph->num_inputs());
+
+ for (uint32_t i = 0; i < N; ++i)
+ {
+ auto *input = _graph_clone->nodes()->create<luci::CircleInput>();
+ luci::copy_common_attributes(isograph->input(i), input);
+ if (i != 0)
+ _clonectx.emplace(isograph->input(i), input);
+ _inputs.push_back(input);
+ }
+ }
+
+ void clone_connect(luci::CircleNode *node, luci::CircleNode *clone)
+ {
+ _clonectx.emplace(node, clone);
+
+ luci::clone_connect(node, _clonectx);
+ }
+
+public:
+ loco::Graph *graph_clone(void) { return _graph_clone.get(); }
+
+ luci::CircleNode *inputs(uint32_t idx) { return _inputs.at(idx); }
+
+protected:
+ luci::CloneContext _clonectx;
+ std::vector<luci::CircleInput *> _inputs;
+ std::unique_ptr<loco::Graph> _graph_clone; // graph for clones
+};
+
+} // namespace test
+} // namespace luci
+
+#endif // __CONNECT_NODE_TEST_H__
diff --git a/compiler/luci/partition/src/Nodes/CircleAdd.cpp b/compiler/luci/partition/src/Nodes/CircleAdd.cpp
new file mode 100644
index 000000000..d393997e9
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleAdd.cpp
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleAdd *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleAdd *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleAdd *node) { connect(this, node); }
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/Nodes/CircleAdd.test.cpp b/compiler/luci/partition/src/Nodes/CircleAdd.test.cpp
new file mode 100644
index 000000000..e457b83d2
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleAdd.test.cpp
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleAdd>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g) override
+ {
+ NodeGraphletT<luci::CircleAdd>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Add)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAdd *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAdd *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Add_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAdd *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleAdd *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
diff --git a/compiler/luci/service/src/Nodes/CircleInput.cpp b/compiler/luci/partition/src/Nodes/CircleConst.cpp
index 24eab7bd6..118cd8de2 100644
--- a/compiler/luci/service/src/Nodes/CircleInput.cpp
+++ b/compiler/luci/partition/src/Nodes/CircleConst.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,14 +14,14 @@
* limitations under the License.
*/
-#include <luci/Service/CircleShapeSignatureInference.h>
+#include "ConnectNode.h"
namespace luci
{
-ShapeSignature ssinf::Algorithm::visit(const luci::CircleInput *node)
+void ConnectNode::visit(const luci::CircleConst *)
{
- return node->shape_signature();
+ // Nothing to do
}
} // namespace luci
diff --git a/compiler/luci/partition/src/Nodes/CircleDiv.cpp b/compiler/luci/partition/src/Nodes/CircleDiv.cpp
new file mode 100644
index 000000000..480338542
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleDiv.cpp
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleDiv *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleDiv *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleDiv *node) { connect(this, node); }
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/Nodes/CircleDiv.test.cpp b/compiler/luci/partition/src/Nodes/CircleDiv.test.cpp
new file mode 100644
index 000000000..226932337
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleDiv.test.cpp
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleDiv>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g) override
+ {
+ NodeGraphletT<luci::CircleDiv>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Div)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDiv *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDiv *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Div_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDiv *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleDiv *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
diff --git a/compiler/luci/partition/src/Nodes/CircleMean.cpp b/compiler/luci/partition/src/Nodes/CircleMean.cpp
new file mode 100644
index 000000000..b634e5838
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleMean.cpp
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleMean *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleMean *>(cn->find_clone(node));
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
+ luci::CircleNode *reduction_indices =
+ loco::must_cast<luci::CircleNode *>(node->reduction_indices());
+
+ cloned->input(cn->find_clone(input));
+ cloned->reduction_indices(cn->find_clone(reduction_indices));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleMean *node) { connect(this, node); }
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/Nodes/CircleMul.cpp b/compiler/luci/partition/src/Nodes/CircleMul.cpp
new file mode 100644
index 000000000..2cd2b4038
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleMul.cpp
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleMul *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleMul *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleMul *node) { connect(this, node); }
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/Nodes/CircleMul.test.cpp b/compiler/luci/partition/src/Nodes/CircleMul.test.cpp
new file mode 100644
index 000000000..99cf0824d
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleMul.test.cpp
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleMul>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ NodeGraphletT<luci::CircleMul>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Mul)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMul *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMul *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Mul_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMul *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleMul *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
diff --git a/compiler/luci/partition/src/Nodes/CirclePow.cpp b/compiler/luci/partition/src/Nodes/CirclePow.cpp
new file mode 100644
index 000000000..fb180ee69
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CirclePow.cpp
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CirclePow *node)
+{
+ auto *cloned = loco::must_cast<luci::CirclePow *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CirclePow *node) { connect(this, node); }
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/Nodes/CircleRsqrt.cpp b/compiler/luci/partition/src/Nodes/CircleRsqrt.cpp
new file mode 100644
index 000000000..03e64aad0
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleRsqrt.cpp
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleRsqrt *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleRsqrt *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleRsqrt *node) { connect(this, node); }
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/Nodes/CircleSqrt.cpp b/compiler/luci/partition/src/Nodes/CircleSqrt.cpp
new file mode 100644
index 000000000..f737aac8d
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleSqrt.cpp
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSqrt *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSqrt *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+
+ cloned->x(cn->find_clone(x));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSqrt *node) { connect(this, node); }
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/Nodes/CircleSquaredDifference.cpp b/compiler/luci/partition/src/Nodes/CircleSquaredDifference.cpp
new file mode 100644
index 000000000..40dd31706
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleSquaredDifference.cpp
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSquaredDifference *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSquaredDifference *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSquaredDifference *node) { connect(this, node); }
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/Nodes/CircleSub.cpp b/compiler/luci/partition/src/Nodes/CircleSub.cpp
new file mode 100644
index 000000000..8ac294b7b
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleSub.cpp
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+namespace
+{
+
+void connect(luci::ConnectNode *cn, const luci::CircleSub *node)
+{
+ auto *cloned = loco::must_cast<luci::CircleSub *>(cn->find_clone(node));
+
+ luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x());
+ luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y());
+
+ cloned->x(cn->find_clone(x));
+ cloned->y(cn->find_clone(y));
+}
+
+} // namespace
+
+namespace luci
+{
+
+void ConnectNode::visit(const luci::CircleSub *node) { connect(this, node); }
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/Nodes/CircleSub.test.cpp b/compiler/luci/partition/src/Nodes/CircleSub.test.cpp
new file mode 100644
index 000000000..7c0d83745
--- /dev/null
+++ b/compiler/luci/partition/src/Nodes/CircleSub.test.cpp
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConnectNode.h"
+
+#include "ConnectNode.test.h"
+
+#include <luci/Service/CircleNodeClone.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class NodeGraphlet : public NodeGraphletT<luci::CircleSub>
+{
+public:
+ NodeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ NodeGraphletT<luci::CircleSub>::init(g);
+
+ _node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ }
+};
+
+class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
+{
+public:
+ TestNodeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIsOGraph<2>::init({shape, shape}, shape);
+ NodeGraphlet::init(g());
+
+ node()->x(input(0));
+ node()->y(input(1));
+
+ output()->from(node());
+ }
+};
+
+} // namespace
+
+TEST(ConnectNodeTest, connect_Sub)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSub *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSub *>(clone));
+
+ cth.clone_connect(node, clone);
+
+ ASSERT_EQ(2, clone->arity());
+ ASSERT_EQ(cth.inputs(0), clone->arg(0));
+ ASSERT_EQ(cth.inputs(1), clone->arg(1));
+}
+
+TEST(ConnectNodeTest, connect_Sub_NEG)
+{
+ TestNodeGraph tng;
+ tng.init({2, 3});
+
+ ConnectionTestHelper cth;
+ cth.prepare_inputs_miss(&tng);
+
+ auto *node = tng.node();
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSub *>(node));
+
+ auto *clone = luci::clone_node(node, cth.graph_clone());
+ ASSERT_NO_THROW(loco::must_cast<luci::CircleSub *>(clone));
+
+ EXPECT_ANY_THROW(cth.clone_connect(node, clone));
+}
diff --git a/compiler/luci/partition/src/Partition.cpp b/compiler/luci/partition/src/Partition.cpp
new file mode 100644
index 000000000..cc7106ca9
--- /dev/null
+++ b/compiler/luci/partition/src/Partition.cpp
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PartitionIR.h"
+#include "PartitionIRDump.h"
+#include "PartitionPGroups.h"
+#include "PartitionMerge.h"
+#include "PartitionCleanup.h"
+#include "PartitionPModules.h"
+#include "PartitionPModulesDump.h"
+
+#include "luci/Partition.h"
+#include "luci/Log.h"
+
+#include <cassert>
+
+namespace luci
+{
+
+/**
+ * @brief This will return Partitioned Modules object
+ */
+PartedModules apply(Module *source, const PartitionTable &partition)
+{
+ assert(source != nullptr);
+
+ LOGGER(l);
+
+ auto pgroups = produce_pgroups(source, partition);
+ INFO(l) << "--- Partition Graph (1)------------------------";
+ INFO(l) << pgroups.get();
+
+ auto mpgroups = merge_pgroups(pgroups.get());
+ INFO(l) << "--- Partition Graph (2)------------------------";
+ INFO(l) << mpgroups.get();
+
+ remove_unused_inputoutputs(mpgroups.get(), source);
+ INFO(l) << "--- Partition Graph (3)------------------------";
+ INFO(l) << mpgroups.get();
+
+ auto pmodules = produce_pmodules(mpgroups.get());
+ INFO(l) << "--- Modules -----------------------------------";
+ INFO(l) << &pmodules;
+
+ return pmodules;
+}
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/Partition.test.cpp b/compiler/luci/partition/src/Partition.test.cpp
new file mode 100644
index 000000000..9e24c441c
--- /dev/null
+++ b/compiler/luci/partition/src/Partition.test.cpp
@@ -0,0 +1,83 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Partition.h"
+
+#include <luci/test/TestIOGraph.h>
+
+#include <luci/IR/Nodes/CircleSqrt.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class SqrtGraphlet
+{
+public:
+ SqrtGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 input_shape)
+ {
+ _sqrt = g->nodes()->create<luci::CircleSqrt>();
+ _sqrt->dtype(loco::DataType::S32);
+ _sqrt->name("sqrt");
+ }
+
+protected:
+ luci::CircleSqrt *_sqrt = nullptr;
+};
+
+class SqrtGraph : public TestIOGraph, public SqrtGraphlet
+{
+public:
+ SqrtGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ SqrtGraphlet::init(g(), shape);
+
+ _sqrt->x(input());
+
+ output()->from(_sqrt);
+ }
+};
+
+} // namespace
+
+TEST(PartitionTest, simple_apply)
+{
+ luci::Module module;
+
+ SqrtGraph g;
+ g.init({3, 3});
+ g.transfer_to(&module);
+
+ luci::PartitionTable pt;
+ pt.default_group = "A";
+
+ auto pms = apply(&module, pt);
+
+ ASSERT_EQ(1, pms.pmodules.size());
+
+ auto &pm = *pms.pmodules.begin();
+ ASSERT_NE(nullptr, pm.module->graph());
+}
diff --git a/compiler/luci/partition/src/PartitionCleanup.cpp b/compiler/luci/partition/src/PartitionCleanup.cpp
new file mode 100644
index 000000000..6545295df
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionCleanup.cpp
@@ -0,0 +1,139 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PartitionCleanup.h"
+
+#include "luci/Log.h"
+
+namespace
+{
+
+using CircleNodes = std::vector<luci::CircleNode *>;
+
+/**
+ * @note Original source outputs should be outputs
+ */
+void gather_graph_outputs(CircleNodes &nodes, const luci::Module *source)
+{
+ // graph outputs are treated as used
+ auto graph = source->graph();
+ for (uint32_t n = 0; n < graph->outputs()->size(); ++n)
+ {
+ auto output = luci::output_node(graph, n); // output is CircleOutput
+ assert(output != nullptr);
+
+ auto node = loco::must_cast<luci::CircleNode *>(output->from());
+
+ nodes.push_back(node);
+ }
+
+ // TODO add unused virtual outputs
+}
+
+/**
+ * @note If one PGroup requires an input, that input should be an output
+ * from another PGroup
+ */
+void gather_pgroups_outputs(CircleNodes &nodes, const luci::PGroups *pgroups)
+{
+ // input of a pgroup is used output
+ for (auto &pgroup : pgroups->pgroups)
+ {
+ for (auto input : pgroup->inputs)
+ {
+ nodes.push_back(input);
+ }
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+void remove_unused_inputoutputs(luci::PGroups *pgroups, const luci::Module *source)
+{
+ assert(source != nullptr);
+ assert(pgroups != nullptr);
+
+ LOGGER(l);
+
+ // TODO support multiple subgraph
+ assert(source->size() == 1);
+
+ INFO(l) << "--- Cleanup unused inputs/outputs";
+
+ // remove input within same pgroup
+ for (auto &pgroup : pgroups->pgroups)
+ {
+ bool changed;
+ do
+ {
+ changed = false;
+ for (auto it = pgroup->inputs.begin(); it != pgroup->inputs.end(); ++it)
+ {
+ auto input = *it;
+ if (pgroups->pgroup_of(input) == pgroup.get())
+ {
+ INFO(l) << " Cleanup input " << input->name() << " from group " << pgroup->group;
+ pgroup->inputs.erase(it);
+ changed = true;
+ break;
+ }
+ // NOTE CircleConst is one of input type, as they are registered as
+ // input to some node and then (should be) merged.
+ // Remove if this input is CircleConst
+ if (dynamic_cast<CircleConst *>(input) != nullptr)
+ {
+ INFO(l) << " Cleanup CircleConst " << input->name() << " from group " << pgroup->group;
+ pgroup->inputs.erase(it);
+ changed = true;
+ break;
+ }
+ }
+ } while (changed);
+ }
+
+ // remove unused output(s)
+ // 'used_outputs' will hold actual used outputs for all PGroups
+ CircleNodes used_outputs;
+
+ gather_graph_outputs(used_outputs, source);
+ gather_pgroups_outputs(used_outputs, pgroups);
+
+ for (auto &pgroup : pgroups->pgroups)
+ {
+ bool changed;
+ do
+ {
+ changed = false;
+ for (auto it = pgroup->outputs.begin(); it != pgroup->outputs.end(); ++it)
+ {
+ auto output = *it;
+ auto oit = std::find(used_outputs.begin(), used_outputs.end(), output);
+ if (oit == used_outputs.end())
+ {
+ INFO(l) << " Cleanup output " << output->name() << " from group " << pgroup->group;
+ pgroup->outputs.erase(it);
+ changed = true;
+ break;
+ }
+ }
+ } while (changed);
+ }
+}
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/PartitionCleanup.h b/compiler/luci/partition/src/PartitionCleanup.h
new file mode 100644
index 000000000..f81b4a7cb
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionCleanup.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PARTITON_CLEANUP_H__
+#define __LUCI_PARTITON_CLEANUP_H__
+
+#include "PartitionIR.h"
+
+#include <luci/IR/Module.h>
+
+namespace luci
+{
+
+/**
+ * @brief This will remove unused inputs/outputs in each pgroup of pgroups
+ */
+void remove_unused_inputoutputs(luci::PGroups *, const luci::Module *);
+
+} // namespace luci
+
+#endif // __LUCI_PARTITON_CLEANUP_H__
diff --git a/compiler/luci/partition/src/PartitionIR.cpp b/compiler/luci/partition/src/PartitionIR.cpp
new file mode 100644
index 000000000..ebd6b25fa
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionIR.cpp
@@ -0,0 +1,101 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PartitionIR.h"
+#include "CircleOpCode.h"
+
+#include "luci/Log.h"
+
+#include <cassert>
+#include <ostream>
+#include <iostream>
+
+namespace luci
+{
+
+std::unique_ptr<PGroups> PGroups::make_copy(void) const
+{
+ auto d_pgroups = std::make_unique<luci::PGroups>();
+
+ for (auto &s_pgroup : pgroups)
+ {
+ // make a copy of s_pgroup to d_pgroup
+ std::unique_ptr<luci::PGroup> d_pgroup = std::make_unique<luci::PGroup>();
+
+ d_pgroup->group = s_pgroup->group;
+ d_pgroup->id = s_pgroup->id;
+
+ for (auto &pnode : s_pgroup->pnodes)
+ {
+ auto pnodec = std::make_unique<luci::PNode>();
+ pnodec->node = pnode->node;
+ pnodec->group = pnode->group;
+ pnodec->pgroup = d_pgroup.get();
+ d_pgroup->pnodes.push_back(std::move(pnodec));
+ }
+
+ for (auto &input : s_pgroup->inputs)
+ d_pgroup->inputs.push_back(input);
+
+ for (auto &output : s_pgroup->outputs)
+ d_pgroup->outputs.push_back(output);
+
+ // copy node2group
+ for (auto it = node2group.begin(); it != node2group.end(); ++it)
+ d_pgroups->node2group[it->first] = it->second;
+
+ // build id2pgroup
+ d_pgroups->id2pgroup[d_pgroup->id] = d_pgroup.get();
+
+ d_pgroups->pgroups.push_back(std::move(d_pgroup));
+ // note: d_pgroup is now nullptr as it's moved
+ }
+
+ return std::move(d_pgroups);
+}
+
+std::string PGroups::group_of(luci::CircleNode *node) const
+{
+ assert(node != nullptr);
+
+ LOGGER(l);
+
+ auto it = node2group.find(node);
+ if (it == node2group.end())
+ {
+ INFO(l) << "PGroups::group_of " << node << "(" << node->name() << ") not found" << std::endl;
+ return "";
+ }
+ return it->second;
+}
+
+const PGroup *PGroups::pgroup_of(luci::CircleNode *node) const
+{
+ assert(node != nullptr);
+
+ for (auto &pgroup : pgroups)
+ {
+ for (auto &pnode : pgroup->pnodes)
+ {
+ if (node == pnode->node)
+ return pgroup.get();
+ }
+ }
+ // node maybe graph input (CircleInput)
+ return nullptr;
+}
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/PartitionIR.h b/compiler/luci/partition/src/PartitionIR.h
new file mode 100644
index 000000000..852e38cc0
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionIR.h
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PARTITION_IR_H__
+#define __LUCI_PARTITION_IR_H__
+
+#include <luci/IR/CircleNodes.h>
+
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace luci
+{
+
+struct PGroup;
+
+/**
+ * @brief Partition Node with CircleNode with group name
+ * @note node just points to source luci::CircleNode, NOT the cloned node
+ * CloneContext is used to find cloned node from source node
+ */
+struct PNode
+{
+ const luci::CircleNode *node = nullptr;
+ std::string group;
+
+ const PGroup *pgroup = nullptr;
+};
+
+/**
+ * @brief Partition Group with Partition Nodes of same group and I/Os nodes
+ */
+struct PGroup
+{
+ std::vector<std::unique_ptr<PNode>> pnodes;
+ std::string group;
+ uint32_t id = 0;
+
+ // I/O while partitioning
+ std::vector<luci::CircleNode *> inputs;
+ std::vector<luci::CircleNode *> outputs;
+};
+
+struct PGroups
+{
+ std::vector<std::unique_ptr<PGroup>> pgroups;
+
+ // node2group is to find group key from source node
+ std::map<const luci::CircleNode *, std::string> node2group;
+
+ // id2pngroup is to find *pngroup from pngroup id
+ std::map<uint32_t, PGroup *> id2pgroup;
+
+ // default group key for reference
+ std::string default_group;
+
+public:
+ /**
+ * @brief return a copy of PGroups
+ */
+ std::unique_ptr<PGroups> make_copy(void) const;
+
+ /**
+ * @brief return group key of node, empty string if not found
+ */
+ std::string group_of(luci::CircleNode *node) const;
+
+ /**
+ * @brief return holding pgroup of node, nullptr if not found
+ */
+ const PGroup *pgroup_of(luci::CircleNode *node) const;
+};
+
+} // namespace luci
+
+#endif // __LUCI_PARTITION_IR_H__
diff --git a/compiler/luci/partition/src/PartitionIR.test.cpp b/compiler/luci/partition/src/PartitionIR.test.cpp
new file mode 100644
index 000000000..4c051a96d
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionIR.test.cpp
@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PartitionIR.h"
+
+// NOTE any node will do for testing
+#include <luci/IR/Nodes/CircleAdd.h>
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+TEST(PartitionIRTest, PNode_ctor)
+{
+ auto g = loco::make_graph();
+ auto node = g->nodes()->create<luci::CircleAdd>();
+
+ luci::PNode pnode;
+ pnode.node = node;
+
+ ASSERT_NE(nullptr, pnode.node);
+ ASSERT_EQ(nullptr, pnode.pgroup);
+}
+
+// TODO add more tests with luci::PNode
+
+TEST(PartitionIRTest, PGroup_ctor)
+{
+ auto g = loco::make_graph();
+ auto node = g->nodes()->create<luci::CircleAdd>();
+
+ luci::PGroup pgroup;
+ auto pnode = std::make_unique<luci::PNode>();
+ pnode->node = node;
+
+ pgroup.pnodes.push_back(std::move(pnode));
+
+ ASSERT_NE(pgroup.pnodes.end(), pgroup.pnodes.begin());
+ ASSERT_EQ(0, pgroup.inputs.size());
+ ASSERT_EQ(0, pgroup.outputs.size());
+}
+
+// TODO add more tests with luci::PGroup
+
+TEST(PartitionIRTest, PGroups_ctor)
+{
+ auto g = loco::make_graph();
+ auto node = g->nodes()->create<luci::CircleAdd>();
+
+ auto pnode = std::make_unique<luci::PNode>();
+ pnode->node = node;
+
+ auto pgroup = std::make_unique<luci::PGroup>();
+ pgroup->pnodes.push_back(std::move(pnode));
+
+ luci::PGroups pgroups;
+ pgroups.pgroups.push_back(std::move(pgroup));
+
+ ASSERT_NE(pgroups.pgroups.end(), pgroups.pgroups.begin());
+}
+
+// TODO add more tests with luci::PGroups
diff --git a/compiler/luci/partition/src/PartitionIRDump.cpp b/compiler/luci/partition/src/PartitionIRDump.cpp
new file mode 100644
index 000000000..4f2c26800
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionIRDump.cpp
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PartitionIRDump.h"
+
+#include "CircleOpCode.h"
+
+#include <iostream>
+
+namespace luci
+{
+
+void dump(std::ostream &os, const PNode *pnode)
+{
+ os << "PNode: " << pnode->group << ", " << pnode->node << ":" << luci::opcode_name(pnode->node)
+ << ":" << pnode->node->name() << std::endl;
+}
+
+void dump(std::ostream &os, const PGroup *pgroup)
+{
+ os << "--- PGroup: " << pgroup->group << std::endl;
+ os << "Input(s): ";
+ for (auto &node_in : pgroup->inputs)
+ os << node_in->name() << " ";
+ os << std::endl;
+ for (auto &pnode : pgroup->pnodes)
+ {
+ dump(os, pnode.get());
+ }
+ os << "Output(s): ";
+ for (auto &node_out : pgroup->outputs)
+ os << node_out->name() << " ";
+ os << std::endl;
+}
+
+void dump(std::ostream &os, const PGroups *pgroups)
+{
+ for (auto &pgroup : pgroups->pgroups)
+ {
+ dump(os, pgroup.get());
+ }
+ os << "--- Node2Group items: " << std::endl;
+ for (auto it = pgroups->node2group.begin(); it != pgroups->node2group.end(); ++it)
+ {
+ auto node = it->first;
+ auto group = it->second;
+ os << " Node: " << node << "(" << node->name() << "): " << group << std::endl;
+ }
+}
+
+} // namespace luci
+
+std::ostream &operator<<(std::ostream &os, const luci::PGroups *pgroups)
+{
+ luci::dump(os, pgroups);
+ return os;
+}
diff --git a/compiler/luci/partition/src/PartitionIRDump.h b/compiler/luci/partition/src/PartitionIRDump.h
new file mode 100644
index 000000000..8a4b3f579
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionIRDump.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PARTITION_IR_DUMP_H__
+#define __LUCI_PARTITION_IR_DUMP_H__
+
+#include "PartitionIR.h"
+
+#include <iostream>
+
+namespace luci
+{
+
+void dump(std::ostream &os, const PNode *pnode);
+void dump(std::ostream &os, const PGroup *pgroup);
+void dump(std::ostream &os, const PGroups *pgroups);
+
+} // namespace luci
+
+std::ostream &operator<<(std::ostream &os, const luci::PGroups *pgroups);
+
+#endif // __LUCI_PARTITION_IR_DUMP_H__
diff --git a/compiler/luci/partition/src/PartitionMerge.cpp b/compiler/luci/partition/src/PartitionMerge.cpp
new file mode 100644
index 000000000..038fc2a0c
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionMerge.cpp
@@ -0,0 +1,207 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PartitionMerge.h"
+
+#include <algorithm>
+
+namespace
+{
+
+/**
+ * @brief return true if pgroup_i output is one of the inputs of pgroup
+ */
+bool is_input_of(const luci::PGroup *pgroup_i, const luci::PGroup *pgroup)
+{
+ for (auto *output : pgroup_i->outputs)
+ {
+ for (auto *input : pgroup->inputs)
+ {
+ if (input == output)
+ return true;
+ }
+ }
+ return false;
+}
+
+/**
+ * @brief return true if there is only one input or all the inputs have same group
+ * @note pgroups is used to find group of pgroup
+ */
+bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
+{
+ assert(pgroups != nullptr);
+ assert(pgroup != nullptr);
+
+ const luci::PGroup *input_pgroup = nullptr;
+ std::string group;
+ for (auto &input : pgroup->inputs)
+ {
+ auto input_group = pgroups->group_of(input);
+ // NOTE: all the nodes should be registered and return should be valid group.
+ // convert_to_proups() should ensure this.
+ // assert here to find if there is any problem with this.
+ assert(not input_group.empty());
+ if (input_group.empty())
+ input_group = pgroups->default_group;
+
+ if (group.empty())
+ group = input_group;
+ else
+ {
+ if (group != input_group)
+ return false;
+ }
+ // if there are multiple inputs, all the inputs should be in same pgroup
+ // https://github.com/Samsung/ONE/issues/6230#issuecomment-801618150
+ // https://github.com/Samsung/ONE/issues/6230#issuecomment-801680531
+ auto pgroup_input = pgroups->pgroup_of(input);
+ if (pgroup_input != nullptr)
+ {
+ if (input_pgroup == nullptr)
+ input_pgroup = pgroup_input;
+ else
+ {
+ if (input_pgroup != pgroup_input)
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+/**
+ * @brief merge pgroup into pgroup_i
+ * @note output of pgroup_i should be input of pgroup
+ */
+void merge_into(luci::PGroup *pgroup, luci::PGroup *pgroup_i)
+{
+ for (auto &pnode : pgroup->pnodes)
+ {
+ // update pgroup for this pnode
+ pnode->pgroup = pgroup_i;
+ assert(pnode->group == pgroup_i->group);
+
+ // we don't need to add this in topological order:
+ // all the nodes will be created first then connection will be held
+ pgroup_i->pnodes.push_back(std::move(pnode));
+ // note: pnode is now nullptr as it's moved into pgroup_i->pnodes
+ }
+
+ for (auto &input : pgroup->inputs)
+ {
+ // add inputs of pgroup to pgroup_i if not member of pgroup_i
+ bool found_in_pgroup_i = false;
+ for (auto &pnode : pgroup_i->pnodes)
+ {
+ if (input == pnode->node)
+ {
+ found_in_pgroup_i = true;
+ break;
+ }
+ }
+ // skip if this input is already in the inputs
+ auto fit = std::find(pgroup_i->inputs.begin(), pgroup_i->inputs.end(), input);
+ if (fit != pgroup_i->inputs.end())
+ {
+ found_in_pgroup_i = true;
+ }
+ // note: if we force found_in_pgroup_i to false, for testing there will be
+ // unnecessary inputs
+ if (not found_in_pgroup_i)
+ {
+ // node input maybe in another pgroup
+ pgroup_i->inputs.push_back(input);
+ }
+ }
+ // add outputs of pgroup to pgroup_i outputs if not exist
+ for (auto &output : pgroup->outputs)
+ {
+ auto it = std::find(pgroup_i->outputs.begin(), pgroup_i->outputs.end(), output);
+ if (it == pgroup_i->outputs.end())
+ {
+ pgroup_i->outputs.push_back(output);
+ }
+ }
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * @brief This will merge pgroups with same group values in topological order
+ */
+std::unique_ptr<luci::PGroups> merge_pgroups(const luci::PGroups *s_pgroups)
+{
+ // Make a copy of pgroups to apply merge action
+ // Q) do we really need a copy?
+ auto d_pgroups = s_pgroups->make_copy();
+
+ // Merge partition graphs
+ // - This is initial implementation that works for limited networks
+ // - if A and B is same group -> if A is input of B -> ... -> merge B into A
+ auto &pgroups = d_pgroups->pgroups;
+ bool changed;
+ do
+ {
+ changed = false;
+ for (auto &pgroup_i : pgroups)
+ {
+ bool merged = false;
+ for (auto it = pgroups.begin(); it != pgroups.end(); ++it)
+ {
+ auto &pgroup = *it;
+
+ // skip if same object
+ if (pgroup->id == pgroup_i->id)
+ continue;
+ // skip if different group
+ if (pgroup->group != pgroup_i->group)
+ continue;
+ // skip if not connected
+ if (!is_input_of(pgroup_i.get(), pgroup.get()))
+ continue;
+ // skip if there are multiple inputs but inputs differ in group
+ if (!is_input_same(pgroup.get(), d_pgroups.get()))
+ continue;
+ // TODO add more condition may be needed
+
+ merge_into(pgroup.get(), pgroup_i.get());
+
+ auto eit = d_pgroups->id2pgroup.find(pgroup->id);
+ assert(eit != d_pgroups->id2pgroup.end());
+ d_pgroups->id2pgroup.erase(eit);
+
+ // remove merged pgroup from pgroups
+ pgroups.erase(it);
+
+ merged = true;
+ break;
+ }
+ if (merged)
+ {
+ changed = true;
+ break;
+ }
+ }
+ } while (changed);
+
+ return std::move(d_pgroups);
+}
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/PartitionMerge.h b/compiler/luci/partition/src/PartitionMerge.h
new file mode 100644
index 000000000..5c9fec2d2
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionMerge.h
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PARTITON_MERGE_H__
+#define __LUCI_PARTITON_MERGE_H__
+
+#include "PartitionIR.h"
+
+#include <memory>
+
+namespace luci
+{
+
+std::unique_ptr<luci::PGroups> merge_pgroups(const luci::PGroups *s_pgroups);
+
+} // namespace luci
+
+#endif // __LUCI_PARTITON_MERGE_H__
diff --git a/compiler/luci/partition/src/PartitionPGroups.cpp b/compiler/luci/partition/src/PartitionPGroups.cpp
new file mode 100644
index 000000000..594ed6c40
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionPGroups.cpp
@@ -0,0 +1,139 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PartitionPGroups.h"
+#include "PartitionIR.h"
+#include "CircleOpCode.h"
+
+#include "luci/Partition.h"
+#include "luci/Log.h"
+#include "luci/LogHelper.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+#include <loco.h>
+
+namespace
+{
+
+class IsVirtualNode final : public luci::CircleNodeVisitor<bool>
+{
+public:
+ bool visit(const luci::CircleInput *) final { return true; }
+ bool visit(const luci::CircleOutput *) final { return true; }
+ // TODO add all virtual nodes
+
+ // default is false
+ bool visit(const luci::CircleNode *) final { return false; }
+};
+
+bool check_allocate_partition(const luci::CircleNode *node)
+{
+ IsVirtualNode query;
+ if (node->accept(&query))
+ return false;
+ /**
+ * @note About CircleConst
+ * CirleConst acts like a part of some CircleNode and managing mulitiple
+ * used(referenced) CircleConst is a bit difficult if it's used across
+ * different PGroup. So we treat this different to other types.
+ * https://github.com/Samsung/ONE/issues/6230#issuecomment-809802813
+ */
+ if (dynamic_cast<const luci::CircleConst *>(node) != nullptr)
+ return false;
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
+ const luci::PartitionTable &partition)
+{
+ assert(source != nullptr);
+ // TODO support multiple subgraphs
+ assert(source->size() == 1);
+
+ LOGGER(l);
+
+ auto pgroups = std::make_unique<luci::PGroups>();
+
+ pgroups->default_group = partition.default_group;
+
+ // Create a PGroup per CircleNode: each PGroup will have one CircleNode
+ auto graph = source->graph();
+ auto nodes = graph->nodes();
+ for (uint32_t idx = 0; idx < nodes->size(); ++idx)
+ {
+ auto node = loco::must_cast<luci::CircleNode *>(nodes->at(idx));
+
+ // check if node is normal node that we are interested
+ if (check_allocate_partition(node))
+ {
+ auto opcodename = luci::opcode_name(node);
+ assert(!opcodename.empty());
+
+ auto group = partition.default_group;
+ auto it = partition.byopcodes.find(opcodename);
+ if (it != partition.byopcodes.end())
+ group = it->second;
+
+ INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group
+ << std::endl;
+
+ auto pgroup = std::make_unique<luci::PGroup>();
+ pgroup->group = group;
+ pgroup->id = idx + 1;
+
+ auto pnode = std::make_unique<luci::PNode>();
+ pnode->node = node;
+ pnode->group = group;
+ pnode->pgroup = pgroup.get();
+
+ pgroup->pnodes.push_back(std::move(pnode));
+
+ // Set input of PGroup
+ for (uint32_t in = 0; in < node->arity(); ++in)
+ {
+ auto input = loco::must_cast<luci::CircleNode *>(node->arg(in));
+ // this input maybe CircleInput in source graph
+ // --> not confident this is safe
+ pgroup->inputs.push_back(input);
+ }
+ // Set output of PGroup: node itself or multiple virtual outputs
+ // TODO support multiple virtual outputs
+ pgroup->outputs.push_back(node);
+
+ pgroups->node2group[node] = group;
+ pgroups->id2pgroup[pgroup->id] = pgroup.get();
+
+ pgroups->pgroups.push_back(std::move(pgroup));
+ }
+ else
+ {
+ INFO(l) << "Skip Op: " << node->name() << std::endl;
+ // record as default group
+ pgroups->node2group[node] = partition.default_group;
+ }
+ }
+
+ return std::move(pgroups);
+}
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/PartitionPGroups.h b/compiler/luci/partition/src/PartitionPGroups.h
new file mode 100644
index 000000000..998e11cbd
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionPGroups.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PARTITON_PGROUPS_H__
+#define __LUCI_PARTITON_PGROUPS_H__
+
+#include "PartitionIR.h"
+
+#include "luci/Partition.h"
+
+#include <luci/IR/Module.h>
+
+namespace luci
+{
+
+/**
+ * @brief This will produce a PGroups from Module and PartitionTable.
+ * @note Each PGroup will hold one CircleNode and partition key value as group.
+ * Supports only single Graph in the Module for now.
+ */
+std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
+ const luci::PartitionTable &partition);
+
+} // namespace luci
+
+#endif // __LUCI_PARTITON_PGROUPS_H__
diff --git a/compiler/luci/partition/src/PartitionPGroups.test.cpp b/compiler/luci/partition/src/PartitionPGroups.test.cpp
new file mode 100644
index 000000000..960f3cde9
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionPGroups.test.cpp
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PartitionPGroups.h"
+
+#include <luci/test/TestIOGraph.h>
+
+#include <luci/IR/Nodes/CircleSqrt.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class SqrtGraphlet
+{
+public:
+ SqrtGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 input_shape)
+ {
+ _sqrt = g->nodes()->create<luci::CircleSqrt>();
+ _sqrt->dtype(loco::DataType::S32);
+ _sqrt->name("sqrt");
+ }
+
+protected:
+ luci::CircleSqrt *_sqrt = nullptr;
+};
+
+class SqrtGraph : public TestIOGraph, public SqrtGraphlet
+{
+public:
+ SqrtGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ SqrtGraphlet::init(g(), shape);
+
+ _sqrt->x(input());
+
+ output()->from(_sqrt);
+ }
+};
+
+} // namespace
+
+TEST(PartitionPGroupsTest, simple_produce)
+{
+ luci::Module module;
+
+ SqrtGraph g;
+ g.init({3, 3});
+ g.transfer_to(&module);
+
+ luci::PartitionTable pt;
+ pt.default_group = "A";
+
+ auto pgs = produce_pgroups(&module, pt);
+
+ ASSERT_EQ(1, pgs->pgroups.size());
+}
diff --git a/compiler/luci/partition/src/PartitionPModules.cpp b/compiler/luci/partition/src/PartitionPModules.cpp
new file mode 100644
index 000000000..36f4d47a4
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionPModules.cpp
@@ -0,0 +1,203 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PartitionPModules.h"
+#include "ConnectNode.h"
+
+#include "luci/Service/CircleNodeClone.h"
+#include "luci/Log.h"
+
+#include <loco.h>
+
+namespace
+{
+
+void add_graph_input(loco::Graph *graph, luci::CircleInput *input_node)
+{
+ assert(graph != nullptr);
+ assert(input_node != nullptr);
+
+ auto graph_input = graph->inputs()->create();
+ graph_input->name(input_node->name());
+
+ // Set GraphInputOutputIndex for graph
+ input_node->index(graph_input->index());
+
+ // Data type
+ graph_input->dtype(input_node->dtype());
+
+ // Shape of GraphInput
+ auto input_shape = std::make_unique<loco::TensorShape>();
+ input_shape->rank(input_node->rank());
+ for (uint32_t r = 0; r < input_node->rank(); ++r)
+ {
+ if (input_node->dim(r).known())
+ input_shape->dim(r).set(input_node->dim(r).value());
+ }
+ graph_input->shape(std::move(input_shape));
+}
+
+void add_graph_output(loco::Graph *graph, luci::CircleOutput *output_node)
+{
+ assert(graph != nullptr);
+ assert(output_node != nullptr);
+
+ auto graph_output = graph->outputs()->create();
+ graph_output->name(output_node->name());
+
+ // Set GraphInputOutputIndex for graph
+ output_node->index(graph_output->index());
+
+ // Data type
+ graph_output->dtype(output_node->dtype());
+
+ // Shape of GraphOutput
+ auto output_shape = std::make_unique<loco::TensorShape>();
+ output_shape->rank(output_node->rank());
+ for (uint32_t r = 0; r < output_node->rank(); ++r)
+ {
+ if (output_node->dim(r).known())
+ output_shape->dim(r).set(output_node->dim(r).value());
+ }
+ graph_output->shape(std::move(output_shape));
+}
+
+/**
+ * @brief Build loco::graph from pgroup into graph
+ */
+void build_graph(loco::Graph *graph, const luci::PGroup *pgroup)
+{
+ LOGGER(l);
+
+ luci::CloneContext clonectx;
+
+ // add input node(s)
+ for (auto *input : pgroup->inputs)
+ {
+ auto *input_clone = graph->nodes()->create<luci::CircleInput>();
+ luci::copy_common_attributes(input, input_clone);
+
+ add_graph_input(graph, input_clone);
+ clonectx.emplace(input, input_clone);
+
+ INFO(l) << "MAP: "
+ << " input(" << input << ") -> " << input_clone << "(" << input_clone->name() << ")";
+ }
+
+ // add CircleConst for inputs
+ for (auto &pnode : pgroup->pnodes)
+ {
+ auto node = pnode->node;
+ uint32_t arity = node->arity();
+ for (uint32_t a = 0; a < arity; ++a)
+ {
+ auto in_a_const = dynamic_cast<luci::CircleConst *>(node->arg(a));
+ if (in_a_const != nullptr)
+ {
+ auto it = clonectx.find(in_a_const);
+ if (it == clonectx.end())
+ {
+ auto *clone = clone_node(in_a_const, graph);
+ clonectx.emplace(in_a_const, clone);
+
+ INFO(l) << "MAP: "
+ << " const(" << in_a_const << ") -> " << clone << "(" << clone->name() << ")";
+ }
+ }
+ }
+ }
+
+ // add nodes
+ for (auto &pnode : pgroup->pnodes)
+ {
+ auto *clone = clone_node(pnode->node, graph);
+ clonectx.emplace(pnode->node, clone);
+
+ INFO(l) << "MAP: "
+ << " node(" << pnode->node << ") -> " << clone << "(" << clone->name() << ")";
+ }
+ // connect nodes
+ for (auto &pnode : pgroup->pnodes)
+ {
+ clone_connect(pnode->node, clonectx);
+ }
+
+ // add output node(s)
+ for (auto *output : pgroup->outputs)
+ {
+ auto *output_clone = graph->nodes()->create<luci::CircleOutput>();
+ luci::copy_common_attributes(output, output_clone);
+ // note: we don't add output_clone to clonectx.
+ // logically, output is not used as an input to any other nodes.
+
+ auto it = clonectx.find(output);
+ assert(it != clonectx.end());
+ output_clone->from(it->second);
+
+ add_graph_output(graph, output_clone);
+
+ INFO(l) << "MAP: "
+ << "output(" << output << ") -> " << output_clone << "(" << output_clone->name() << ")"
+ << ": from " << it->second << "(" << it->second->name() << ")";
+ }
+}
+
+std::string make_name(const luci::PGroup *pgroup)
+{
+ auto &first_pnode = *pgroup->pnodes.begin();
+ auto *first_node = first_pnode->node;
+ std::string name = first_node->graph()->name();
+ name = name + "_" + pgroup->group;
+ return name;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * @brief This will produce list of luci::Module as PartedModules from pgroups
+ */
+luci::PartedModules produce_pmodules(const luci::PGroups *pgroups)
+{
+ LOGGER(l);
+
+ luci::PartedModules pms;
+
+ for (auto &pgroup : pgroups->pgroups)
+ {
+ luci::PartedModule pm;
+ pm.module = std::make_unique<luci::Module>();
+ pm.group = pgroup->group;
+
+ auto graph = loco::make_graph();
+
+ auto graph_name = make_name(pgroup.get());
+ graph->name(graph_name);
+
+ INFO(l) << "--- Partition Graph build----------------------";
+ INFO(l) << "--- name: " << graph_name;
+ build_graph(graph.get(), pgroup.get());
+
+ pm.module->add(std::move(graph));
+ pms.pmodules.emplace_back(std::move(pm));
+ }
+
+ return pms;
+}
+
+} // namespace luci
diff --git a/compiler/luci/partition/src/PartitionPModules.h b/compiler/luci/partition/src/PartitionPModules.h
new file mode 100644
index 000000000..628ada56c
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionPModules.h
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PARTITON_PMODULES_H__
+#define __LUCI_PARTITON_PMODULES_H__
+
+#include "PartitionIR.h"
+
+#include "luci/Partition.h"
+
+namespace luci
+{
+
+luci::PartedModules produce_pmodules(const luci::PGroups *pgroups);
+
+} // namespace luci
+
+#endif // __LUCI_PARTITON_PMODULES_H__
diff --git a/compiler/luci/partition/src/PartitionPModules.test.cpp b/compiler/luci/partition/src/PartitionPModules.test.cpp
new file mode 100644
index 000000000..99c39e839
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionPModules.test.cpp
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PartitionPModules.h"
+#include "PartitionPGroups.h"
+
+#include <luci/test/TestIOGraph.h>
+
+#include <luci/IR/Nodes/CircleSqrt.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class SqrtGraphlet
+{
+public:
+ SqrtGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 input_shape)
+ {
+ _sqrt = g->nodes()->create<luci::CircleSqrt>();
+ _sqrt->dtype(loco::DataType::S32);
+ _sqrt->name("sqrt");
+ }
+
+protected:
+ luci::CircleSqrt *_sqrt = nullptr;
+};
+
+class SqrtGraph : public TestIOGraph, public SqrtGraphlet
+{
+public:
+ SqrtGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ SqrtGraphlet::init(g(), shape);
+
+ _sqrt->x(input());
+
+ output()->from(_sqrt);
+ }
+};
+
+} // namespace
+
+TEST(PartitionPModulesTest, simple_convert)
+{
+ luci::Module module;
+
+ SqrtGraph g;
+ g.init({3, 3});
+ g.transfer_to(&module);
+
+ luci::PartitionTable pt;
+ pt.default_group = "A";
+
+ auto pgs = produce_pgroups(&module, pt);
+ auto pms = produce_pmodules(pgs.get());
+
+ ASSERT_EQ(1, pms.pmodules.size());
+}
diff --git a/compiler/luci/partition/src/PartitionPModulesDump.cpp b/compiler/luci/partition/src/PartitionPModulesDump.cpp
new file mode 100644
index 000000000..ee50bc6fb
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionPModulesDump.cpp
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PartitionPModulesDump.h"
+
+#include "luci/LogHelper.h"
+
+#include <iostream>
+
+namespace luci
+{
+
+void dump(std::ostream &os, const PartedModule *pmodule)
+{
+ os << "--- PartedModule: " << pmodule->group << std::endl;
+ os << luci::fmt(pmodule->module->graph());
+}
+
+void dump(std::ostream &os, const PartedModules *pmodules)
+{
+ for (auto &pmodule : pmodules->pmodules)
+ {
+ dump(os, &pmodule);
+ }
+ os << std::endl;
+}
+
+} // namespace luci
+
+std::ostream &operator<<(std::ostream &os, const luci::PartedModules *pmodules)
+{
+ luci::dump(os, pmodules);
+ return os;
+}
diff --git a/compiler/luci/partition/src/PartitionPModulesDump.h b/compiler/luci/partition/src/PartitionPModulesDump.h
new file mode 100644
index 000000000..e77b235f4
--- /dev/null
+++ b/compiler/luci/partition/src/PartitionPModulesDump.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PARTITION_PMODULES_DUMP_H__
+#define __LUCI_PARTITION_PMODULES_DUMP_H__
+
+#include "luci/Partition.h"
+
+#include <iostream>
+
+namespace luci
+{
+
+void dump(std::ostream &os, const PartedModule *pmodule);
+void dump(std::ostream &os, const PartedModules *pmodules);
+
+} // namespace luci
+
+std::ostream &operator<<(std::ostream &os, const luci::PartedModules *pmodules);
+
+#endif // __LUCI_PARTITION_PMODULES_DUMP_H__
diff --git a/compiler/luci/pass/CMakeLists.txt b/compiler/luci/pass/CMakeLists.txt
index 2c5fb3407..2977fbed7 100644
--- a/compiler/luci/pass/CMakeLists.txt
+++ b/compiler/luci/pass/CMakeLists.txt
@@ -12,6 +12,7 @@ target_link_libraries(luci_pass PRIVATE luci_lang)
target_link_libraries(luci_pass PRIVATE luci_log)
target_link_libraries(luci_pass PRIVATE luci_service)
target_link_libraries(luci_pass PRIVATE luci_logex)
+target_link_libraries(luci_pass PRIVATE luci_profile)
target_link_libraries(luci_pass PRIVATE nncc_common)
target_link_libraries(luci_pass PRIVATE oops)
install(TARGETS luci_pass DESTINATION lib)
@@ -26,4 +27,5 @@ GTest_AddTest(luci_pass_test ${TESTS})
target_include_directories(luci_pass_test PRIVATE src)
target_link_libraries(luci_pass_test luci_pass)
target_link_libraries(luci_pass_test luci_lang)
+target_link_libraries(luci_pass_test luci_testhelper)
#target_link_libraries(luci_pass_test oops)
diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h
index 906760e0a..1f5e1c8b9 100644
--- a/compiler/luci/pass/include/luci/CircleOptimizer.h
+++ b/compiler/luci/pass/include/luci/CircleOptimizer.h
@@ -35,6 +35,8 @@ public:
enum Algorithm
{
FuseAddWithTConv,
+ FuseBatchNormWithConv,
+ FuseBatchNormWithDwConv,
FuseBatchNormWithTConv,
FuseBCQ,
FuseInstanceNorm,
@@ -44,7 +46,11 @@ public:
QuantizeDequantizeWeights,
QuantizeWithMinMax,
Requantize,
+ FoldAddV2,
+ FoldCast,
FoldDequantize,
+ FoldSparseToDense,
+ ForwardReshapeToUnaryOp,
SparsifyTensorPass,
FusePreActivationBatchNorm,
MakeBatchNormGammaPositive,
@@ -53,6 +59,15 @@ public:
RemoveRedundantTranspose,
ReplaceMulAddWithDepthwiseConv,
SubstitutePackToReshape,
+ SubstituteSqueezeToReshape,
+ ConvertNCHWToNHWC,
+ RemoveUnnecessarySlice,
+ RemoveUnnecessaryStridedSlice,
+ RemoveUnnecessarySplit,
+ RemoveUnnecessaryReshape,
+ TransformMinMaxToRelu6Pass,
+ SubstituteTransposeToReshape,
+ RemoveRedundantReshape,
};
enum AlgorithmParameters
@@ -68,6 +83,10 @@ public:
Sparsify_format,
Sparsify_block_size,
Sparsify_block_map,
+
+ // convert NCHW to NHWC
+ NCHW_to_NHWC_preserve_input_shape,
+ NCHW_to_NHWC_preserve_output_shape,
};
virtual ~Options() = default;
diff --git a/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h b/compiler/luci/pass/include/luci/Pass/CircleShapeInferencePass.h
index e21ab4cce..21d6d09d6 100644
--- a/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h
+++ b/compiler/luci/pass/include/luci/Pass/CircleShapeInferencePass.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef __LUCI_SHAPE_INFERENCE_PASS_H__
-#define __LUCI_SHAPE_INFERENCE_PASS_H__
+#ifndef __LUCI_CIRCLE_SHAPE_INFERENCE_PASS_H__
+#define __LUCI_CIRCLE_SHAPE_INFERENCE_PASS_H__
#include <loco.h>
@@ -25,12 +25,12 @@ namespace luci
{
/**
- * @brief Pass to infer shape of nodes
+ * @brief Pass to infer shape of circle nodes
*/
-class ShapeInferencePass : public luci::Pass
+class CircleShapeInferencePass : public luci::Pass
{
public:
- virtual const char *name(void) const { return "luci::ShapeInferencePass"; }
+ virtual const char *name(void) const { return "luci::CircleShapeInferencePass"; }
public:
bool run(luci::Module *m);
@@ -39,4 +39,4 @@ public:
} // namespace luci
-#endif //__LUCI_SHAPE_INFERENCE_PASS_H__
+#endif //__LUCI_CIRCLE_SHAPE_INFERENCE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/ConvertNCHWToNHWCPass.h b/compiler/luci/pass/include/luci/Pass/ConvertNCHWToNHWCPass.h
new file mode 100644
index 000000000..ba2392596
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/ConvertNCHWToNHWCPass.h
@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CONVERT_NCHW_TO_NHWC_PASS_H__
+#define __LUCI_CONVERT_NCHW_TO_NHWC_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to convert NCHW Ops to NHWC
+ *
+ * @details Find operators that use NCHW layout and make them use NHWC.
+ * Strictly speaking, it is impossible to distinguish whether
+ * an operator is using NCHW or NHWC without programmers' annotations.
+ * But we guess the data layout of each operator as much as possible
+ * based on the assumptions described in the comments.
+ * Note that this Pass does not change the execution result even
+ * for the false-positive cases.
+ */
+struct ConvertNCHWToNHWCPass final : public logo::Pass
+{
+public:
+ ConvertNCHWToNHWCPass(bool preserve_input, bool preserve_output)
+ : _preserve_input(preserve_input), _preserve_output(preserve_output)
+ {
+ // Do nothing
+ }
+
+ ConvertNCHWToNHWCPass() = delete;
+
+ virtual ~ConvertNCHWToNHWCPass() = default;
+
+ const char *name(void) const final { return "luci::ConvertNCHWToNHWCPass"; }
+
+ bool run(loco::Graph *g) final;
+
+private:
+ bool _preserve_input = false;
+ bool _preserve_output = false;
+};
+
+} // namespace luci
+
+#endif // __LUCI_CONVERT_NCHW_TO_NHWC_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/FoldAddV2Pass.h b/compiler/luci/pass/include/luci/Pass/FoldAddV2Pass.h
new file mode 100644
index 000000000..cd260b916
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/FoldAddV2Pass.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FOLD_ADD_V2_PASS_H__
+#define __LUCI_FOLD_ADD_V2_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to fold AddV2 to a constant tensor
+ *
+ */
+struct FoldAddV2Pass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FoldAddV2Pass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FOLD_ADD_V2_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/FoldCastPass.h b/compiler/luci/pass/include/luci/Pass/FoldCastPass.h
new file mode 100644
index 000000000..5d7ce4ad3
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/FoldCastPass.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FOLD_CAST_PASS_H__
+#define __LUCI_FOLD_CAST_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to fold Cast to a constant tensor
+ *
+ */
+struct FoldCastPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FoldCastPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FOLD_CAST_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/FoldSparseToDensePass.h b/compiler/luci/pass/include/luci/Pass/FoldSparseToDensePass.h
new file mode 100644
index 000000000..00d2447a5
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/FoldSparseToDensePass.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FOLD_SPARSE_TO_DENSE_PASS_H__
+#define __LUCI_FOLD_SPARSE_TO_DENSE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to fold SparseToDense to a constant tensor
+ *
+ */
+struct FoldSparseToDensePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FoldSparseToDensePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FOLD_SPARSE_TO_DENSE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/ForwardReshapeToUnaryOpPass.h b/compiler/luci/pass/include/luci/Pass/ForwardReshapeToUnaryOpPass.h
new file mode 100644
index 000000000..4c308e531
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/ForwardReshapeToUnaryOpPass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FORWARD_RESHAPE_TO_UNARYOP_PASS_H__
+#define __LUCI_FORWARD_RESHAPE_TO_UNARYOP_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to Forward send Reshape after UnaryOp.
+ */
+struct ForwardReshapeToUnaryOpPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::ForwardReshapeToUnaryOpPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FORWARD_RESHAPE_TO_UNARYOP_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithConvPass.h b/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithConvPass.h
new file mode 100644
index 000000000..1ed85447b
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithConvPass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FUSE_BATCH_NORM_WITH_CONV_PASS_H__
+#define __LUCI_FUSE_BATCH_NORM_WITH_CONV_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to fuse Batch Normalization into CircleConv
+ */
+struct FuseBatchNormWithConvPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FuseBatchNormWithConvPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FUSE_BATCH_NORM_WITH_CONV_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithDwConvPass.h b/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithDwConvPass.h
new file mode 100644
index 000000000..32885c6b2
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithDwConvPass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FUSE_BATCH_NORM_WITH_DWCONV_PASS_H__
+#define __LUCI_FUSE_BATCH_NORM_WITH_DWCONV_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to fuse Batch Normalization into CircleDepthWiseConv2D
+ */
+struct FuseBatchNormWithDwConvPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FuseBatchNormWithDwConvPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FUSE_BATCH_NORM_WITH_DWCONV_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithTConv.h b/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithTConvPass.h
index d3e930a36..d3e930a36 100644
--- a/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithTConv.h
+++ b/compiler/luci/pass/include/luci/Pass/FuseBatchNormWithTConvPass.h
diff --git a/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h b/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h
deleted file mode 100644
index c0ebc4e5d..000000000
--- a/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__
-#define __LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__
-
-#include <loco.h>
-
-#include <luci/ModulePass.h>
-
-namespace luci
-{
-
-/**
- * @brief Pass to copy shape/dtype of loco to circle node
- *
- * CAUTION : This pass will be removed after refactoring is finished
- */
-class MigrateLegacyShapeDtypePass : public luci::Pass
-{
-public:
- virtual const char *name(void) const { return "luci::MigrateLegacyShapeDtypePass"; }
-
-public:
- bool run(luci::Module *m);
- bool run(loco::Graph *graph);
-};
-
-} // namespace luci
-
-#endif //__LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h
index 713b88f9d..78e7323f9 100644
--- a/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h
+++ b/compiler/luci/pass/include/luci/Pass/QuantizeDequantizeWeightsPass.h
@@ -34,7 +34,7 @@ class QuantizeDequantizeWeightsPass : public logo::Pass
public:
QuantizeDequantizeWeightsPass(loco::DataType input_dtype, loco::DataType output_dtype,
QuantizationGranularity granularity)
- : _input_dtype{input_dtype}, _output_dtype{output_dtype}, _granularity{granularity}
+ : _input_dtype{input_dtype}, _output_dtype{output_dtype}, _granularity{granularity}
{
// DO NOTHING
}
diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h
index bb0d0ff40..9520910d5 100644
--- a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h
+++ b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h
@@ -34,7 +34,7 @@ class QuantizeWithMinMaxPass : public logo::Pass
public:
QuantizeWithMinMaxPass(loco::DataType input_dtype, loco::DataType output_dtype,
QuantizationGranularity granularity)
- : _input_dtype{input_dtype}, _output_dtype{output_dtype}, _granularity{granularity}
+ : _input_dtype{input_dtype}, _output_dtype{output_dtype}, _granularity{granularity}
{
// DO NOTHING
}
diff --git a/compiler/luci/pass/include/luci/Pass/RemoveRedundantReshapePass.h b/compiler/luci/pass/include/luci/Pass/RemoveRedundantReshapePass.h
new file mode 100644
index 000000000..458ffc094
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/RemoveRedundantReshapePass.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REMOVE_REDUNDANT_RESHAPE_PASS_H__
+#define __LUCI_REMOVE_REDUNDANT_RESHAPE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to remove redundant Reshape node into 1 Reshape node.
+ * @details This class will update consecutive two Reshape node into single Reshape node.
+ * As Reshape operation just change shape, not buffer, former reshape could be unnecessary.
+ */
+struct RemoveRedundantReshapePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::RemoveRedundantReshapePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REMOVE_REDUNDANT_RESHAPE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryReshapePass.h b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryReshapePass.h
new file mode 100644
index 000000000..8fca35e5b
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryReshapePass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REMOVE_UNNECESSARY_RESHAPE_PASS_H__
+#define __LUCI_REMOVE_UNNECESSARY_RESHAPE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to Remove Unnecessary(input shape and output shape same) Reshape node.
+ */
+struct RemoveUnnecessaryReshapePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::RemoveUnnecessaryReshapePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REMOVE_UNNECESSARY_RESHAPE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/RemoveUnnecessarySlicePass.h b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessarySlicePass.h
new file mode 100644
index 000000000..a3b0f2f8c
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessarySlicePass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REMOVE_NO_EFFECT_SLICE_PASS_H__
+#define __LUCI_REMOVE_NO_EFFECT_SLICE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to Remove Unnecessary(input and output are same) Slice node.
+ */
+struct RemoveUnnecessarySlicePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::RemoveUnnecessarySlicePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REMOVE_NO_EFFECT_SLICE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/RemoveUnnecessarySplitPass.h b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessarySplitPass.h
new file mode 100644
index 000000000..0d9330fe7
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessarySplitPass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REMOVE_UNNECESSARY_SPLIT_PASS_H__
+#define __LUCI_REMOVE_UNNECESSARY_SPLIT_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Remove unnecessary Split OP
+ */
+struct RemoveUnnecessarySplitPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::RemoveUnnecessarySplitPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REMOVE_UNNECESSARY_SPLIT_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryStridedSlicePass.h b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryStridedSlicePass.h
new file mode 100644
index 000000000..0f6a61d43
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryStridedSlicePass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REMOVE_UNNECESSARY_STRIDED_SLICE_PASS_H__
+#define __LUCI_REMOVE_UNNECESSARY_STRIDED_SLICE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to Remove Unnecessary(input and output are same) StridedSlice node.
+ */
+struct RemoveUnnecessaryStridedSlicePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::RemoveUnnecessaryStridedSlicePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REMOVE_UNNECESSARY_STRIDED_SLICE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/RequantizePass.h b/compiler/luci/pass/include/luci/Pass/RequantizePass.h
index 2442b24ea..c6c424f1b 100644
--- a/compiler/luci/pass/include/luci/Pass/RequantizePass.h
+++ b/compiler/luci/pass/include/luci/Pass/RequantizePass.h
@@ -33,7 +33,7 @@ class RequantizePass : public logo::Pass
{
public:
RequantizePass(loco::DataType input_dtype, loco::DataType output_dtype)
- : _input_dtype{input_dtype}, _output_dtype{output_dtype}
+ : _input_dtype{input_dtype}, _output_dtype{output_dtype}
{
// DO NOTHING
}
diff --git a/compiler/luci/pass/include/luci/Pass/SparsifyTensorPass.h b/compiler/luci/pass/include/luci/Pass/SparsifyTensorPass.h
index 41f43bf88..0ce142c55 100644
--- a/compiler/luci/pass/include/luci/Pass/SparsifyTensorPass.h
+++ b/compiler/luci/pass/include/luci/Pass/SparsifyTensorPass.h
@@ -35,8 +35,8 @@ public:
SparsifyTensorPass(const std::string &tensor_name, const std::vector<int32_t> &traversal_order,
const std::vector<DimensionType> &format,
const std::vector<int32_t> &block_size, const std::vector<int32_t> &block_map)
- : _tensor_name{tensor_name}, _traversal_order{traversal_order}, _format{format},
- _block_size{block_size}, _block_map{block_map}
+ : _tensor_name{tensor_name}, _traversal_order{traversal_order}, _format{format},
+ _block_size{block_size}, _block_map{block_map}
{
// DO NOTHING
}
diff --git a/compiler/luci/pass/include/luci/Pass/SubstituteSqueezeToReshapePass.h b/compiler/luci/pass/include/luci/Pass/SubstituteSqueezeToReshapePass.h
new file mode 100644
index 000000000..d8df6ac3f
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/SubstituteSqueezeToReshapePass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_SUBSTITUTE_SQUEEZE_TO_RESHAPE_PASS_H__
+#define __LUCI_SUBSTITUTE_SQUEEZE_TO_RESHAPE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to Substitute Squeeze to Reshape node for certain conditions.
+ */
+struct SubstituteSqueezeToReshapePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::SubstituteSqueezeToReshapePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_SUBSTITUTE_SQUEEZE_TO_RESHAPE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/SubstituteTransposeToReshapePass.h b/compiler/luci/pass/include/luci/Pass/SubstituteTransposeToReshapePass.h
new file mode 100644
index 000000000..ee708585a
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/SubstituteTransposeToReshapePass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_SUBSTITUTE_TRANSPOSE_TO_RESHAPE_PASS_H__
+#define __LUCI_SUBSTITUTE_TRANSPOSE_TO_RESHAPE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to Substitute Transpose with certain input shape condition to single reshape node.
+ */
+struct SubstituteTransposeToReshapePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::SubstituteTransposeToReshapePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_SUBSTITUTE_TRANSPOSE_TO_RESHAPE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/TransformMinMaxToRelu6Pass.h b/compiler/luci/pass/include/luci/Pass/TransformMinMaxToRelu6Pass.h
new file mode 100644
index 000000000..9ea39ee4e
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/TransformMinMaxToRelu6Pass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_TRANSFORM_MIN_MAX_TO_RELU6_PASS_H__
+#define __LUCI_TRANSFORM_MIN_MAX_TO_RELU6_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to transform Maximum(Minimum(input, 6), 0) to Relu6
+ */
+struct TransformMinMaxToRelu6Pass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::TransformMinMaxToRelu6Pass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_TRANSFORM_MIN_MAX_TO_RELU6_PASS_H__
diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.cpp b/compiler/luci/pass/src/BatchNormPatternFinder.cpp
new file mode 100644
index 000000000..c1a06bfda
--- /dev/null
+++ b/compiler/luci/pass/src/BatchNormPatternFinder.cpp
@@ -0,0 +1,106 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "BatchNormPatternFinder.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace luci
+{
+
+bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::CircleConst *&beta)
+{
+ auto x = loco::must_cast<luci::CircleNode *>(add->x());
+ auto y = loco::must_cast<luci::CircleNode *>(add->y());
+
+ luci::CircleMul *pred = nullptr;
+ luci::CircleConst *constant = nullptr;
+
+ if (x->opcode() == luci::CircleOpcode::CIRCLECONST && y->opcode() == luci::CircleOpcode::MUL)
+ {
+ pred = loco::must_cast<luci::CircleMul *>(y);
+ constant = loco::must_cast<luci::CircleConst *>(x);
+ }
+ else if (x->opcode() == luci::CircleOpcode::MUL && y->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ pred = loco::must_cast<luci::CircleMul *>(x);
+ constant = loco::must_cast<luci::CircleConst *>(y);
+ }
+ else
+ {
+ return false;
+ }
+
+ if (constant->rank() != 1)
+ return false;
+
+ auto channel_dim = constant->dim(0);
+ // Assumption: Layout is channel-last
+ if (!(channel_dim == add->dim(add->rank() - 1)))
+ return false;
+
+ mul = pred;
+ beta = constant;
+ return true;
+}
+
+bool is_batchnorm_add(const luci::CircleAdd *add)
+{
+ // for dummy mul and beta
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *beta = nullptr;
+
+ return is_batchnorm_add(add, mul, beta);
+}
+
+bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node,
+ luci::CircleConst *&gamma)
+{
+ auto x = dynamic_cast<luci::CircleConst *>(mul->x());
+ auto y = dynamic_cast<luci::CircleConst *>(mul->y());
+
+ luci::CircleNode *pred = nullptr;
+ luci::CircleConst *constant = nullptr;
+
+ if (x != nullptr && y == nullptr)
+ {
+ pred = loco::must_cast<luci::CircleNode *>(mul->y());
+ constant = x;
+ }
+ else if (x == nullptr && y != nullptr)
+ {
+ pred = loco::must_cast<luci::CircleNode *>(mul->x());
+ constant = y;
+ }
+ else
+ {
+ return false;
+ }
+
+ if (constant->rank() != 1)
+ return false;
+
+ auto channel_dim = constant->dim(0);
+ // Assumption: Layout is channel-last
+ if (!(channel_dim == mul->dim(mul->rank() - 1)))
+ return false;
+
+ pred_node = pred;
+ gamma = constant;
+ return true;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.h b/compiler/luci/pass/src/BatchNormPatternFinder.h
new file mode 100644
index 000000000..58cdbb464
--- /dev/null
+++ b/compiler/luci/pass/src/BatchNormPatternFinder.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PASS_BATCH_NORM_PATTERN_FINDER_H__
+#define __LUCI_PASS_BATCH_NORM_PATTERN_FINDER_H__
+
+#include <luci/IR/CircleNodes.h>
+
+namespace luci
+{
+
+/**
+ * @brief Find Mul-Add pattern and return Mul and beta as BatchNorm
+ */
+bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::CircleConst *&beta);
+
+/**
+ * @brief Find Mul-Add pattern
+ */
+bool is_batchnorm_add(const luci::CircleAdd *add);
+
+/**
+ * @brief Find Const-Mul pattern and return Node and gamma as BatchNorm
+ */
+bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node,
+ luci::CircleConst *&gamma);
+
+} // namespace luci
+
+#endif // __LUCI_PASS_BATCH_NORM_PATTERN_FINDER_H__
diff --git a/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
new file mode 100644
index 000000000..08e7fac1c
--- /dev/null
+++ b/compiler/luci/pass/src/BatchNormPatternFinder.test.cpp
@@ -0,0 +1,217 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "BatchNormPatternFinder.h"
+
+#include <luci/test/TestIOGraph.h>
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace luci
+{
+namespace test
+{
+
+/**
+ * @brief Graphlet with Add and Const as beta from BatchNorm
+ */
+class AddBetaGraphlet
+{
+public:
+ AddBetaGraphlet() = default;
+
+ void init(loco::Graph *g, const ShapeU32 shape, luci::FusedActFunc actf)
+ {
+ _add = g->nodes()->create<luci::CircleAdd>();
+ _add_beta = g->nodes()->create<luci::CircleConst>();
+
+ _add->dtype(loco::DataType::FLOAT32);
+ _add_beta->dtype(loco::DataType::FLOAT32);
+
+ _add->fusedActivationFunction(actf);
+
+ assert(shape.size() > 0);
+ auto last_it = std::prev(shape.end(), 1);
+ auto channel_size = *last_it;
+
+ _add->shape(shape);
+ _add_beta->shape({channel_size});
+ _add_beta->size<loco::DataType::FLOAT32>(channel_size);
+ for (uint32_t i = 0; i < channel_size; i++)
+ _add_beta->at<loco::DataType::FLOAT32>(i) = i;
+
+ _add->name("add");
+ _add_beta->name("add_beta");
+ }
+
+public:
+ luci::CircleAdd *add() { return _add; }
+
+protected:
+ luci::CircleAdd *_add = nullptr;
+ luci::CircleConst *_add_beta = nullptr;
+};
+
+/**
+ * @brief Graphlet with Mul and Const as gamma from BatchNorm
+ */
+class MulGammaGraphlet
+{
+public:
+ MulGammaGraphlet() = default;
+
+ void init(loco::Graph *g, const ShapeU32 shape, luci::FusedActFunc actf)
+ {
+ _mul = g->nodes()->create<luci::CircleMul>();
+ _mul_gamma = g->nodes()->create<luci::CircleConst>();
+
+ _mul->dtype(loco::DataType::FLOAT32);
+ _mul_gamma->dtype(loco::DataType::FLOAT32);
+
+ _mul->fusedActivationFunction(actf);
+
+ assert(shape.size() > 0);
+ auto last_it = std::prev(shape.end(), 1);
+ auto channel_size = *last_it;
+
+ _mul->shape(shape);
+ _mul_gamma->shape({channel_size});
+ _mul_gamma->size<loco::DataType::FLOAT32>(channel_size);
+ for (uint32_t i = 0; i < channel_size; i++)
+ _mul_gamma->at<loco::DataType::FLOAT32>(i) = i;
+
+ _mul->name("mul");
+ _mul_gamma->name("mul_gamma");
+ }
+
+public:
+ luci::CircleMul *mul(void) { return _mul; }
+
+protected:
+ luci::CircleMul *_mul = nullptr;
+ luci::CircleConst *_mul_gamma = nullptr;
+};
+
+/**
+ * @brief Graph of Mul-Add pattern from BatchNorm
+ */
+class MulAddGraph : public TestIOGraph, public AddBetaGraphlet, public MulGammaGraphlet
+{
+public:
+ MulAddGraph() = default;
+
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ MulGammaGraphlet::init(g(), shape_in, luci::FusedActFunc::NONE);
+ AddBetaGraphlet::init(g(), shape_out, luci::FusedActFunc::RELU);
+
+ // connect network
+ _mul->x(input());
+ _mul->y(_mul_gamma);
+ _add->x(_mul);
+ _add->y(_add_beta);
+ output()->from(_add);
+ }
+};
+
+/**
+ * @brief Graph of Add with Const
+ */
+class AddGraph : public TestIOGraph, public AddBetaGraphlet
+{
+public:
+ AddGraph() = default;
+
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ AddBetaGraphlet::init(g(), shape_in, luci::FusedActFunc::RELU);
+
+ // connect network
+ _add->x(input());
+ _add->y(_add_beta);
+ output()->from(_add);
+ }
+};
+
+} // namespace test
+} // namespace luci
+
+class BatchNormPatternFinderMulAddTest : public ::testing::Test
+{
+public:
+ BatchNormPatternFinderMulAddTest() = default;
+
+protected:
+ luci::test::MulAddGraph _mag;
+};
+
+class BatchNormPatternFinderAddTest : public ::testing::Test
+{
+public:
+ BatchNormPatternFinderAddTest() = default;
+
+protected:
+ luci::test::AddGraph _ag;
+};
+
+TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add)
+{
+ _mag.init({1, 16, 16, 4}, {1, 16, 16, 4});
+
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *beta = nullptr;
+
+ auto res = luci::is_batchnorm_add(_mag.add(), mul, beta);
+ ASSERT_TRUE(res);
+ ASSERT_NE(nullptr, mul);
+ ASSERT_NE(nullptr, beta);
+}
+
+TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_add2)
+{
+ _mag.init({1, 16, 16, 4}, {1, 16, 16, 4});
+
+ auto res = luci::is_batchnorm_add(_mag.add());
+ ASSERT_TRUE(res);
+}
+
+TEST_F(BatchNormPatternFinderAddTest, is_batchnorm_add_NEG)
+{
+ _ag.init({1, 16, 16, 4}, {1, 16, 16, 4});
+
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *beta = nullptr;
+
+ auto res = luci::is_batchnorm_add(_ag.add(), mul, beta);
+ ASSERT_FALSE(res);
+}
+
+TEST_F(BatchNormPatternFinderMulAddTest, is_batchnorm_mul)
+{
+ _mag.init({1, 16, 16, 4}, {1, 16, 16, 4});
+
+ luci::CircleNode *pred = nullptr;
+ luci::CircleConst *gamma = nullptr;
+
+ auto res = luci::is_batchnorm_mul(_mag.mul(), pred, gamma);
+ ASSERT_TRUE(res);
+ ASSERT_NE(nullptr, pred);
+ ASSERT_NE(nullptr, gamma);
+}
diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp
index cc9fe481c..bddad34fa 100644
--- a/compiler/luci/pass/src/CircleOptimizer.cpp
+++ b/compiler/luci/pass/src/CircleOptimizer.cpp
@@ -16,16 +16,28 @@
#include "luci/CircleOptimizer.h"
+#include "luci/Pass/ConvertNCHWToNHWCPass.h"
+#include "luci/Pass/FoldAddV2Pass.h"
+#include "luci/Pass/FoldCastPass.h"
#include "luci/Pass/FoldDequantizePass.h"
+#include "luci/Pass/FoldSparseToDensePass.h"
+#include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
#include "luci/Pass/FuseActivationFunctionPass.h"
#include "luci/Pass/FuseAddWithTConvPass.h"
-#include "luci/Pass/FuseBatchNormWithTConv.h"
+#include "luci/Pass/FuseBatchNormWithConvPass.h"
+#include "luci/Pass/FuseBatchNormWithDwConvPass.h"
+#include "luci/Pass/FuseBatchNormWithTConvPass.h"
#include "luci/Pass/FuseBCQPass.h"
#include "luci/Pass/FuseInstanceNormPass.h"
#include "luci/Pass/FusePreActivationBatchNormPass.h"
#include "luci/Pass/MakeBatchNormGammaPositivePass.h"
#include "luci/Pass/PropagateQuantParamPass.h"
+#include "luci/Pass/RemoveRedundantReshapePass.h"
#include "luci/Pass/RemoveRedundantTransposePass.h"
+#include "luci/Pass/RemoveUnnecessaryReshapePass.h"
+#include "luci/Pass/RemoveUnnecessarySlicePass.h"
+#include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h"
+#include "luci/Pass/RemoveUnnecessarySplitPass.h"
#include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
#include "luci/Pass/ResolveCustomOpAddPass.h"
#include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
@@ -36,21 +48,22 @@
#include "luci/Pass/SparsifyTensorPass.h"
#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
#include "luci/Pass/SubstitutePackToReshapePass.h"
+#include "luci/Pass/SubstituteSqueezeToReshapePass.h"
+#include "luci/Pass/SubstituteTransposeToReshapePass.h"
+#include "luci/Pass/TransformMinMaxToRelu6Pass.h"
// TODO add more passes
-#include "luci/Pass/ShapeInferencePass.h"
-#include "luci/Pass/ShapeSignatureInferencePass.h"
-#include "luci/Pass/TypeInferencePass.h"
-
-// Following passes will be removed after refactoring is finished
-#include "luci/Pass/MigrateLegacyShapeDtypePass.h"
+#include "luci/Pass/CircleShapeInferencePass.h"
+#include "luci/Pass/CircleTypeInferencePass.h"
// logo passes
#include <logo/RemoveDeadNodeWithQueryPass.h>
#include "ModulePhase.h"
#include "ProgressReporter.h"
-#include "CircleOptimizerUtils.h"
+#include "helpers/Strings.h"
+
+#include "QuantizedModelVerifier.h"
#include <luci/IR/CircleNodes.h>
#include <logo/Phase.h>
@@ -61,20 +74,6 @@
namespace
{
-std::vector<int> parseIntFromCommadelimitedStr(std::string str)
-{
- std::vector<int> ret;
- std::istringstream is(str);
- for (uint32_t i; is >> i;)
- {
- assert(i != ',');
- ret.push_back(i);
- if (is.peek() == ',')
- is.ignore();
- }
- return ret;
-}
-
using namespace luci;
class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
@@ -138,13 +137,9 @@ void CircleOptimizer::optimize(luci::Module *m) const
{
luci::Phase phase;
- // Following passes will be deprecated after refactoring is finished.
- phase.emplace_back(std::make_unique<luci::MigrateLegacyShapeDtypePass>());
-
// Following passes are needed everytime when other passes create new node or modify some nodes.
- phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
- phase.emplace_back(std::make_unique<luci::ShapeSignatureInferencePass>());
- phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
if (_options->query(Options::Algorithm::FuseBCQ))
{
@@ -164,13 +159,9 @@ void CircleOptimizer::optimize(loco::Graph *g) const
/* TRANSFORM DECLARATION BEGIN */
phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
- // Following passes will be deprecated after refactoring is finished.
- phase.emplace_back(std::make_unique<luci::MigrateLegacyShapeDtypePass>());
-
// Following passes are needed everytime when other passes create new node or modify some nodes.
- phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
- phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
- phase.emplace_back(std::make_unique<luci::ShapeSignatureInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
if (_options->query(Options::Algorithm::ResolveCustomOpAdd))
{
@@ -188,6 +179,14 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
}
+ if (_options->query(Options::Algorithm::FuseBatchNormWithConv))
+ {
+ phase.emplace_back(std::make_unique<FuseBatchNormWithConvPass>());
+ }
+ if (_options->query(Options::Algorithm::FuseBatchNormWithDwConv))
+ {
+ phase.emplace_back(std::make_unique<FuseBatchNormWithDwConvPass>());
+ }
if (_options->query(Options::Algorithm::FuseBatchNormWithTConv))
{
phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>());
@@ -200,10 +199,26 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<FuseActivationFunctionPass>());
}
+ if (_options->query(Options::Algorithm::FoldAddV2))
+ {
+ phase.emplace_back(std::make_unique<luci::FoldAddV2Pass>());
+ }
+ if (_options->query(Options::Algorithm::FoldCast))
+ {
+ phase.emplace_back(std::make_unique<luci::FoldCastPass>());
+ }
if (_options->query(Options::Algorithm::FoldDequantize))
{
phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
}
+ if (_options->query(Options::Algorithm::FoldSparseToDense))
+ {
+ phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>());
+ }
+ if (_options->query(Options::Algorithm::ForwardReshapeToUnaryOp))
+ {
+ phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>());
+ }
if (_options->query(Options::Algorithm::FusePreActivationBatchNorm))
{
phase.emplace_back(std::make_unique<luci::FusePreActivationBatchNormPass>());
@@ -216,6 +231,26 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::ShuffleWeightTo16x1Float32Pass>());
}
+ if (_options->query(Options::Algorithm::RemoveUnnecessaryReshape))
+ {
+ phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapePass>());
+ }
+ if (_options->query(Options::Algorithm::RemoveUnnecessarySlice))
+ {
+ phase.emplace_back(std::make_unique<luci::RemoveUnnecessarySlicePass>());
+ }
+ if (_options->query(Options::Algorithm::RemoveUnnecessaryStridedSlice))
+ {
+ phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryStridedSlicePass>());
+ }
+ if (_options->query(Options::Algorithm::RemoveUnnecessarySplit))
+ {
+ phase.emplace_back(std::make_unique<luci::RemoveUnnecessarySplitPass>());
+ }
+ if (_options->query(Options::Algorithm::RemoveRedundantReshape))
+ {
+ phase.emplace_back(std::make_unique<luci::RemoveRedundantReshapePass>());
+ }
if (_options->query(Options::Algorithm::RemoveRedundantTranspose))
{
phase.emplace_back(std::make_unique<luci::RemoveRedundantTransposePass>());
@@ -228,6 +263,28 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::SubstitutePackToReshapePass>());
}
+ if (_options->query(Options::Algorithm::SubstituteSqueezeToReshape))
+ {
+ phase.emplace_back(std::make_unique<luci::SubstituteSqueezeToReshapePass>());
+ }
+ if (_options->query(Options::Algorithm::SubstituteTransposeToReshape))
+ {
+ phase.emplace_back(std::make_unique<luci::SubstituteTransposeToReshapePass>());
+ }
+ if (_options->query(Options::Algorithm::TransformMinMaxToRelu6Pass))
+ {
+ phase.emplace_back(std::make_unique<luci::TransformMinMaxToRelu6Pass>());
+ }
+ if (_options->query(Options::Algorithm::ConvertNCHWToNHWC))
+ {
+ bool preserve_input =
+ _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_preserve_input_shape) == "true";
+ bool preserve_output =
+ _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_preserve_output_shape) == "true";
+
+ phase.emplace_back(
+ std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
+ }
/* TRANSFORM DECLARATION END */
@@ -275,7 +332,7 @@ void CircleOptimizer::quantize(loco::Graph *g) const
}
luci::QuantizeDequantizeWeightsPass fake_quantizer(
- str_to_dtype(input_dtype), str_to_dtype(output_dtype), str_to_granularity(granularity));
+ str_to_dtype(input_dtype), str_to_dtype(output_dtype), str_to_granularity(granularity));
fake_quantizer.run(g);
}
@@ -315,14 +372,19 @@ void CircleOptimizer::quantize(loco::Graph *g) const
phase.emplace_back(std::make_unique<luci::PropagateQuantParamPass>());
- phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
- phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
phase_runner.attach(&prog);
phase_runner.run(phase);
+
+ // Verify the type/granularity of the quantized model
+ luci::QuantizedModelVerifier verifier(str_to_dtype(output_dtype),
+ str_to_granularity(granularity));
+ verifier.verify(g);
}
// Requantize
@@ -349,8 +411,8 @@ void CircleOptimizer::quantize(loco::Graph *g) const
logo::Phase phase;
// Do Shape/Type inference
- phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
- phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
@@ -364,13 +426,13 @@ void CircleOptimizer::sparsify(loco::Graph *g) const
{
std::string tensor_name = _options->param(Options::AlgorithmParameters::Sparsify_tensor_name);
std::string str_tarversal_order =
- _options->param(Options::AlgorithmParameters::Sparsify_traversal_order);
+ _options->param(Options::AlgorithmParameters::Sparsify_traversal_order);
std::string str_format = _options->param(Options::AlgorithmParameters::Sparsify_format);
std::string str_block_size = _options->param(Options::AlgorithmParameters::Sparsify_block_size);
std::string str_block_map = _options->param(Options::AlgorithmParameters::Sparsify_block_map);
// traversal order
- std::vector<int32_t> traversal_order = parseIntFromCommadelimitedStr(str_tarversal_order);
+ std::vector<int32_t> traversal_order = csv_to_vector<int32_t>(str_tarversal_order);
// format
std::vector<DimensionType> format;
std::istringstream is(str_format);
@@ -385,9 +447,9 @@ void CircleOptimizer::sparsify(loco::Graph *g) const
is.ignore();
}
// block size
- std::vector<int32_t> block_size = parseIntFromCommadelimitedStr(str_block_size);
+ std::vector<int32_t> block_size = csv_to_vector<int32_t>(str_block_size);
// block map
- std::vector<int32_t> block_map = parseIntFromCommadelimitedStr(str_block_map);
+ std::vector<int32_t> block_map = csv_to_vector<int32_t>(str_block_map);
luci::SparsifyTensorPass sparsifier{tensor_name, traversal_order, format, block_size,
block_map};
diff --git a/compiler/luci/pass/src/CircleOptimizer.test.cpp b/compiler/luci/pass/src/CircleOptimizer.test.cpp
new file mode 100644
index 000000000..ca6dc77f3
--- /dev/null
+++ b/compiler/luci/pass/src/CircleOptimizer.test.cpp
@@ -0,0 +1,238 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/CircleOptimizer.h"
+
+#include <gtest/gtest.h>
+
+using namespace luci;
+using Algorithms = luci::CircleOptimizer::Options::Algorithm;
+using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
+
+TEST(CircleOptimizerTest, optimize_algorithms)
+{
+ loco::Graph g;
+ luci::CircleOptimizer o;
+
+ auto options = o.options();
+
+ // NOTE these are added to cover the test
+ // TODO add more if needed
+ options->enable(Algorithms::FoldAddV2);
+ options->enable(Algorithms::FoldCast);
+ options->enable(Algorithms::FoldDequantize);
+ options->enable(Algorithms::FoldSparseToDense);
+ options->enable(Algorithms::FusePreActivationBatchNorm);
+ options->enable(Algorithms::MakeBatchNormGammaPositive);
+ options->enable(Algorithms::ShuffleWeightTo16x1Float32);
+ options->enable(Algorithms::RemoveUnnecessaryReshape);
+ options->enable(Algorithms::RemoveUnnecessarySlice);
+ options->enable(Algorithms::RemoveUnnecessarySplit);
+ options->enable(Algorithms::ReplaceMulAddWithDepthwiseConv);
+ options->enable(Algorithms::SubstituteTransposeToReshape);
+ options->enable(Algorithms::ConvertNCHWToNHWC);
+
+ o.optimize(&g);
+
+ SUCCEED();
+}
+
+TEST(CircleOptimizerTest, sparsify_simple)
+{
+ loco::Graph g;
+ luci::CircleOptimizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::SparsifyTensorPass);
+ options->param(AlgorithmParameters::Sparsify_tensor_name, "dummy");
+ options->param(AlgorithmParameters::Sparsify_traversal_order, "dummy");
+ options->param(AlgorithmParameters::Sparsify_format, "ds");
+ options->param(AlgorithmParameters::Sparsify_block_size, "1,1");
+ options->param(AlgorithmParameters::Sparsify_block_map, "1,1");
+
+ o.sparsify(&g);
+
+ SUCCEED();
+}
+
+TEST(CircleOptimizerTest, quantize_quantdequant_simple)
+{
+ loco::Graph g;
+ luci::CircleOptimizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ o.quantize(&g);
+
+ SUCCEED();
+}
+
+TEST(CircleOptimizerTest, quantize_quantdequant_input_NEG)
+{
+ loco::Graph g;
+ luci::CircleOptimizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_output_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleOptimizerTest, quantize_quantdequant_output_NEG)
+{
+ loco::Graph g;
+ luci::CircleOptimizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleOptimizerTest, quantize_quantdequant_gran_NEG)
+{
+ loco::Graph g;
+ luci::CircleOptimizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeDequantizeWeights);
+ options->param(AlgorithmParameters::Quantize_input_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "invalid");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleOptimizerTest, quantize_minmax_simple)
+{
+ loco::Graph g;
+ luci::CircleOptimizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ o.quantize(&g);
+
+ SUCCEED();
+}
+
+TEST(CircleOptimizerTest, quantize_minmax_input_NEG)
+{
+ loco::Graph g;
+ luci::CircleOptimizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_output_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleOptimizerTest, quantize_minmax_output_NEG)
+{
+ loco::Graph g;
+ luci::CircleOptimizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_granularity, "layer");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleOptimizerTest, quantize_minmax_gran_NEG)
+{
+ loco::Graph g;
+ luci::CircleOptimizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::QuantizeWithMinMax);
+ options->param(AlgorithmParameters::Quantize_input_dtype, "float32");
+ options->param(AlgorithmParameters::Quantize_output_dtype, "uint8");
+ options->param(AlgorithmParameters::Quantize_granularity, "invalid");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleOptimizerTest, quantize_requant_simple)
+{
+ loco::Graph g;
+ luci::CircleOptimizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::Requantize);
+ options->param(AlgorithmParameters::Quantize_input_dtype, "int8");
+ options->param(AlgorithmParameters::Quantize_output_dtype, "uint8");
+
+ o.quantize(&g);
+
+ SUCCEED();
+}
+
+TEST(CircleOptimizerTest, quantize_requant_input_NEG)
+{
+ loco::Graph g;
+ luci::CircleOptimizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::Requantize);
+ options->param(AlgorithmParameters::Quantize_input_dtype, "invalid");
+ options->param(AlgorithmParameters::Quantize_output_dtype, "uint8");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
+
+TEST(CircleOptimizerTest, quantize_requant_output_NEG)
+{
+ loco::Graph g;
+ luci::CircleOptimizer o;
+
+ auto options = o.options();
+
+ options->enable(Algorithms::Requantize);
+ options->param(AlgorithmParameters::Quantize_input_dtype, "int8");
+ options->param(AlgorithmParameters::Quantize_output_dtype, "invalid");
+
+ EXPECT_THROW(o.quantize(&g), std::runtime_error);
+}
diff --git a/compiler/luci/pass/src/CircleOptimizerUtils.cpp b/compiler/luci/pass/src/CircleOptimizerUtils.cpp
index ffc372392..127573db4 100644
--- a/compiler/luci/pass/src/CircleOptimizerUtils.cpp
+++ b/compiler/luci/pass/src/CircleOptimizerUtils.cpp
@@ -16,74 +16,18 @@
#include "CircleOptimizerUtils.h"
-namespace luci
-{
-
-bool in_array(const std::string &str, const std::vector<std::string> &array)
-{
- return std::find(array.begin(), array.end(), str) != array.end();
-}
+#include <luci/IR/CircleNode.h>
-std::string to_string(const std::vector<std::string> &strings)
-{
- assert(!strings.empty());
-
- std::string res;
- for (unsigned int i = 0; i < strings.size() - 1; i++)
- res += strings[i] + ", ";
-
- res += strings[strings.size() - 1];
- return res;
-}
-
-std::string to_lower_case(std::string s)
-{
- std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
- return s;
-}
-
-loco::DataType str_to_dtype(const std::string &str)
+namespace luci
{
- if (to_lower_case(str).compare("uint8") == 0)
- return loco::DataType::U8;
- if (to_lower_case(str).compare("uint16") == 0)
- return loco::DataType::U16;
- if (to_lower_case(str).compare("uint32") == 0)
- return loco::DataType::U32;
- if (to_lower_case(str).compare("uint64") == 0)
- return loco::DataType::U64;
-
- if (to_lower_case(str).compare("int8") == 0)
- return loco::DataType::S8;
- if (to_lower_case(str).compare("int16") == 0)
- return loco::DataType::S16;
- if (to_lower_case(str).compare("int32") == 0)
- return loco::DataType::S32;
- if (to_lower_case(str).compare("int64") == 0)
- return loco::DataType::S64;
-
- if (to_lower_case(str).compare("float16") == 0)
- return loco::DataType::FLOAT16;
- if (to_lower_case(str).compare("float32") == 0)
- return loco::DataType::FLOAT32;
- if (to_lower_case(str).compare("float64") == 0)
- return loco::DataType::FLOAT64;
- if (to_lower_case(str).compare("bool") == 0)
- return loco::DataType::BOOL;
-
- return loco::DataType::Unknown;
-}
-
-QuantizationGranularity str_to_granularity(const std::string &str)
+bool has_dynamic_shape(const loco::Node *node)
{
- if (to_lower_case(str).compare("layer") == 0)
- return QuantizationGranularity::LayerWise;
-
- if (to_lower_case(str).compare("channel") == 0)
- return QuantizationGranularity::ChannelWise;
-
- throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'");
+ const auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ for (uint32_t i = 0; i < circle_node->rank(); ++i)
+ if (!circle_node->dim(i).known())
+ return true;
+ return false;
}
} // namespace luci
diff --git a/compiler/luci/pass/src/CircleOptimizerUtils.h b/compiler/luci/pass/src/CircleOptimizerUtils.h
index 7e577a05f..e04942bfa 100644
--- a/compiler/luci/pass/src/CircleOptimizerUtils.h
+++ b/compiler/luci/pass/src/CircleOptimizerUtils.h
@@ -17,25 +17,12 @@
#ifndef __LUCI_CIRCLE_OPTIMIZER_UTILS_H__
#define __LUCI_CIRCLE_OPTIMIZER_UTILS_H__
-#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
-#include "luci/Pass/QuantizeWithMinMaxPass.h"
-
#include <loco.h>
-#include <algorithm>
-
namespace luci
{
-bool in_array(const std::string &, const std::vector<std::string> &);
-
-std::string to_string(const std::vector<std::string> &);
-
-std::string to_lower_case(std::string);
-
-loco::DataType str_to_dtype(const std::string &);
-
-QuantizationGranularity str_to_granularity(const std::string &);
+bool has_dynamic_shape(const loco::Node *node);
} // namespace luci
diff --git a/compiler/luci/pass/src/CircleShapeInferencePass.cpp b/compiler/luci/pass/src/CircleShapeInferencePass.cpp
new file mode 100644
index 000000000..ddab22421
--- /dev/null
+++ b/compiler/luci/pass/src/CircleShapeInferencePass.cpp
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "helpers/InferenceCandidates.h"
+
+#include "luci/Pass/CircleShapeInferencePass.h"
+
+#include <luci/Service/CircleShapeInference.h>
+
+#include <loco.h>
+
+namespace
+{
+
+bool is_same_shape(luci::CircleNode *node, loco::TensorShape shape)
+{
+ if (node->shape_status() != luci::ShapeStatus::VALID)
+ return false;
+
+ if (node->rank() != shape.rank())
+ return false;
+
+ for (uint32_t i = 0; i < node->rank(); ++i)
+ {
+ if (node->dim(i).known() != shape.dim(i).known())
+ return false;
+
+ if (node->dim(i).value() != shape.dim(i).value())
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool CircleShapeInferencePass::run(luci::Module *m)
+{
+ bool changed = false;
+
+ for (size_t g = 0; g < m->size(); ++g)
+ {
+ if (run(m->graph(g)))
+ changed = true;
+ }
+
+ return changed;
+}
+
+bool CircleShapeInferencePass::run(loco::Graph *g)
+{
+ luci::sinf::Rule shape_infer_rule;
+ bool changed = false;
+
+ for (auto node : inference_candidates(g))
+ {
+ loco::TensorShape shape;
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+
+ if (shape_infer_rule.infer(circle_node, shape) && !is_same_shape(circle_node, shape))
+ {
+ circle_node->rank(shape.rank());
+ for (uint32_t i = 0; i < shape.rank(); ++i)
+ circle_node->dim(i) = shape.dim(i);
+
+ circle_node->shape_status(luci::ShapeStatus::VALID);
+
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/CircleShapeInferencePass.test.cpp b/compiler/luci/pass/src/CircleShapeInferencePass.test.cpp
new file mode 100644
index 000000000..cb3f1fe5f
--- /dev/null
+++ b/compiler/luci/pass/src/CircleShapeInferencePass.test.cpp
@@ -0,0 +1,364 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/CircleShapeInferencePass.h"
+
+#include <loco.h>
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+TEST(CircleShapeInferencePassTest, name)
+{
+ luci::CircleShapeInferencePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+/**
+ * This test is to check whether shape inference is done by topological order.
+ *
+ * When perm() of "transpose1" is changed from "old_perm" to "new_perm"
+ * by some of luci/Pass like below diagram, shape_status of "transpose1" is
+ * still VALID even the shape should be changed.
+ * If "transpose2" is visited first before shape of "transpose1" is updated,
+ * "transpose2" can reference the shape of "relu" which is not updated yet.
+ * Then shape of "transpose2" becomes 3x5x5x1 and it causes an error at "conv2d".
+ *
+ * <Initial graph>
+ * 4x1x1x3
+ * [old_perm] ----------+ [filter] ----------+
+ * (0,2,1,3) | |
+ * | [bias] ----------+
+ * | |
+ * input ------> [transpose1] ------> [relu] ------> [conv2d] ------> output
+ * 1x5x5x3 1x5x5x3 1x5x5x3 1x5x5x4
+ *
+ *
+ * <Right after transformation>
+ * 4x1x1x3
+ * [new_perm] ----------+-----------------------------------+ [filter] ------+
+ * (3,2,1,0) | | |
+ * | | [bias] ------+
+ * | | |
+ * input ------> [transpose1] ------> [relu] ------> [transpose2] ------> [conv2d] ------> output
+ * 1x5x5x3 1x5x5x3 1x5x5x3 ? 1x5x5x4
+ *
+ *
+ * <Expected result>
+ * 4x1x1x3
+ * [new_perm] ----------+-----------------------------------+ [filter] ------+
+ * (3,2,1,0) | | |
+ * | | [bias] ------+
+ * | | |
+ * input ------> [transpose1] ------> [relu] ------> [transpose2] ------> [conv2d] ------> output
+ * 1x5x5x3 3x5x5x1 3x5x5x1 1x5x5x3 1x5x5x4
+ *
+ */
+TEST(CircleShapeInferencePassTest, original_node_change)
+{
+ luci::CircleShapeInferencePass pass;
+ auto g = loco::make_graph();
+
+ // Have to be packed into lambda to check throw
+ auto shape_inference_run = [&]() {
+ while (pass.run(g.get()) == true)
+ ;
+ };
+
+ // Create nodes to make relu traversed first
+ auto input = g->nodes()->create<luci::CircleInput>();
+ auto relu = g->nodes()->create<luci::CircleRelu>();
+ auto old_perm = g->nodes()->create<luci::CircleConst>();
+ auto transpose1 = g->nodes()->create<luci::CircleTranspose>();
+ auto filter = g->nodes()->create<luci::CircleConst>();
+ auto bias = g->nodes()->create<luci::CircleConst>();
+ auto conv2d = g->nodes()->create<luci::CircleConv2D>();
+ auto output = g->nodes()->create<luci::CircleOutput>();
+ auto new_perm = g->nodes()->create<luci::CircleConst>();
+ auto transpose2 = g->nodes()->create<luci::CircleTranspose>();
+
+ // Build up initial graph
+ auto graph_input = g->inputs()->create();
+ graph_input->shape({1, 5, 5, 3});
+
+ input->index(graph_input->index());
+ input->shape({1, 5, 5, 3});
+ input->shape_status(luci::ShapeStatus::VALID);
+
+ old_perm->dtype(loco::DataType::S32);
+ old_perm->size<loco::DataType::S32>(4);
+ old_perm->shape({4});
+ old_perm->at<loco::DataType::S32>(0) = 0;
+ old_perm->at<loco::DataType::S32>(1) = 2;
+ old_perm->at<loco::DataType::S32>(2) = 1;
+ old_perm->at<loco::DataType::S32>(3) = 3;
+ old_perm->shape_status(luci::ShapeStatus::VALID);
+
+ transpose1->a(input);
+ transpose1->perm(old_perm);
+
+ relu->features(transpose1);
+
+ filter->dtype(loco::DataType::FLOAT32);
+ filter->size<loco::DataType::FLOAT32>(4 * 1 * 1 * 3);
+ filter->shape({4, 1, 1, 3});
+ filter->shape_status(luci::ShapeStatus::VALID);
+
+ bias->dtype(loco::DataType::FLOAT32);
+ bias->size<loco::DataType::FLOAT32>(4);
+ bias->shape({4});
+ bias->shape_status(luci::ShapeStatus::VALID);
+
+ conv2d->input(relu);
+ conv2d->filter(filter);
+ conv2d->bias(bias);
+ conv2d->padding(luci::Padding::VALID);
+ conv2d->stride()->h(1);
+ conv2d->stride()->w(1);
+ conv2d->dilation()->h(1);
+ conv2d->dilation()->w(1);
+
+ output->from(conv2d);
+ auto graph_output = g->outputs()->create();
+ output->index(graph_output->index());
+ graph_output->shape({1, 5, 5, 4});
+
+ ASSERT_NO_THROW(shape_inference_run());
+
+ // Transform graph
+ new_perm->dtype(loco::DataType::S32);
+ new_perm->size<loco::DataType::S32>(4);
+ new_perm->shape({4});
+ new_perm->at<loco::DataType::S32>(0) = 3;
+ new_perm->at<loco::DataType::S32>(1) = 2;
+ new_perm->at<loco::DataType::S32>(2) = 1;
+ new_perm->at<loco::DataType::S32>(3) = 0;
+ new_perm->shape_status(luci::ShapeStatus::VALID);
+
+ transpose1->perm(new_perm);
+
+ transpose2->a(relu);
+ transpose2->perm(new_perm);
+
+ conv2d->input(transpose2);
+
+ ASSERT_NO_THROW(shape_inference_run());
+
+ // Check result of shape inference is correct
+ ASSERT_EQ(3, transpose1->dim(0).value());
+ ASSERT_EQ(5, transpose1->dim(1).value());
+ ASSERT_EQ(5, transpose1->dim(2).value());
+ ASSERT_EQ(1, transpose1->dim(3).value());
+
+ ASSERT_EQ(3, relu->dim(0).value());
+ ASSERT_EQ(5, relu->dim(1).value());
+ ASSERT_EQ(5, relu->dim(2).value());
+ ASSERT_EQ(1, relu->dim(3).value());
+
+ ASSERT_EQ(1, transpose2->dim(0).value());
+ ASSERT_EQ(5, transpose2->dim(1).value());
+ ASSERT_EQ(5, transpose2->dim(2).value());
+ ASSERT_EQ(3, transpose2->dim(3).value());
+
+ ASSERT_EQ(1, conv2d->dim(0).value());
+ ASSERT_EQ(5, conv2d->dim(1).value());
+ ASSERT_EQ(5, conv2d->dim(2).value());
+ ASSERT_EQ(4, conv2d->dim(3).value());
+
+ SUCCEED();
+}
+
+/**
+ * This test is for checking when imported shape is wrong.
+ *
+ * Even "concat1" has wrong shape at first, correct shape should be inferred.
+ *
+ * <Initial graph>
+ *
+ * 1x1x1x1
+ * input1 ------+ 8x7x6x5
+ * +-----> [concat1] ------+
+ * input2 ------+ (axis=3) | 1x1x2x3
+ * 1x1x1x2 +------> [concat2] ------> output
+ * | (axis=2)
+ * 1x1x1x3 |
+ * input3 ------------------------------+
+ *
+ *
+ * <Expected result>
+ *
+ * 1x1x1x1
+ * input1 ------+ 1x1x1x3
+ * +-----> [concat1] ------+
+ * input2 ------+ (axis=3) | 1x1x2x3
+ * 1x1x1x2 +------> [concat2] ------> output
+ * | (axis=2)
+ * 1x1x1x3 |
+ * input3 ------------------------------+
+ */
+TEST(CircleShapeInferencePassTest, wrong_imported_shape)
+{
+ luci::CircleShapeInferencePass pass;
+ auto g = loco::make_graph();
+
+ // Have to be packed into lambda to check throw
+ auto shape_inference_run = [&]() {
+ while (pass.run(g.get()) == true)
+ ;
+ };
+
+ // Create nodes to make concat2 traversed first
+ auto concat2 = g->nodes()->create<luci::CircleConcatenation>(2);
+ auto concat1 = g->nodes()->create<luci::CircleConcatenation>(2);
+ auto input1 = g->nodes()->create<luci::CircleInput>();
+ auto input2 = g->nodes()->create<luci::CircleInput>();
+ auto input3 = g->nodes()->create<luci::CircleInput>();
+
+ // Build up initial graph
+ auto graph_input1 = g->inputs()->create();
+ auto graph_input2 = g->inputs()->create();
+ auto graph_input3 = g->inputs()->create();
+ graph_input1->shape({1, 1, 1, 1});
+ graph_input2->shape({1, 1, 1, 2});
+ graph_input2->shape({1, 1, 1, 3});
+
+ input1->index(graph_input1->index());
+ input1->shape({1, 1, 1, 1});
+ input1->shape_status(luci::ShapeStatus::VALID);
+
+ input2->index(graph_input2->index());
+ input2->shape({1, 1, 1, 2});
+ input2->shape_status(luci::ShapeStatus::VALID);
+
+ input3->index(graph_input3->index());
+ input3->shape({1, 1, 1, 3});
+ input3->shape_status(luci::ShapeStatus::VALID);
+
+ concat1->values(0, input1);
+ concat1->values(1, input2);
+ concat1->axis(3);
+ concat1->shape({8, 7, 6, 5}); // Intentionally set wrong shape
+ concat1->shape_status(luci::ShapeStatus::VALID);
+
+ concat2->values(0, concat1);
+ concat2->values(1, input3);
+ concat2->axis(2);
+
+ auto output = g->nodes()->create<luci::CircleOutput>();
+ output->from(concat2);
+ auto graph_output = g->outputs()->create();
+ output->index(graph_output->index());
+ graph_output->shape({1, 1, 2, 3});
+
+ ASSERT_NO_THROW(shape_inference_run());
+
+ // Check result of shape inference is correct
+ ASSERT_EQ(1, concat1->dim(0).value());
+ ASSERT_EQ(1, concat1->dim(1).value());
+ ASSERT_EQ(1, concat1->dim(2).value());
+ ASSERT_EQ(3, concat1->dim(3).value());
+
+ ASSERT_EQ(1, concat2->dim(0).value());
+ ASSERT_EQ(1, concat2->dim(1).value());
+ ASSERT_EQ(2, concat2->dim(2).value());
+ ASSERT_EQ(3, concat2->dim(3).value());
+
+ SUCCEED();
+}
+
+/**
+ * This test is for checking that virtual operations which is not used for graph output
+ * but shape should be exported.
+ *
+ * Although "split_out2" is not used for graph output, shape should be inferenced.
+ *
+ * <Initial graph>
+ *
+ *
+ * 1x6 +----> [split_out1] ----> output
+ * input ------> [split] -----+
+ * (split_dim=1) +----> [split_out2]
+ * (num_split=2)
+ *
+ *
+ * <Expected result>
+ * 1x3 1x3
+ * 1x6 +----> [split_out1] ----> output
+ * input ------> [split] -----+
+ * (split_dim=1) +----> [split_out2]
+ * (num_split=2) 1x3
+ */
+TEST(CircleShapeInferencePassTest, not_used_virtual_op)
+{
+ luci::CircleShapeInferencePass pass;
+ auto g = loco::make_graph();
+
+ // Have to be packed into lambda to check throw
+ auto shape_inference_run = [&]() {
+ while (pass.run(g.get()) == true)
+ ;
+ };
+
+ // Create nodes
+ auto input = g->nodes()->create<luci::CircleInput>();
+ auto split = g->nodes()->create<luci::CircleSplit>();
+ auto split_out1 = g->nodes()->create<luci::CircleSplitOut>();
+ auto split_out2 = g->nodes()->create<luci::CircleSplitOut>();
+ auto split_dim = g->nodes()->create<luci::CircleConst>();
+
+ // Build up initial graph
+ auto graph_input1 = g->inputs()->create();
+ graph_input1->shape({1, 6});
+
+ input->index(graph_input1->index());
+ input->shape({1, 6});
+ input->shape_status(luci::ShapeStatus::VALID);
+
+ split_dim->dtype(loco::DataType::S32);
+ split_dim->size<loco::DataType::S32>(1);
+ split_dim->shape({1});
+ split_dim->at<loco::DataType::S32>(0) = 1;
+ split_dim->shape_status(luci::ShapeStatus::VALID);
+
+ split->split_dim(split_dim);
+ split->input(input);
+ split->num_split(2);
+
+ split_out1->input(split);
+ split_out1->index(0);
+
+ split_out2->input(split);
+ split_out2->index(1);
+
+ auto output = g->nodes()->create<luci::CircleOutput>();
+ output->from(split_out1);
+ auto graph_output = g->outputs()->create();
+ output->index(graph_output->index());
+ graph_output->shape({1, 3});
+
+ ASSERT_NO_THROW(shape_inference_run());
+
+ // Check result of shape inference is correct
+ ASSERT_EQ(1, split_out1->dim(0).value());
+ ASSERT_EQ(3, split_out1->dim(1).value());
+
+ ASSERT_EQ(1, split_out2->dim(0).value());
+ ASSERT_EQ(3, split_out2->dim(1).value());
+
+ SUCCEED();
+}
diff --git a/compiler/luci/pass/src/CircleTypeInferencePass.cpp b/compiler/luci/pass/src/CircleTypeInferencePass.cpp
index 67bd253e0..fb3755ffa 100644
--- a/compiler/luci/pass/src/CircleTypeInferencePass.cpp
+++ b/compiler/luci/pass/src/CircleTypeInferencePass.cpp
@@ -14,6 +14,8 @@
* limitations under the License.
*/
+#include "helpers/InferenceCandidates.h"
+
#include "luci/Pass/CircleTypeInferencePass.h"
#include <luci/Service/CircleTypeInference.h>
@@ -41,7 +43,7 @@ bool CircleTypeInferencePass::run(loco::Graph *g)
luci::tinf::Rule type_infer_rule;
bool changed = false;
- for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ for (auto node : inference_candidates(g))
{
loco::DataType dtype;
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
diff --git a/compiler/luci/pass/src/CircleTypeInferencePass.test.cpp b/compiler/luci/pass/src/CircleTypeInferencePass.test.cpp
new file mode 100644
index 000000000..415424a6f
--- /dev/null
+++ b/compiler/luci/pass/src/CircleTypeInferencePass.test.cpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/CircleTypeInferencePass.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleTypeInferencePassTest, name)
+{
+ luci::CircleTypeInferencePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
new file mode 100644
index 000000000..c9022f122
--- /dev/null
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
@@ -0,0 +1,698 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ConvertNCHWToNHWCPass.h"
+#include "CircleOptimizerUtils.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+#include <luci/Log.h>
+
+namespace
+{
+
+enum class DataFormat
+{
+ NCHW,
+ NHWC
+};
+
+/**
+ * @brief Set annotation for DataFormat (NCHW, NHWC)
+ *
+ * @note DataFormatAnnotation will live longer than this Pass (until the
+ * annotated loco::Node is erased). So, do not use large data in the
+ * annotation to avoid excessive memory usage.
+ */
+class DataFormatAnnotation final : public loco::NodeAnnotation
+{
+public:
+ DataFormatAnnotation(const DataFormat &format) : _format{format}
+ {
+ // DO NOTHING
+ }
+
+public:
+ const DataFormat &format(void) const { return _format; }
+
+private:
+ DataFormat _format;
+};
+
+void set_data_format(loco::Node *node, const DataFormat &format)
+{
+ node->annot(std::make_unique<DataFormatAnnotation>(format));
+}
+
+DataFormat get_data_format(loco::Node *node)
+{
+ assert(node->annot<DataFormatAnnotation>() != nullptr);
+ return node->annot<DataFormatAnnotation>()->format();
+}
+
+bool has_data_format(loco::Node *node) { return node->annot<DataFormatAnnotation>() != nullptr; }
+
+luci::CircleTranspose *create_4d_transpose(luci::CircleNode *node,
+ const std::vector<int32_t> indices)
+{
+ assert(indices.size() == 4);
+
+ auto name = node->name();
+ assert(name.length() > 0);
+
+ auto perm = node->graph()->nodes()->create<luci::CircleConst>();
+ perm->dtype(loco::DataType::S32);
+ perm->size<loco::DataType::S32>(4);
+ perm->rank(1);
+ perm->dim(0) = 4;
+ for (uint32_t i = 0; i < 4; i++)
+ perm->at<loco::DataType::S32>(i) = indices[i];
+ perm->shape_status(luci::ShapeStatus::VALID);
+
+ auto make_string = [](const std::vector<int32_t> &nums) {
+ std::string str;
+ for (auto num : nums)
+ {
+ if (str.length() > 0)
+ str += ".";
+ str += std::to_string(num);
+ }
+ return str;
+ };
+
+ auto str_indices = make_string(indices);
+
+ perm->name(name + "/Transpose_" + str_indices + "/perm");
+
+ auto trans = node->graph()->nodes()->create<luci::CircleTranspose>();
+ trans->perm(perm);
+ trans->name(name + "/Transpose_" + str_indices);
+ luci::add_origin(trans, luci::get_origin(node));
+
+ return trans;
+}
+
+int32_t nchw_axis_to_nhwc(int32_t axis)
+{
+ uint32_t pos_axis = axis >= 0 ? static_cast<uint32_t>(axis) : static_cast<uint32_t>(axis + 4);
+ static const uint32_t to_nhwc[4] = {0, 3, 1, 2};
+ if (pos_axis > 3)
+ throw std::runtime_error("Concat axis must be in range [-4, 4)");
+ return to_nhwc[pos_axis];
+}
+
+luci::CircleTranspose *create_post_transpose(luci::CircleNode *node)
+{
+ return create_4d_transpose(node, {0, 3, 1, 2});
+}
+
+luci::CircleTranspose *create_pre_transpose(luci::CircleNode *node)
+{
+ return create_4d_transpose(node, {0, 2, 3, 1});
+}
+
+uint32_t cal_offset(const loco::TensorShape &dimension, const uint32_t *indices)
+{
+ return indices[0] * dimension.dim(1).value() * dimension.dim(2).value() *
+ dimension.dim(3).value() +
+ indices[1] * dimension.dim(2).value() * dimension.dim(3).value() +
+ indices[2] * dimension.dim(3).value() + indices[3];
+}
+
+luci::CircleConst *create_NHWC_paddings(luci::CircleConst *paddings)
+{
+ // paddings shape is (4,2) (it was checked by is_NCHW)
+ assert(paddings != nullptr);
+ assert(paddings->rank() == 2);
+ assert(paddings->dim(0).value() == 4);
+ assert(paddings->dim(1).value() == 2);
+
+ // paddings for idx 0~3 are 0 (checked by is_NCHW)
+ assert(paddings->at<loco::DataType::S32>(0) == 0);
+ assert(paddings->at<loco::DataType::S32>(1) == 0);
+ assert(paddings->at<loco::DataType::S32>(2) == 0);
+ assert(paddings->at<loco::DataType::S32>(3) == 0);
+
+ auto name = paddings->name();
+ assert(name.length() > 0);
+
+ auto nhwc_paddings = paddings->graph()->nodes()->create<luci::CircleConst>();
+ nhwc_paddings->dtype(loco::DataType::S32);
+ nhwc_paddings->shape({4, 2});
+ nhwc_paddings->shape_status(luci::ShapeStatus::VALID);
+ nhwc_paddings->size<loco::DataType::S32>(4 * 2);
+ nhwc_paddings->name(name + "_NHWC");
+
+ for (uint32_t dim = 0; dim < 4; dim++)
+ {
+ for (uint32_t i = 0; i < 2; i++)
+ {
+ int32_t data = 0;
+
+ if (dim == 1)
+ {
+ // get third dimension (H in NCHW)
+ data = paddings->at<loco::DataType::S32>(2 * 2 + i);
+ }
+ else if (dim == 2)
+ {
+ // get fourth dimension (W in NCHW)
+ data = paddings->at<loco::DataType::S32>(3 * 2 + i);
+ }
+
+ nhwc_paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
+ }
+ }
+ return nhwc_paddings;
+}
+
+luci::CircleConst *create_NHWC_from_NCHW(luci::CircleConst *constant)
+{
+ LOGGER(l);
+ assert(constant->rank() == 4);
+
+ // TODO: Support non-float types
+ if (constant->dtype() != loco::DataType::FLOAT32)
+ {
+ INFO(l) << "Non-float type constant: " << constant->name() << std::endl;
+ return nullptr;
+ }
+
+ loco::TensorShape nchw_dimension{constant->dim(0), constant->dim(1), constant->dim(2),
+ constant->dim(3)};
+ loco::TensorShape nhwc_dimension{constant->dim(0), constant->dim(2), constant->dim(3),
+ constant->dim(1)};
+
+ auto name = constant->name();
+ assert(name.length() > 0);
+
+ auto nhwc_const = constant->graph()->nodes()->create<luci::CircleConst>();
+ nhwc_const->dtype(constant->dtype());
+ nhwc_const->rank(4);
+ nhwc_const->dim(0).set(constant->dim(0).value());
+ nhwc_const->dim(1).set(constant->dim(2).value());
+ nhwc_const->dim(2).set(constant->dim(3).value());
+ nhwc_const->dim(3).set(constant->dim(1).value());
+ nhwc_const->shape_status(luci::ShapeStatus::VALID);
+ nhwc_const->size<loco::DataType::FLOAT32>(constant->size<loco::DataType::FLOAT32>());
+ nhwc_const->name(name + "_NHWC");
+
+ for (uint32_t n = 0; n < nchw_dimension.dim(0).value(); n++)
+ {
+ for (uint32_t c = 0; c < nchw_dimension.dim(1).value(); c++)
+ {
+ for (uint32_t h = 0; h < nchw_dimension.dim(2).value(); h++)
+ {
+ for (uint32_t w = 0; w < nchw_dimension.dim(3).value(); w++)
+ {
+ uint32_t nchw_indices[4] = {n, c, h, w};
+ uint32_t nhwc_indices[4] = {n, h, w, c};
+ auto data =
+ constant->at<loco::DataType::FLOAT32>(cal_offset(nchw_dimension, nchw_indices));
+ nhwc_const->at<loco::DataType::FLOAT32>(cal_offset(nhwc_dimension, nhwc_indices)) = data;
+ }
+ }
+ }
+ }
+ return nhwc_const;
+}
+
+// NOTE Following conditions can be extended later
+//
+// Find PAD with an NCHW pattern described below
+// - Paddings shape : [4, 2]
+// - Paddings value : [[0, 0], [0, 0], [h_t, h_b], [w_t, w_b]]]
+bool is_NCHW(const luci::CirclePad *node)
+{
+ const auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
+ // Non-const paddings is not supported
+ if (paddings == nullptr)
+ return false;
+
+ if (paddings->rank() != 2)
+ return false;
+
+ if (paddings->dim(0).value() != 4 || paddings->dim(1).value() != 2)
+ return false;
+
+ // Only check the first two dimensions
+ for (uint32_t dim = 0; dim < 2; dim++)
+ {
+ for (uint32_t i = 0; i < 2; i++)
+ {
+ auto data = paddings->at<loco::DataType::S32>(dim * 2 + i);
+ if (data != 0)
+ return false;
+ }
+ }
+
+ return true;
+}
+
+// NOTE Following conditions can be extended later
+//
+// Find MUL with an NCHW pattern described below
+// - Input (non-constant) shape : [N, C, H, W]
+// - Input (constant) shape : [1, C, 1, 1]
+// - Output shape : [N, C, H, W]
+bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_node,
+ luci::CircleConst *&multiplier)
+{
+ auto x = dynamic_cast<luci::CircleConst *>(node->x());
+ auto y = dynamic_cast<luci::CircleConst *>(node->y());
+
+ if (x != nullptr && y == nullptr)
+ {
+ pred_node = loco::must_cast<luci::CircleNode *>(node->y());
+ multiplier = x;
+ }
+ else if (x == nullptr && y != nullptr)
+ {
+ pred_node = loco::must_cast<luci::CircleNode *>(node->x());
+ multiplier = y;
+ }
+ else
+ {
+ // Ignore if MUL does not have a multiplier input.
+ return false;
+ }
+
+ if (pred_node->rank() != 4)
+ return false;
+
+ const auto const_rank = multiplier->rank();
+ if (const_rank != 4)
+ return false;
+
+ for (uint32_t i = 0; i < const_rank; i++)
+ {
+ if (i != 1 && multiplier->dim(i).value() != 1)
+ return false;
+ }
+
+ const auto const_cdim = multiplier->dim(1);
+ const auto input_cdim = pred_node->dim(1);
+ const auto output_cdim = node->dim(1);
+
+ if (const_cdim == input_cdim && input_cdim == output_cdim)
+ return true;
+ else
+ return false;
+}
+
+// We assume ADD with const input is NCHW if,
+// Input shape: (N, C, H, W)
+// Output shape: (N, C, H, W)
+// 1. Const shape is (1, C, 1, 1)
+// 2. Input, Output, Const have the same C.
+bool is_NCHW_with_const(const luci::CircleAdd *node, luci::CircleNode *&pred_node,
+ luci::CircleConst *&beta)
+{
+ auto x = dynamic_cast<luci::CircleConst *>(node->x());
+ auto y = dynamic_cast<luci::CircleConst *>(node->y());
+
+ if (x != nullptr && y == nullptr)
+ {
+ pred_node = loco::must_cast<luci::CircleNode *>(node->y());
+ beta = x;
+ }
+ else if (x == nullptr && y != nullptr)
+ {
+ pred_node = loco::must_cast<luci::CircleNode *>(node->x());
+ beta = y;
+ }
+ else
+ {
+ // Ignore if ADD does not have a constant input.
+ return false;
+ }
+
+ if (pred_node->rank() != 4)
+ return false;
+
+ const auto const_rank = beta->rank();
+ if (const_rank != 4)
+ return false;
+
+ // Check the shape is (1, C, 1, 1)
+ for (uint32_t i = 0; i < const_rank; i++)
+ {
+ if (i == 1)
+ continue;
+
+ if (beta->dim(i).value() != 1)
+ return false;
+ }
+
+ const auto const_cdim = beta->dim(1);
+ const auto input_cdim = pred_node->dim(1);
+ const auto output_cdim = node->dim(1);
+
+ // Check Input, Output, Const have the same channel size
+ if (const_cdim == input_cdim && input_cdim == output_cdim)
+ return true;
+ else
+ return false;
+}
+
+template <class T> bool convert_unary_features(T *node)
+{
+ const auto pred_node = loco::must_cast<luci::CircleNode *>(node->features());
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(pred_node);
+ node->features(pre_trans);
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+
+ return true;
+}
+
+class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
+{
+ // Default
+ bool visit(luci::CircleNode *node)
+ {
+ throw std::runtime_error(node->name() + " is an unsupported operator.");
+ }
+
+ bool visit(luci::CircleInput *node)
+ {
+ const auto n = node->dim(0);
+ const auto c = node->dim(1);
+ const auto h = node->dim(2);
+ const auto w = node->dim(3);
+
+ node->dim(1) = h;
+ node->dim(2) = w;
+ node->dim(3) = c;
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ // Insert post-tranpose
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+
+ // Update graph input
+ auto graph_inputs = node->graph()->inputs();
+ auto graph_input = graph_inputs->at(node->index());
+ graph_input->shape({n, h, w, c});
+
+ return true;
+ }
+
+ bool visit(luci::CircleOutput *node)
+ {
+ // Insert pre-transpose
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(node->from());
+
+ node->from(pre_trans);
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ // Update graph output
+ const auto n = node->dim(0).value();
+ const auto c = node->dim(1).value();
+ const auto h = node->dim(2).value();
+ const auto w = node->dim(3).value();
+
+ auto graph_outputs = node->graph()->outputs();
+ auto graph_output = graph_outputs->at(node->index());
+ graph_output->shape({n, h, w, c});
+
+ return true;
+ }
+
+ bool visit(luci::CircleAdd *node)
+ {
+ luci::CircleNode *pred_node = nullptr;
+ luci::CircleConst *beta = nullptr;
+
+ if (is_NCHW_with_const(node, pred_node, beta))
+ {
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(pred_node);
+
+ auto nhwc_const = create_NHWC_from_NCHW(beta);
+ if (nhwc_const == nullptr)
+ return false;
+
+ node->x(pre_trans);
+ node->y(nhwc_const);
+ }
+ else if (beta == nullptr)
+ {
+ // Both inputs are not constant.
+ // In this case, we cannot distinguish NCHW from NHWC,
+ // so just insert Transpose Ops.
+ auto pre_trans_x = create_pre_transpose(node);
+ pre_trans_x->a(node->x());
+ node->x(pre_trans_x);
+
+ auto pre_trans_y = create_pre_transpose(node);
+ pre_trans_y->a(node->y());
+ node->y(pre_trans_y);
+ }
+ else
+ {
+ return false;
+ }
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+ return true;
+ }
+
+ bool visit(luci::CircleConcatenation *node)
+ {
+ const auto num_values = node->numValues();
+ for (uint32_t i = 0; i < num_values; i++)
+ {
+ auto pred_node = loco::must_cast<luci::CircleNode *>(node->values(i));
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(pred_node);
+ node->values(i, pre_trans);
+ }
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ node->axis(nchw_axis_to_nhwc(node->axis()));
+
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+
+ return true;
+ }
+
+ bool visit(luci::CircleLeakyRelu *node)
+ {
+ return convert_unary_features<luci::CircleLeakyRelu>(node);
+ }
+
+ bool visit(luci::CircleMul *node)
+ {
+ LOGGER(l);
+
+ luci::CircleNode *pred_node = nullptr;
+ luci::CircleConst *multiplier = nullptr;
+
+ if (is_NCHW_with_const(node, pred_node, multiplier))
+ {
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(pred_node);
+ node->x(pre_trans);
+
+ auto nhwc_const = create_NHWC_from_NCHW(multiplier);
+ node->y(nhwc_const);
+ }
+ else if (multiplier == nullptr)
+ {
+ // TODO : Implement this case.
+ INFO(l) << "Not yet implemented. Both inputs of MUL are non-const." << std::endl;
+ return false;
+ }
+ else
+ {
+ return false;
+ }
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+ return true;
+ }
+
+ bool visit(luci::CircleNeg *node)
+ {
+ const auto pred_node = loco::must_cast<luci::CircleNode *>(node->x());
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(pred_node);
+ node->x(pre_trans);
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+
+ return true;
+ }
+
+ bool visit(luci::CirclePad *node)
+ {
+ if (!is_NCHW(node))
+ return false;
+
+ const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(pred_node);
+ node->input(pre_trans);
+
+ auto nchw_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
+ const auto nhwc_paddings = create_NHWC_paddings(nchw_paddings);
+ node->paddings(nhwc_paddings);
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+
+ return true;
+ }
+
+ bool visit(luci::CircleRelu *node) { return convert_unary_features<luci::CircleRelu>(node); }
+
+ bool visit(luci::CircleRelu6 *node) { return convert_unary_features<luci::CircleRelu6>(node); }
+};
+
+} // namespace
+
+namespace luci
+{
+
+bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+ INFO(l) << "ConvertNCHWToNHWCPass Start" << std::endl;
+
+ // Annotate NCHW operators
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ switch (circle_node->opcode())
+ {
+ // List of supported Ops
+ case luci::CircleOpcode::CIRCLEINPUT:
+ if (!_preserve_input && !has_data_format(node))
+ {
+ set_data_format(node, DataFormat::NCHW);
+ }
+ break;
+ case luci::CircleOpcode::CIRCLEOUTPUT:
+ if (!_preserve_output && !has_data_format(node))
+ {
+ set_data_format(node, DataFormat::NCHW);
+ }
+ break;
+ case luci::CircleOpcode::ADD:
+ case luci::CircleOpcode::CONCATENATION:
+ case luci::CircleOpcode::LEAKY_RELU:
+ case luci::CircleOpcode::MUL:
+ case luci::CircleOpcode::NEG:
+ case luci::CircleOpcode::PAD:
+ case luci::CircleOpcode::RELU:
+ case luci::CircleOpcode::RELU6:
+ if (!has_data_format(node))
+ {
+ set_data_format(node, DataFormat::NCHW);
+ }
+ break;
+ default:
+ break;
+ }
+ }
+
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (!has_data_format(node))
+ {
+ // Unsupported Op
+ continue;
+ }
+ else if (get_data_format(node) == DataFormat::NHWC)
+ {
+ // Already converted to NHWC
+ continue;
+ }
+ else if (has_dynamic_shape(node))
+ {
+ // This pass only works for static-shaped node
+ INFO(l) << "Skip the node with a dynamic shape." << std::endl;
+ continue;
+ }
+ else
+ {
+ ConvertNCHWToNHWC converter;
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (circle_node->rank() != 4)
+ continue;
+
+ if (circle_node->accept(&converter))
+ {
+ set_data_format(node, DataFormat::NHWC);
+ changed = true;
+ }
+ else
+ {
+ continue;
+ }
+ }
+ }
+
+ INFO(l) << "ConvertNCHWToNHWCPass End" << std::endl;
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
new file mode 100644
index 000000000..831d5f89a
--- /dev/null
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
@@ -0,0 +1,636 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <logo/Phase.h>
+
+#include "luci/Pass/ConvertNCHWToNHWCPass.h"
+#include "luci/Pass/CircleShapeInferencePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ * Graph with a single Op (example: Add).
+ *
+ * BEFORE
+ * - All Ops including Input/Output are NCHW.
+ *
+ * [Input] [beta]
+ * | /
+ * [Add]
+ * |
+ * [Output]
+ *
+ * AFTER
+ * - All Ops including Input/Output are NHWC.
+ *
+ * [Input]
+ * |
+ * [Transpose]
+ * |
+ * [Transpose] [beta]
+ * | /
+ * [Add]
+ * |
+ * [Transpose]
+ * |
+ * [Transpose]
+ * |
+ * [Output]
+ */
+class SimpleGraph
+{
+public:
+ SimpleGraph() = default;
+
+public:
+ void init()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ output = g.nodes()->create<luci::CircleOutput>();
+ input->name("input");
+ output->name("output");
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ graph_input->dtype(loco::DataType::FLOAT32);
+ input->dtype(loco::DataType::FLOAT32);
+ output->dtype(loco::DataType::FLOAT32);
+ graph_output->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ graph_input->shape({1, channel_size, 4, 4});
+ input->shape({1, channel_size, 4, 4});
+ output->shape({1, channel_size, 4, 4});
+ graph_output->shape({1, channel_size, 4, 4});
+
+ auto graph_body = insertGraphBody(input);
+ output->from(graph_body);
+ }
+
+ virtual ~SimpleGraph() = default;
+
+protected:
+ virtual loco::Node *insertGraphBody(loco::Node *input) = 0;
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+class AddGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ add = g.nodes()->create<luci::CircleAdd>();
+ beta = g.nodes()->create<luci::CircleConst>();
+
+ add->dtype(loco::DataType::FLOAT32);
+ beta->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ add->shape({1, channel_size, 4, 4});
+ beta->shape({1, channel_size, 1, 1});
+
+ beta->size<loco::DataType::FLOAT32>(channel_size);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ beta->at<loco::DataType::FLOAT32>(i) = i;
+ }
+
+ add->x(input);
+ add->y(beta);
+
+ add->name("add");
+ beta->name("beta");
+
+ return add;
+ }
+
+public:
+ luci::CircleAdd *add = nullptr;
+ luci::CircleConst *beta = nullptr;
+};
+
+class ConcatenationGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ concat = g.nodes()->create<luci::CircleConcatenation>(2);
+ concat->values(0, input);
+ concat->axis(1);
+
+ input2 = g.nodes()->create<luci::CircleConst>();
+ input2->dtype(loco::DataType::FLOAT32);
+ input2->shape({1, 16, 4, 4});
+ input2->size<loco::DataType::FLOAT32>(16 * 4 * 4);
+ for (uint32_t i = 0; i < 16 * 4 * 4; i++)
+ {
+ input2->at<loco::DataType::FLOAT32>(i) = i;
+ }
+ concat->values(1, input2);
+
+ concat->name("concat");
+ input2->name("input2");
+
+ return concat;
+ }
+
+public:
+ luci::CircleConcatenation *concat = nullptr;
+ luci::CircleConst *input2 = nullptr;
+};
+
+class LeakyReluGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ leakyrelu = g.nodes()->create<luci::CircleLeakyRelu>();
+ leakyrelu->features(input);
+ leakyrelu->name("leakyrelu");
+
+ return leakyrelu;
+ }
+
+public:
+ luci::CircleLeakyRelu *leakyrelu = nullptr;
+};
+
+class MulGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ mul = g.nodes()->create<luci::CircleMul>();
+ multiplier = g.nodes()->create<luci::CircleConst>();
+
+ mul->dtype(loco::DataType::FLOAT32);
+ multiplier->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ mul->shape({1, channel_size, 4, 4});
+ multiplier->shape({1, channel_size, 1, 1});
+
+ multiplier->size<loco::DataType::FLOAT32>(channel_size);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ multiplier->at<loco::DataType::FLOAT32>(i) = i;
+ }
+
+ mul->x(input);
+ mul->y(multiplier);
+
+ mul->name("mul");
+ multiplier->name("multiplier");
+
+ return mul;
+ }
+
+public:
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *multiplier = nullptr;
+};
+
+class NegGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ neg = g.nodes()->create<luci::CircleNeg>();
+ neg->x(input);
+ neg->name("neg");
+
+ return neg;
+ }
+
+public:
+ luci::CircleNeg *neg = nullptr;
+};
+
+class PadGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ pad = g.nodes()->create<luci::CirclePad>();
+ paddings = g.nodes()->create<luci::CircleConst>();
+
+ pad->dtype(loco::DataType::FLOAT32);
+ paddings->dtype(loco::DataType::S32);
+
+ uint32_t channel_size = 16;
+ pad->shape({1, channel_size, 4, 4});
+ paddings->shape({4, 2});
+
+ // paddings data (NCHW)
+ // [[0,0], [0,0], [1,1], [2,2]]
+ paddings->size<loco::DataType::S32>(8);
+ for (uint32_t dim = 0; dim < 4; dim++)
+ {
+ for (uint32_t i = 0; i < 2; i++)
+ {
+ int32_t data = 0;
+
+ if (dim == 2)
+ data = 1;
+ else if (dim == 3)
+ data = 2;
+
+ paddings->at<loco::DataType::S32>(dim * 2 + i) = data;
+ }
+ }
+
+ pad->input(input);
+ pad->paddings(paddings);
+
+ pad->name("pad");
+ paddings->name("paddings");
+
+ return pad;
+ }
+
+public:
+ luci::CirclePad *pad = nullptr;
+ luci::CircleConst *paddings = nullptr;
+};
+
+class ReluGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ relu = g.nodes()->create<luci::CircleRelu>();
+ relu->features(input);
+ relu->name("Relu");
+
+ return relu;
+ }
+
+public:
+ luci::CircleRelu *relu = nullptr;
+};
+
+class Relu6Graph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ relu6 = g.nodes()->create<luci::CircleRelu6>();
+ relu6->features(input);
+ relu6->name("relu6");
+
+ return relu6;
+ }
+
+public:
+ luci::CircleRelu6 *relu6 = nullptr;
+};
+
+void check_pre_trans(loco::Node *node)
+{
+ auto pre_trans = dynamic_cast<luci::CircleTranspose *>(node);
+ EXPECT_NE(nullptr, pre_trans);
+ auto pre_trans_perm = dynamic_cast<luci::CircleConst *>(pre_trans->perm());
+ EXPECT_NE(nullptr, pre_trans_perm);
+ EXPECT_EQ(1, pre_trans_perm->rank());
+ EXPECT_EQ(4, pre_trans_perm->dim(0).value());
+ EXPECT_EQ(loco::DataType::S32, pre_trans_perm->dtype());
+ EXPECT_EQ(0, pre_trans_perm->at<loco::DataType::S32>(0));
+ EXPECT_EQ(2, pre_trans_perm->at<loco::DataType::S32>(1));
+ EXPECT_EQ(3, pre_trans_perm->at<loco::DataType::S32>(2));
+ EXPECT_EQ(1, pre_trans_perm->at<loco::DataType::S32>(3));
+}
+
+void check_post_trans(loco::Node *node)
+{
+ auto post_trans = dynamic_cast<luci::CircleTranspose *>(node);
+ EXPECT_NE(nullptr, post_trans);
+ auto post_trans_perm = dynamic_cast<luci::CircleConst *>(post_trans->perm());
+ EXPECT_NE(nullptr, post_trans_perm);
+ EXPECT_EQ(1, post_trans_perm->rank());
+ EXPECT_EQ(4, post_trans_perm->dim(0).value());
+ EXPECT_EQ(loco::DataType::S32, post_trans_perm->dtype());
+ EXPECT_EQ(0, post_trans_perm->at<loco::DataType::S32>(0));
+ EXPECT_EQ(3, post_trans_perm->at<loco::DataType::S32>(1));
+ EXPECT_EQ(1, post_trans_perm->at<loco::DataType::S32>(2));
+ EXPECT_EQ(2, post_trans_perm->at<loco::DataType::S32>(3));
+}
+
+void run_phase(loco::Graph *g, bool preserve_input, bool preserve_output)
+{
+ logo::Phase phase;
+
+ // Default passes.
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+
+ // Pass to test
+ phase.emplace_back(
+ std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
+
+ logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
+ phase_runner.run(phase);
+}
+
+} // namespace
+
+TEST(ConvertNCHWToNHWCPassTest, name)
+{
+ luci::ConvertNCHWToNHWCPass pass(false, false);
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(ConvertNCHWToNHWC, Add)
+{
+ AddGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.add->x());
+
+ auto add_succs = loco::succs(g.add);
+ EXPECT_EQ(1, add_succs.size());
+ check_post_trans(*add_succs.begin());
+
+ uint32_t channel_size = 16;
+ auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
+ EXPECT_NE(nullptr, new_beta);
+ EXPECT_EQ(4, new_beta->rank());
+ EXPECT_EQ(1, new_beta->dim(0).value());
+ EXPECT_EQ(1, new_beta->dim(1).value());
+ EXPECT_EQ(1, new_beta->dim(2).value());
+ EXPECT_EQ(channel_size, new_beta->dim(3).value());
+
+ check_pre_trans(g.output->from());
+}
+
+TEST(ConvertNCHWToNHWC, Concatenation)
+{
+ ConcatenationGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.concat->values(0));
+ check_pre_trans(g.concat->values(1));
+
+ auto concat_succs = loco::succs(g.concat);
+ EXPECT_EQ(1, concat_succs.size());
+ check_post_trans(*concat_succs.begin());
+
+ // Check concat shape, axis
+ EXPECT_EQ(1, g.concat->dim(0).value());
+ EXPECT_EQ(4, g.concat->dim(1).value());
+ EXPECT_EQ(4, g.concat->dim(2).value());
+ EXPECT_EQ(32, g.concat->dim(3).value());
+ EXPECT_EQ(3, g.concat->axis());
+}
+
+TEST(ConvertNCHWToNHWC, LeakyRelu)
+{
+ LeakyReluGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.leakyrelu->features());
+
+ auto leakyrelu_succs = loco::succs(g.leakyrelu);
+ EXPECT_EQ(1, leakyrelu_succs.size());
+ check_post_trans(*leakyrelu_succs.begin());
+
+ // Check leakyrelu shape
+ EXPECT_EQ(1, g.leakyrelu->dim(0).value());
+ EXPECT_EQ(4, g.leakyrelu->dim(1).value());
+ EXPECT_EQ(4, g.leakyrelu->dim(2).value());
+ EXPECT_EQ(16, g.leakyrelu->dim(3).value());
+}
+
+TEST(ConvertNCHWToNHWC, Mul)
+{
+ MulGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.mul->x());
+
+ auto mul_succs = loco::succs(g.mul);
+ EXPECT_EQ(1, mul_succs.size());
+ check_post_trans(*mul_succs.begin());
+
+ uint32_t channel_size = 16;
+ auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
+ EXPECT_NE(nullptr, new_multiplier);
+ EXPECT_EQ(4, new_multiplier->rank());
+ EXPECT_EQ(1, new_multiplier->dim(0).value());
+ EXPECT_EQ(1, new_multiplier->dim(1).value());
+ EXPECT_EQ(1, new_multiplier->dim(2).value());
+ EXPECT_EQ(channel_size, new_multiplier->dim(3).value());
+
+ check_pre_trans(g.output->from());
+}
+
+TEST(ConvertNCHWToNHWC, Neg)
+{
+ NegGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.neg->x());
+
+ auto neg_succs = loco::succs(g.neg);
+ EXPECT_EQ(1, neg_succs.size());
+ check_post_trans(*neg_succs.begin());
+
+ // Check leakyrelu shape
+ EXPECT_EQ(1, g.neg->dim(0).value());
+ EXPECT_EQ(4, g.neg->dim(1).value());
+ EXPECT_EQ(4, g.neg->dim(2).value());
+ EXPECT_EQ(16, g.neg->dim(3).value());
+}
+
+TEST(ConvertNCHWToNHWC, Pad)
+{
+ PadGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
+
+ check_pre_trans(g.pad->input());
+
+ auto pad_succs = loco::succs(g.pad);
+ EXPECT_EQ(1, pad_succs.size());
+ check_post_trans(*pad_succs.begin());
+
+ auto new_paddings = dynamic_cast<luci::CircleConst *>(g.pad->paddings());
+ EXPECT_NE(nullptr, new_paddings);
+ EXPECT_EQ(2, new_paddings->rank());
+ EXPECT_EQ(4, new_paddings->dim(0).value());
+ EXPECT_EQ(2, new_paddings->dim(1).value());
+ EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(0));
+ EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(1));
+ EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(2));
+ EXPECT_EQ(1, new_paddings->at<loco::DataType::S32>(3));
+ EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(4));
+ EXPECT_EQ(2, new_paddings->at<loco::DataType::S32>(5));
+ EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(6));
+ EXPECT_EQ(0, new_paddings->at<loco::DataType::S32>(7));
+
+ check_pre_trans(g.output->from());
+}
+
+TEST(ConvertNCHWToNHWC, Unknown_Shape_NEG)
+{
+ AddGraph g;
+ g.init();
+
+ // Unknown shape
+ g.input->dim(0).unset();
+ g.add->dim(0).unset();
+ g.output->dim(0).unset();
+
+ luci::ConvertNCHWToNHWCPass pass(false, false);
+ EXPECT_EQ(false, pass.run(&g.g));
+}
+
+TEST(ConvertNCHWToNHWC, Preserve_Input_Output)
+{
+ // Preserve input
+ {
+ AddGraph g;
+ g.init();
+
+ run_phase(&g.g, true, false);
+
+ // Check input shape
+ EXPECT_EQ(1, g.input->dim(0).value());
+ EXPECT_EQ(16, g.input->dim(1).value());
+ EXPECT_EQ(4, g.input->dim(2).value());
+ EXPECT_EQ(4, g.input->dim(3).value());
+
+ // Check output shape
+ EXPECT_EQ(1, g.output->dim(0).value());
+ EXPECT_EQ(4, g.output->dim(1).value());
+ EXPECT_EQ(4, g.output->dim(2).value());
+ EXPECT_EQ(16, g.output->dim(3).value());
+ }
+
+ // Preserve output
+ {
+ AddGraph g;
+ g.init();
+
+ run_phase(&g.g, false, true);
+
+ // Check input shape
+ EXPECT_EQ(1, g.input->dim(0).value());
+ EXPECT_EQ(4, g.input->dim(1).value());
+ EXPECT_EQ(4, g.input->dim(2).value());
+ EXPECT_EQ(16, g.input->dim(3).value());
+
+ // Check output shape
+ EXPECT_EQ(1, g.output->dim(0).value());
+ EXPECT_EQ(16, g.output->dim(1).value());
+ EXPECT_EQ(4, g.output->dim(2).value());
+ EXPECT_EQ(4, g.output->dim(3).value());
+ }
+
+ // Preserve both input and output
+ {
+ AddGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ // Check input shape
+ EXPECT_EQ(1, g.input->dim(0).value());
+ EXPECT_EQ(16, g.input->dim(1).value());
+ EXPECT_EQ(4, g.input->dim(2).value());
+ EXPECT_EQ(4, g.input->dim(3).value());
+
+ // Check output shape
+ EXPECT_EQ(1, g.output->dim(0).value());
+ EXPECT_EQ(16, g.output->dim(1).value());
+ EXPECT_EQ(4, g.output->dim(2).value());
+ EXPECT_EQ(4, g.output->dim(3).value());
+ }
+}
+
+TEST(ConvertNCHWToNHWC, Relu)
+{
+ ReluGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.relu->features());
+
+ auto relu_succs = loco::succs(g.relu);
+ EXPECT_EQ(1, relu_succs.size());
+ check_post_trans(*relu_succs.begin());
+
+ // Check relu shape
+ EXPECT_EQ(1, g.relu->dim(0).value());
+ EXPECT_EQ(4, g.relu->dim(1).value());
+ EXPECT_EQ(4, g.relu->dim(2).value());
+ EXPECT_EQ(16, g.relu->dim(3).value());
+}
+
+TEST(ConvertNCHWToNHWC, Relu6)
+{
+ Relu6Graph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.relu6->features());
+
+ auto relu6_succs = loco::succs(g.relu6);
+ EXPECT_EQ(1, relu6_succs.size());
+ check_post_trans(*relu6_succs.begin());
+
+ // Check relu6 shape
+ EXPECT_EQ(1, g.relu6->dim(0).value());
+ EXPECT_EQ(4, g.relu6->dim(1).value());
+ EXPECT_EQ(4, g.relu6->dim(2).value());
+ EXPECT_EQ(16, g.relu6->dim(3).value());
+}
diff --git a/compiler/luci/pass/src/FoldAddV2Pass.cpp b/compiler/luci/pass/src/FoldAddV2Pass.cpp
new file mode 100644
index 000000000..20c1022f8
--- /dev/null
+++ b/compiler/luci/pass/src/FoldAddV2Pass.cpp
@@ -0,0 +1,122 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldAddV2Pass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <iostream>
+
+namespace
+{
+
+bool same_shape(const luci::CircleConst *x, const luci::CircleConst *y)
+{
+ if (x->rank() != y->rank())
+ return false;
+
+ for (uint32_t i = 0; i < x->rank(); i++)
+ {
+ if (!(x->dim(i) == y->dim(i)))
+ return false;
+ }
+
+ return true;
+}
+
+/**
+ * Fold AddV2 to const if both inputs are const
+ **/
+template <loco::DataType T> bool fold_add_v2(luci::CircleCustom *add_v2)
+{
+ // This should hold for AddV2
+ if (add_v2->numInputs() != 2)
+ return false;
+
+ // Check first input is const
+ auto x = dynamic_cast<luci::CircleConst *>(add_v2->inputs(0));
+ if (not x)
+ return false;
+
+ // Check second input is const
+ auto y = dynamic_cast<luci::CircleConst *>(add_v2->inputs(1));
+ if (not y)
+ return false;
+
+ if (x->dtype() != y->dtype())
+ return false;
+
+ if (!same_shape(x, y))
+ return false;
+
+ auto name_x = x->name();
+ auto name_y = y->name();
+ assert(name_x.length() > 0);
+ assert(name_y.length() > 0);
+ auto constant = add_v2->graph()->nodes()->create<luci::CircleConst>();
+ constant->dtype(x->dtype());
+ constant->rank(x->rank());
+ for (uint32_t i = 0; i < x->rank(); i++)
+ constant->dim(i).set(x->dim(i).value());
+
+ const auto size = x->size<T>();
+ constant->size<T>(size);
+ for (uint32_t i = 0; i < size; i++)
+ constant->at<T>(i) = x->at<T>(i) + y->at<T>(i);
+
+ constant->shape_status(luci::ShapeStatus::VALID);
+ constant->name(name_x + ";" + name_y);
+
+ for (auto succ : loco::succs(add_v2))
+ {
+ auto custom_out = loco::must_cast<luci::CircleCustomOut *>(succ);
+ loco::replace(custom_out).with(constant);
+ }
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * Constant Folding for AddV2 Op
+ **/
+bool FoldAddV2Pass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto custom = dynamic_cast<luci::CircleCustom *>(node))
+ {
+ if (custom->custom_code() == "AddV2")
+ {
+ // TODO: Support more data types
+ if (custom->dtype() == loco::DataType::S64)
+ {
+ if (fold_add_v2<loco::DataType::S64>(custom))
+ changed = true;
+ }
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FoldAddV2Pass.test.cpp b/compiler/luci/pass/src/FoldAddV2Pass.test.cpp
new file mode 100644
index 000000000..438d7f077
--- /dev/null
+++ b/compiler/luci/pass/src/FoldAddV2Pass.test.cpp
@@ -0,0 +1,137 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldAddV2Pass.h"
+#include "PassTestGraphs.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ * Graph has an AddV2 Op with constant inputs
+ *
+ * BEFORE
+ *
+ * [CircleConst] [CircleConst]
+ * | |
+ * [CircleCustom (AddV2)]
+ * |
+ * [CircleCustomOut]
+ *
+ * AFTER
+ *
+ * [CircleConst]
+ */
+template <loco::DataType T> class FoldAddV2Test : public luci::ConstantFoldingAddTestGraph
+{
+public:
+ FoldAddV2Test(std::initializer_list<uint32_t> shape) : luci::ConstantFoldingAddTestGraph(shape, T)
+ {
+ _addV2 = _g.nodes()->create<luci::CircleCustom>(2, 1);
+ _x = _g.nodes()->create<luci::CircleConst>();
+ _y = _g.nodes()->create<luci::CircleConst>();
+ _addV2_out = _g.nodes()->create<luci::CircleCustomOut>();
+
+ _addV2->dtype(T);
+ _x->dtype(T);
+ _y->dtype(T);
+ _addV2_out->dtype(T);
+
+ _addV2->shape(shape);
+ _x->shape(shape);
+ _y->shape(shape);
+ _addV2_out->shape(shape);
+
+ uint32_t num_elems = 1;
+ for (auto dim = shape.begin(); dim != shape.end(); dim++)
+ num_elems *= *dim;
+
+ _x->size<T>(num_elems);
+ _y->size<T>(num_elems);
+
+ for (uint32_t i = 0; i < num_elems; i++)
+ {
+ _x->at<T>(i) = i + 1;
+ _y->at<T>(i) = i + 1;
+ }
+
+ _addV2->custom_code("AddV2");
+ _addV2->inputs(0, _x);
+ _addV2->inputs(1, _y);
+ _addV2_out->input(_addV2);
+
+ _addV2->name("addV2");
+ _x->name("x");
+ _y->name("y");
+ }
+
+ loco::Node *createFoldedPattern() override { return _addV2_out; }
+
+ virtual ~FoldAddV2Test() = default;
+
+protected:
+ luci::CircleCustom *_addV2 = nullptr;
+ luci::CircleCustomOut *_addV2_out = nullptr;
+ luci::CircleConst *_x = nullptr;
+ luci::CircleConst *_y = nullptr;
+};
+
+class FoldS64AddV2Test : public FoldAddV2Test<loco::DataType::S64>, public ::testing::Test
+{
+public:
+ FoldS64AddV2Test() : FoldAddV2Test<loco::DataType::S64>({3}) {}
+
+ virtual void SetUp() { init(); }
+};
+
+} // namespace
+
+TEST(FoldAddV2PassTest, name)
+{
+ luci::FoldAddV2Pass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(FoldS64AddV2Test, fold_addV2)
+{
+ luci::FoldAddV2Pass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Check type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::S64, folded_const->dtype());
+ EXPECT_EQ(1, folded_const->rank());
+ EXPECT_EQ(3, folded_const->dim(0).value());
+ EXPECT_EQ(2, folded_const->at<loco::DataType::S64>(0));
+ EXPECT_EQ(4, folded_const->at<loco::DataType::S64>(1));
+ EXPECT_EQ(6, folded_const->at<loco::DataType::S64>(2));
+}
+
+TEST_F(FoldS64AddV2Test, input_type_mismatch_NEG)
+{
+ _x->dtype(loco::DataType::S32);
+
+ luci::FoldAddV2Pass pass;
+ EXPECT_FALSE(pass.run(graph()));
+}
diff --git a/compiler/luci/pass/src/FoldCastPass.cpp b/compiler/luci/pass/src/FoldCastPass.cpp
new file mode 100644
index 000000000..00b86fe48
--- /dev/null
+++ b/compiler/luci/pass/src/FoldCastPass.cpp
@@ -0,0 +1,107 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldCastPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+luci::CircleConst *cast_const(luci::CircleConst *node, loco::DataType from_dtype,
+ loco::DataType to_dtype)
+{
+ assert(node->dtype() == from_dtype);
+
+ auto name = node->name();
+ assert(name.length() > 0);
+ auto constant = node->graph()->nodes()->create<luci::CircleConst>();
+ constant->dtype(to_dtype);
+ constant->rank(node->rank());
+ uint32_t num_elems = 1;
+ for (uint32_t i = 0; i < node->rank(); i++)
+ {
+ constant->dim(i).set(node->dim(i).value());
+ num_elems *= node->dim(i).value();
+ }
+
+ constant->shape_status(luci::ShapeStatus::VALID);
+
+ // TODO: Support more data types
+ if (from_dtype == loco::DataType::S64)
+ {
+ if (to_dtype == loco::DataType::S32)
+ {
+ constant->size<loco::DataType::S32>(num_elems);
+ for (uint32_t i = 0; i < num_elems; i++)
+ constant->at<loco::DataType::S32>(i) =
+ static_cast<int32_t>(node->at<loco::DataType::S64>(i));
+
+ constant->name(name + "_S32");
+ return constant;
+ }
+ return nullptr;
+ }
+
+ return nullptr;
+}
+
+/**
+ * Fold Cast to const if it has const input
+ **/
+bool fold_cast(luci::CircleCast *cast)
+{
+ // Check cast has const input
+ auto const_x = dynamic_cast<luci::CircleConst *>(cast->x());
+ if (not const_x)
+ return false;
+
+ const auto in_dtype = const_x->dtype();
+ const auto out_dtype = cast->dtype();
+
+ auto casted_const = cast_const(const_x, in_dtype, out_dtype);
+ if (not casted_const)
+ return false;
+
+ loco::replace(cast).with(casted_const);
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * Constant Folding for Cast Op
+ **/
+bool FoldCastPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto cast = dynamic_cast<luci::CircleCast *>(node))
+ {
+ if (fold_cast(cast))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FoldCastPass.test.cpp b/compiler/luci/pass/src/FoldCastPass.test.cpp
new file mode 100644
index 000000000..5911adf11
--- /dev/null
+++ b/compiler/luci/pass/src/FoldCastPass.test.cpp
@@ -0,0 +1,112 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldCastPass.h"
+#include "PassTestGraphs.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+template <loco::DataType FromT, loco::DataType ToT>
+class FoldCastTest : public luci::ConstantFoldingAddTestGraph
+{
+public:
+ FoldCastTest(std::initializer_list<uint32_t> shape)
+ : luci::ConstantFoldingAddTestGraph(shape, ToT)
+ {
+ _cast = _g.nodes()->create<luci::CircleCast>();
+ _x = _g.nodes()->create<luci::CircleConst>();
+
+ _cast->dtype(ToT);
+ _x->dtype(FromT);
+
+ _cast->shape(shape);
+ _x->shape(shape);
+
+ uint32_t num_elems = 1;
+ for (auto dim = shape.begin(); dim != shape.end(); dim++)
+ num_elems *= *dim;
+
+ _x->size<FromT>(num_elems);
+ for (uint32_t i = 0; i < num_elems; i++)
+ _x->at<FromT>(i) = i + 1;
+
+ _cast->x(_x);
+
+ _cast->name("cast");
+ _x->name("x");
+ }
+
+ loco::Node *createFoldedPattern() override { return _cast; }
+
+protected:
+ luci::CircleCast *_cast = nullptr;
+ luci::CircleConst *_x = nullptr;
+};
+
+/**
+ * Graph that has a Cast Op with constant input
+ *
+ * BEFORE
+ *
+ * [CircleConst]
+ * |
+ * [Cast]
+ *
+ * AFTER
+ *
+ * [CircleConst]
+ *
+ */
+class FoldS64ToS32CastTest : public FoldCastTest<loco::DataType::S64, loco::DataType::S32>,
+ public ::testing::Test
+{
+public:
+ FoldS64ToS32CastTest() : FoldCastTest<loco::DataType::S64, loco::DataType::S32>({3}) {}
+
+ virtual void SetUp() { init(); }
+};
+
+} // namespace
+
+TEST(FoldCastPassTest, name)
+{
+ luci::FoldCastPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(FoldS64ToS32CastTest, fold_cast_s64_to_s32)
+{
+ luci::FoldCastPass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Check type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::S32, folded_const->dtype());
+ EXPECT_EQ(1, folded_const->rank());
+ EXPECT_EQ(3, folded_const->dim(0).value());
+ EXPECT_EQ(1, folded_const->at<loco::DataType::S32>(0));
+ EXPECT_EQ(2, folded_const->at<loco::DataType::S32>(1));
+ EXPECT_EQ(3, folded_const->at<loco::DataType::S32>(2));
+}
diff --git a/compiler/luci/pass/src/FoldDequantizePass.cpp b/compiler/luci/pass/src/FoldDequantizePass.cpp
index 01c04f478..3dd4f8cea 100644
--- a/compiler/luci/pass/src/FoldDequantizePass.cpp
+++ b/compiler/luci/pass/src/FoldDequantizePass.cpp
@@ -17,8 +17,7 @@
#include "luci/Pass/FoldDequantizePass.h"
#include <luci/IR/CircleNodes.h>
-
-#include <loco/Service/TypeInference.h>
+#include <luci/Profile/CircleNodeOrigin.h>
namespace
{
@@ -51,6 +50,8 @@ luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node)
throw std::runtime_error("Given constant node has no quantization parameter");
}
+ auto name = const_node->name();
+ assert(name.length() > 0);
auto g = const_node->graph();
auto new_const_node = g->nodes()->create<luci::CircleConst>();
@@ -64,6 +65,7 @@ luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node)
}
new_const_node->size<loco::DataType::FLOAT32>(dim_size);
new_const_node->shape_status(luci::ShapeStatus::VALID);
+ new_const_node->name(name + "_DQ");
const int32_t q_dim = const_node->quantparam()->quantized_dimension;
const int32_t q_dim_value = const_node->dim(q_dim).value();
@@ -81,8 +83,8 @@ luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node)
qd = 0;
new_const_node->at<loco::DataType::FLOAT32>(i) =
- (float)(const_node->at<loco::DataType::S8>(i) - const_node->quantparam()->zerop.at(qd)) *
- const_node->quantparam()->scale.at(qd);
+ (float)(const_node->at<loco::DataType::S8>(i) - const_node->quantparam()->zerop.at(qd)) *
+ const_node->quantparam()->scale.at(qd);
}
}
else
@@ -94,9 +96,9 @@ luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node)
qd = 0;
new_const_node->at<loco::DataType::FLOAT32>(i) =
- (float)((int)const_node->at<loco::DataType::U8>(i) -
- const_node->quantparam()->zerop.at(qd)) *
- const_node->quantparam()->scale.at(qd);
+ (float)((int)const_node->at<loco::DataType::U8>(i) -
+ const_node->quantparam()->zerop.at(qd)) *
+ const_node->quantparam()->scale.at(qd);
}
}
@@ -192,6 +194,8 @@ bool FoldDequantizePass::run(loco::Graph *g)
if (replace_const_node(const_node_user, const_node))
{
loco::replace(dequant).with(const_node_user);
+ luci::add_origin(loco::must_cast<luci::CircleNode *>(const_node_user),
+ luci::get_origin(dequant));
changed = true;
}
}
diff --git a/compiler/luci/service/src/Nodes/CircleOutput.cpp b/compiler/luci/pass/src/FoldDequantizePass.test.cpp
index d4c8da2d8..d82a7bc87 100644
--- a/compiler/luci/service/src/Nodes/CircleOutput.cpp
+++ b/compiler/luci/pass/src/FoldDequantizePass.test.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,14 +14,13 @@
* limitations under the License.
*/
-#include <luci/Service/CircleShapeSignatureInference.h>
+#include "luci/Pass/FoldDequantizePass.h"
-namespace luci
-{
+#include <gtest/gtest.h>
-ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutput *node)
+TEST(FoldDequantizePassTest, name)
{
- return input_arg_signature(node, 0);
+ luci::FoldDequantizePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
}
-
-} // namespace luci
diff --git a/compiler/luci/pass/src/FoldSparseToDensePass.cpp b/compiler/luci/pass/src/FoldSparseToDensePass.cpp
new file mode 100644
index 000000000..0c6fc43ed
--- /dev/null
+++ b/compiler/luci/pass/src/FoldSparseToDensePass.cpp
@@ -0,0 +1,140 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldSparseToDensePass.h"
+#include "CircleOptimizerUtils.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+/**
+ * Fold to const if
+ *
+ * 1. indices has 0-sized static shape such as [0]
+ * (i.e., output is filled with default value)
+ * 2. default_value: const scalar
+ * 3. output_shape: const
+ *
+ * TODO: Support more general patterns
+ **/
+template <loco::DataType IndexT, loco::DataType ValueT>
+bool fold_sparse_to_dense(luci::CircleSparseToDense *stod)
+{
+ const auto indices = loco::must_cast<luci::CircleNode *>(stod->indices());
+ const auto default_value = loco::must_cast<luci::CircleConst *>(stod->default_value());
+ const auto output_shape = loco::must_cast<luci::CircleConst *>(stod->output_shape());
+
+ bool has_zero = false;
+ for (uint32_t i = 0; i < indices->rank(); i++)
+ {
+ if (indices->dim(i).known() && indices->dim(i).value() == 0)
+ has_zero = true;
+ }
+ if (!has_zero)
+ return false;
+
+ if (default_value->rank() != 0 || default_value->size<ValueT>() != 1)
+ return false;
+
+ auto rank = output_shape->size<IndexT>();
+ std::vector<uint32_t> shape;
+ for (uint32_t i = 0; i < rank; i++)
+ {
+ auto dim = output_shape->at<IndexT>(i);
+ assert(dim >= 0 && dim <= std::numeric_limits<uint32_t>::max());
+ if (!(dim >= 0 && dim <= std::numeric_limits<uint32_t>::max()))
+ return false;
+
+ shape.push_back(dim);
+ }
+
+ auto name = stod->name();
+ assert(name.length() > 0);
+ auto constant = stod->graph()->nodes()->create<luci::CircleConst>();
+ constant->dtype(default_value->dtype());
+ constant->rank(rank);
+ uint32_t dim_size = 1;
+ for (uint32_t i = 0; i < rank; i++)
+ {
+ constant->dim(i).set(shape[i]);
+ dim_size *= shape[i];
+ }
+
+ constant->size<ValueT>(dim_size);
+ const auto value = default_value->scalar<ValueT>();
+ for (uint32_t i = 0; i < dim_size; i++)
+ constant->at<ValueT>(i) = value;
+
+ constant->shape_status(luci::ShapeStatus::VALID);
+ constant->name(name + "_D");
+
+ loco::replace(stod).with(constant);
+
+ return true;
+}
+
+bool fold_sparse_to_dense(luci::CircleSparseToDense *stod)
+{
+ auto indices = loco::must_cast<luci::CircleNode *>(stod->indices());
+ auto default_value = dynamic_cast<luci::CircleConst *>(stod->default_value());
+ if (not default_value)
+ return false;
+
+ auto output_shape = dynamic_cast<luci::CircleConst *>(stod->output_shape());
+ if (not output_shape)
+ return false;
+
+ // Illegal input check
+ if (indices->dtype() != output_shape->dtype())
+ throw std::runtime_error("indices and output_shape of SparseToDense must have the same dtype");
+
+ // TODO: Support more data types
+ if (indices->dtype() == loco::DataType::S64)
+ {
+ if (default_value->dtype() == loco::DataType::S64)
+ {
+ return fold_sparse_to_dense<loco::DataType::S64, loco::DataType::S64>(stod);
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * Constant Folding for SparseToDense Op
+ **/
+bool FoldSparseToDensePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto stod = dynamic_cast<luci::CircleSparseToDense *>(node))
+ {
+ if (fold_sparse_to_dense(stod))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FoldSparseToDensePass.test.cpp b/compiler/luci/pass/src/FoldSparseToDensePass.test.cpp
new file mode 100644
index 000000000..7c6dcb033
--- /dev/null
+++ b/compiler/luci/pass/src/FoldSparseToDensePass.test.cpp
@@ -0,0 +1,133 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldSparseToDensePass.h"
+#include "PassTestGraphs.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ * Graph that has a SparseToDense Op with zero-sized indices
+ *
+ * BEFORE
+ * - shape of indices: [0,1]
+ * - output_shape: [3]
+ * - default_value: scalar 2
+ *
+ * [indices] [output_shape] [values] [default_value]
+ * | | | |
+ * +------[SparseToDense]------+
+ *
+ * AFTER
+ *
+ * [Const] (shape: [3], values: [2, 2, 2])
+ *
+ */
+class S64SparseToDenseZeroIndicesTest : public luci::ConstantFoldingAddTestGraph,
+ public ::testing::Test
+{
+public:
+ S64SparseToDenseZeroIndicesTest() : luci::ConstantFoldingAddTestGraph({3}, loco::DataType::S64) {}
+
+ virtual void SetUp() { init(); }
+
+ loco::Node *createFoldedPattern() override
+ {
+ _stod = _g.nodes()->create<luci::CircleSparseToDense>();
+ _indices = _g.nodes()->create<luci::CircleConst>();
+ _output_shape = _g.nodes()->create<luci::CircleConst>();
+ _values = _g.nodes()->create<luci::CircleConst>();
+ _default_value = _g.nodes()->create<luci::CircleConst>();
+
+ _stod->dtype(loco::DataType::S64);
+ _indices->dtype(loco::DataType::S64);
+ _output_shape->dtype(loco::DataType::S64);
+ _values->dtype(loco::DataType::S64);
+ _default_value->dtype(loco::DataType::S64);
+
+ _indices->shape({0, 1});
+ _output_shape->shape({1});
+ _values->shape({0});
+ _default_value->rank(0);
+
+ _indices->size<loco::DataType::S64>(0);
+ _output_shape->size<loco::DataType::S64>(1);
+ _output_shape->at<loco::DataType::S64>(0) = 3;
+ _values->size<loco::DataType::S64>(0);
+ _default_value->size<loco::DataType::S64>(1);
+ _default_value->at<loco::DataType::S64>(0) = 2;
+
+ _stod->indices(_indices);
+ _stod->output_shape(_output_shape);
+ _stod->values(_values);
+ _stod->default_value(_default_value);
+
+ _stod->name("stod");
+ _indices->name("indices");
+ _output_shape->name("output_shape");
+ _values->name("values");
+ _default_value->name("default_value");
+
+ return _stod;
+ }
+
+protected:
+ luci::CircleSparseToDense *_stod = nullptr;
+ luci::CircleConst *_indices = nullptr;
+ luci::CircleConst *_output_shape = nullptr;
+ luci::CircleConst *_values = nullptr;
+ luci::CircleConst *_default_value = nullptr;
+};
+
+} // namespace
+
+TEST(FoldSparseToDensePassTest, name)
+{
+ luci::FoldSparseToDensePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(S64SparseToDenseZeroIndicesTest, fold_stod_with_zero_indices)
+{
+ luci::FoldSparseToDensePass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Chec type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::S64, folded_const->dtype());
+ EXPECT_EQ(1, folded_const->rank());
+ EXPECT_EQ(3, folded_const->dim(0).value());
+ EXPECT_EQ(2, folded_const->at<loco::DataType::S64>(0));
+ EXPECT_EQ(2, folded_const->at<loco::DataType::S64>(1));
+ EXPECT_EQ(2, folded_const->at<loco::DataType::S64>(2));
+}
+
+TEST_F(S64SparseToDenseZeroIndicesTest, illegal_input_NEG)
+{
+ _indices->dtype(loco::DataType::S32);
+
+ luci::FoldSparseToDensePass pass;
+ EXPECT_ANY_THROW(pass.run(graph()));
+}
diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp
new file mode 100644
index 000000000..2c990f0a5
--- /dev/null
+++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp
@@ -0,0 +1,154 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Log.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+#include <luci/Service/CircleShapeInference.h>
+#include <luci/Service/Nodes/CircleConst.h>
+
+namespace
+{
+
+luci::CircleReshape *as_reshape(loco::Node *node)
+{
+ return dynamic_cast<luci::CircleReshape *>(node);
+}
+
+luci::CircleConst *clone_shape(luci::CircleReshape *reshape)
+{
+ const auto shape = dynamic_cast<luci::CircleConst *>(reshape->shape());
+ // only support CircleConst for now
+ if (shape == nullptr)
+ return nullptr;
+
+ // NOTE tflite and circle only supports S32
+ // TODO just check with assert() after import handles this
+ auto dtype = shape->dtype();
+ if (dtype != loco::DataType::S32)
+ return nullptr;
+
+ return luci::clone(shape);
+}
+
+void copy_shape(luci::CircleReshape *reshape, luci::CircleReshape *new_reshape)
+{
+ auto ns_rank = reshape->newShape()->rank();
+ new_reshape->newShape()->rank(ns_rank);
+ for (uint32_t r = 0; r < ns_rank; ++r)
+ new_reshape->newShape()->dim(r) = reshape->newShape()->dim(r);
+}
+
+bool forward_reshape(luci::CircleReshape *reshape, luci::CircleNeg *neg)
+{
+ assert(reshape != nullptr);
+ assert(neg != nullptr);
+
+ luci::CircleConst *cloned_shape = clone_shape(reshape);
+ if (cloned_shape == nullptr)
+ return false;
+
+ auto name = reshape->name();
+ assert(name.length() > 0);
+ loco::Graph *graph = neg->graph();
+ // create reshape placed after neg
+ luci::CircleReshape *new_reshape = graph->nodes()->create<luci::CircleReshape>();
+ copy_shape(reshape, new_reshape);
+ new_reshape->shape(cloned_shape);
+ new_reshape->name(name + "_C");
+ luci::add_origin(new_reshape, luci::get_origin(reshape));
+
+ // reconnect network
+ loco::replace(neg).with(new_reshape);
+ neg->x(reshape->tensor());
+ new_reshape->tensor(neg);
+
+ // Do shape inference for this node again.
+ neg->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ return true;
+}
+
+class ForwardReshape final : public luci::CircleNodeMutableVisitor<bool>
+{
+protected:
+ bool visit(luci::CircleNode *node)
+ {
+ LOGGER(l);
+ INFO(l) << "ForwardReshape: Unsupported operator: " << node->name() << std::endl;
+ return false;
+ }
+
+ bool visit(luci::CircleNeg *node)
+ {
+ auto reshape = as_reshape(node->x());
+ if (reshape == nullptr)
+ return false;
+ return forward_reshape(reshape, node);
+ }
+
+ // TODO add more unary operators
+};
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * BEFORE
+ * |
+ * [CircleNode] [CircleConst]
+ * | /
+ * [CircleReshape]
+ * / |
+ * [CircleNode] [(UnaryOp)]
+ * | | \
+ * | | [CircleNode]
+ * | | |
+ *
+ * UnaryOp: CircleNeg, ...
+ *
+ * AFTER
+ * |
+ * [CircleConst] [CircleNode]
+ * | / |
+ * [CircleReshape] [(UnaryOp)] [CircleConst]
+ * | | /
+ * [CircleNode] [CircleReshape]
+ * | | \
+ * | | [CircleNode]
+ * | | |
+ *
+ * Note: new [CircleReshape] after [(UnaryOp)] added
+ */
+bool ForwardReshapeToUnaryOpPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ ForwardReshape forward;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (circle_node->accept(&forward))
+ changed = true;
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp
new file mode 100644
index 000000000..2593a014c
--- /dev/null
+++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp
@@ -0,0 +1,125 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
+#include "luci/Pass/CircleShapeInferencePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+#include <vector>
+
+namespace
+{
+
+using namespace luci::test;
+
+class ReshapeNegGraphlet
+{
+public:
+ ReshapeNegGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ std::vector<uint32_t> shape_out_v = shape_out;
+
+ _reshape_shape = g->nodes()->create<luci::CircleConst>();
+ _reshape = g->nodes()->create<luci::CircleReshape>();
+ _neg = g->nodes()->create<luci::CircleNeg>();
+
+ _reshape_shape->dtype(loco::DataType::S32);
+ _reshape_shape->rank(1);
+ _reshape_shape->dim(0).set(shape_out_v.size());
+ _reshape_shape->shape_status(luci::ShapeStatus::VALID);
+ // values
+ const auto size = shape_out_v.size();
+ _reshape_shape->size<loco::DataType::S32>(size);
+ for (uint32_t i = 0; i < size; i++)
+ _reshape_shape->at<loco::DataType::S32>(i) = shape_out_v[i];
+
+ _reshape_shape->name("reshape_shape");
+ _reshape->name("reshape");
+ _neg->name("neg");
+ }
+
+protected:
+ luci::CircleReshape *_reshape = nullptr;
+ luci::CircleNeg *_neg = nullptr;
+ luci::CircleConst *_reshape_shape = nullptr;
+};
+
+class ForwardReshapeToNegGraph : public TestIOGraph, public ReshapeNegGraphlet
+{
+public:
+ ForwardReshapeToNegGraph() = default;
+
+public:
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ ReshapeNegGraphlet::init(g(), shape_in, shape_out);
+
+ // connect network
+ _reshape->tensor(input());
+ _reshape->shape(_reshape_shape);
+ _neg->x(_reshape);
+
+ output()->from(_neg);
+ }
+};
+
+class ForwardReshapeToNegGraphTest : public ::testing::Test
+{
+public:
+ ForwardReshapeToNegGraphTest() = default;
+
+ void run_pass(void)
+ {
+ while (_pass.run(_graph.g()))
+ ;
+ }
+
+protected:
+ ForwardReshapeToNegGraph _graph;
+ luci::ForwardReshapeToUnaryOpPass _pass;
+};
+
+} // namespace
+
+TEST(ForwardReshapeToUnaryOpPassTest, name)
+{
+ luci::ForwardReshapeToUnaryOpPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(ForwardReshapeToNegGraphTest, simple_forward)
+{
+ _graph.init({2, 2, 2}, {2, 4});
+
+ run_pass();
+
+ auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
+ auto neg = dynamic_cast<luci::CircleNeg *>(_graph.output()->from());
+ ASSERT_NE(nullptr, reshape);
+ ASSERT_EQ(nullptr, neg);
+ neg = dynamic_cast<luci::CircleNeg *>(reshape->tensor());
+ ASSERT_NE(nullptr, neg);
+}
diff --git a/compiler/luci/pass/src/FuseActivationFunctionPass.cpp b/compiler/luci/pass/src/FuseActivationFunctionPass.cpp
index 844541d2d..66e341518 100644
--- a/compiler/luci/pass/src/FuseActivationFunctionPass.cpp
+++ b/compiler/luci/pass/src/FuseActivationFunctionPass.cpp
@@ -17,7 +17,9 @@
#include "luci/Pass/FuseActivationFunctionPass.h"
#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeMixins.h>
#include <luci/IR/CircleOpcode.h>
+#include <luci/Profile/CircleNodeOrigin.h>
namespace luci
{
@@ -32,10 +34,15 @@ bool fuse_activation_function(luci::CircleNode *node)
return false;
auto node_with_fused_act =
- dynamic_cast<luci::LuciNodeMixin<luci::LuciNodeTrait::FusedActFunc> *>(pred_node);
+ dynamic_cast<luci::CircleNodeMixin<luci::CircleNodeTrait::FusedActFunc> *>(pred_node);
if (node_with_fused_act == nullptr)
return false;
+ // TODO remove this work-around
+ // This will skip fuse for concat as luci-interpreter doesn't support this yet
+ if (dynamic_cast<luci::CircleConcatenation *>(pred_node) != nullptr)
+ return false;
+
auto fused_act = node_with_fused_act->fusedActivationFunction();
luci::FusedActFunc target_func = luci::FusedActFunc::UNDEFINED;
@@ -76,6 +83,7 @@ bool fuse_activation_function(luci::CircleNode *node)
return false;
node_with_fused_act->fusedActivationFunction(target_func);
+ luci::add_origin(pred_node, luci::get_origin(node));
loco::replace(node).with(pred_node);
node->drop();
diff --git a/compiler/luci/pass/src/FuseActivationFunctionPass.test.cpp b/compiler/luci/pass/src/FuseActivationFunctionPass.test.cpp
index 226a303a1..56b414143 100644
--- a/compiler/luci/pass/src/FuseActivationFunctionPass.test.cpp
+++ b/compiler/luci/pass/src/FuseActivationFunctionPass.test.cpp
@@ -14,15 +14,19 @@
* limitations under the License.
*/
-#include "FuseActivationFunctionPassInternal.h"
+#include "luci/Pass/FuseActivationFunctionPass.h"
#include <luci/IR/CircleNodes.h>
+#include <luci/test/TestIOGraph.h>
+
#include <gtest/gtest.h>
namespace
{
+using namespace luci::test;
+
/**
* Simple graph for test
*
@@ -41,60 +45,148 @@ namespace
* [Conv2]
*
*/
-class SimpleGraph
+class ConvReluConvGraphlet
+{
+public:
+ ConvReluConvGraphlet() = default;
+
+ void init(loco::Graph *g)
+ {
+ _conv1 = g->nodes()->create<luci::CircleConv2D>();
+ _conv2 = g->nodes()->create<luci::CircleConv2D>();
+ _relu = g->nodes()->create<luci::CircleRelu>();
+ _conv1_f = g->nodes()->create<luci::CircleConst>();
+ _conv1_b = g->nodes()->create<luci::CircleConst>();
+ _conv2_f = g->nodes()->create<luci::CircleConst>();
+ _conv2_b = g->nodes()->create<luci::CircleConst>();
+
+ _conv1->fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ _conv1->name("conv1");
+ _conv2->name("conv2");
+ _relu->name("relu");
+ _conv1_f->name("conv1f");
+ _conv1_b->name("conv1b");
+ _conv2_f->name("conv2f");
+ _conv2_b->name("conv2b");
+ }
+
+public:
+ luci::CircleRelu *relu() { return _relu; }
+ luci::CircleConv2D *conv1() { return _conv1; }
+ luci::CircleConv2D *conv2() { return _conv2; }
+
+protected:
+ luci::CircleConv2D *_conv1 = nullptr;
+ luci::CircleConv2D *_conv2 = nullptr;
+ luci::CircleRelu *_relu = nullptr;
+ luci::CircleConst *_conv1_f = nullptr;
+ luci::CircleConst *_conv1_b = nullptr;
+ luci::CircleConst *_conv2_f = nullptr;
+ luci::CircleConst *_conv2_b = nullptr;
+};
+
+class FuseActTestGraph : public TestIOGraph, public ConvReluConvGraphlet
{
public:
- SimpleGraph()
+ FuseActTestGraph() = default;
+
+ void init(void)
{
- conv1 = g.nodes()->create<luci::CircleConv2D>();
- conv2 = g.nodes()->create<luci::CircleConv2D>();
- relu = g.nodes()->create<luci::CircleRelu>();
+ TestIOGraph::init({1}, {1});
+ ConvReluConvGraphlet::init(g());
- conv1->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _conv1->input(input());
+ _conv1->filter(_conv1_f);
+ _conv1->bias(_conv1_b);
- relu->features(conv1);
- conv2->input(relu);
+ _relu->features(_conv1);
+
+ _conv2->input(_relu);
+ _conv2->filter(_conv2_f);
+ _conv2->bias(_conv2_b);
+
+ output()->from(_conv2);
}
+};
+class ConvHasMultiSuccGraph : public TestIOGraph, public ConvReluConvGraphlet
+{
public:
- loco::Graph g;
- luci::CircleConv2D *conv1;
- luci::CircleConv2D *conv2;
- luci::CircleRelu *relu;
+ ConvHasMultiSuccGraph() = default;
+
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ ConvReluConvGraphlet::init(g());
+
+ _conv1->input(input());
+ _conv1->filter(_conv1_f);
+ _conv1->bias(_conv1_b);
+
+ _relu->features(_conv1);
+
+ _conv2->input(_conv1);
+ _conv2->filter(_conv2_f);
+ _conv2->bias(_conv2_b);
+
+ output()->from(_relu); // We need to check from relu
+ }
};
+// TODO use ::testing::Test
+
} // namespace
+TEST(FuseActivationFunctionPassTest, name)
+{
+ luci::FuseActivationFunctionPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
TEST(FusePreActivationBatchNorm, fuse_activation_function)
{
- SimpleGraph g;
+ FuseActTestGraph g;
+ luci::FuseActivationFunctionPass pass;
- EXPECT_TRUE(luci::fuse_activation_function(g.relu));
+ g.init();
- EXPECT_EQ(g.conv1, g.conv2->input());
+ EXPECT_TRUE(pass.run(g.g()));
+ EXPECT_EQ(g.conv1(), g.conv2()->input());
}
TEST(FusePreActivationBatchNorm, fuse_activation_function_dup_relu)
{
- SimpleGraph g;
- g.conv1->fusedActivationFunction(luci::FusedActFunc::RELU);
+ FuseActTestGraph g;
+ luci::FuseActivationFunctionPass pass;
- EXPECT_TRUE(luci::fuse_activation_function(g.relu));
+ g.init();
+ g.conv1()->fusedActivationFunction(luci::FusedActFunc::RELU);
- EXPECT_EQ(g.conv1, g.conv2->input());
+ EXPECT_TRUE(pass.run(g.g()));
+ EXPECT_EQ(g.conv1(), g.conv2()->input());
}
-TEST(FusePreActivationBatchNorm, fuse_activation_function_NEG)
+TEST(FusePreActivationBatchNorm, fuse_activation_function_mulsucc_NEG)
{
- SimpleGraph g;
- g.conv2->input(g.conv1);
+ ConvHasMultiSuccGraph g;
+ luci::FuseActivationFunctionPass pass;
+
+ g.init();
- // Conv1 has multiple successors
- EXPECT_FALSE(luci::fuse_activation_function(g.relu));
+ // Relu input Conv2D has multiple successors
+ EXPECT_FALSE(pass.run(g.g()));
+}
+
+TEST(FusePreActivationBatchNorm, fuse_activation_function_tanh_NEG)
+{
+ FuseActTestGraph g;
+ luci::FuseActivationFunctionPass pass;
- g.conv2->input(g.relu);
- g.conv1->fusedActivationFunction(luci::FusedActFunc::TANH);
+ g.init();
+ g.conv1()->fusedActivationFunction(luci::FusedActFunc::TANH);
- // Conv1 already has activation function
- EXPECT_FALSE(luci::fuse_activation_function(g.relu));
+ // Relu input Conv2D already has activation function
+ EXPECT_FALSE(pass.run(g.g()));
}
diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
index bd7805f6a..2bca57014 100644
--- a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
+++ b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
@@ -17,20 +17,30 @@
#include "luci/Pass/FuseAddWithTConvPass.h"
#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
namespace
{
/**
- * Fuse add to TCONV if possible
+ * Fuse Add to TransposeConv if possible
*
* BEFORE
- *
- * [CircleTransposeConv]
+ * |
+ * [CircleConst] [CircleTransposeConv]
+ * \ |
+ * [CircleAdd]
* |
- * [add]
+ *
* AFTER
+ * |
+ * [CircleConst] |
+ * \ |
+ * [CircleTransposeConv] [CircleAdd]
+ * |
+ * ([CircleRelu6])
+ * |
*
- * [CircleTransposeConv]
+ * Note: CircleRelu6 is inserted if Add activation is ReLU6
*/
bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
{
@@ -81,9 +91,13 @@ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
{
+ auto name = addition->name();
+ assert(name.length() > 0);
// separate relu op from add op
auto relu = add->graph()->nodes()->create<luci::CircleRelu6>();
relu->features(tconv);
+ relu->name(name + "/Relu6");
+ luci::add_origin(relu, luci::get_origin(add));
// remove add node
replace(add).with(relu);
@@ -93,6 +107,9 @@ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
replace(add).with(tconv);
}
+ // set origin
+ luci::add_origin(tconv, luci::get_origin(add));
+
return true;
}
diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp
new file mode 100644
index 000000000..8748d73ef
--- /dev/null
+++ b/compiler/luci/pass/src/FuseAddWithTConvPass.test.cpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseAddWithTConvPass.h"
+
+#include <gtest/gtest.h>
+
+TEST(FuseAddWithTConvPassTest, name)
+{
+ luci::FuseAddWithTConvPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/FuseBCQPass.cpp b/compiler/luci/pass/src/FuseBCQPass.cpp
index c0583d848..09180d8c1 100644
--- a/compiler/luci/pass/src/FuseBCQPass.cpp
+++ b/compiler/luci/pass/src/FuseBCQPass.cpp
@@ -17,6 +17,7 @@
#include "luci/Pass/FuseBCQPass.h"
#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Log.h>
#include <cassert>
@@ -111,7 +112,7 @@ template <> class BCQFuser<1>
{
public:
BCQFuser<1>(int32_t original_output_cnt, int32_t bundle_cnt)
- : _original_output_cnt{original_output_cnt}, _bundle_cnt{bundle_cnt}
+ : _original_output_cnt{original_output_cnt}, _bundle_cnt{bundle_cnt}
{
// Do nothing
}
@@ -133,7 +134,7 @@ public:
{
const auto prefix = (output_node->index() - (_original_output_cnt + 1)) / (_bundle_cnt);
const MetadataType metadata_type = static_cast<MetadataType>(
- (output_node->index() - (_original_output_cnt + 1)) % (_bundle_cnt));
+ (output_node->index() - (_original_output_cnt + 1)) % (_bundle_cnt));
const auto circle_node = loco::must_cast<luci::CircleNode *>(output_node->from());
add_BCQ_info_node(prefix, metadata_type, circle_node);
}
@@ -156,13 +157,18 @@ public:
if (prefix == -1 || !is_valid_prefix(prefix))
continue;
+ auto name = gather->name();
+ assert(name.length() > 0);
+
auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>();
+ luci::add_origin(bcq_gather, luci::get_origin(gather));
bcq_gather->op_version(1);
bcq_gather->input_scales(alpha(g, prefix));
bcq_gather->input_binary(packed_binary_code(g, prefix));
bcq_gather->indices(gather->indices());
bcq_gather->input_clusters(packed_clusters(g, prefix));
+ bcq_gather->name(name + "/BCQGather");
if (_do_w_x[prefix]->at<loco::DataType::BOOL>(0))
{
@@ -177,7 +183,7 @@ public:
bcq_gather->axis(axis_transpose);
const auto indices_rank =
- loco::must_cast<luci::CircleNode *>(gather->indices())->rank();
+ loco::must_cast<luci::CircleNode *>(gather->indices())->rank();
auto perm = g->nodes()->create<luci::CircleConst>();
perm->dtype(loco::DataType::S32);
@@ -188,10 +194,13 @@ public:
perm->at<loco::DataType::S32>(idx) = idx + 1;
perm->at<loco::DataType::S32>(indices_rank) = 0;
perm->shape_status(luci::ShapeStatus::VALID);
+ perm->name(name + "/Transpose/perm");
auto output_transpose = g->nodes()->create<luci::CircleTranspose>();
+ luci::add_origin(output_transpose, luci::get_origin(gather));
output_transpose->a(bcq_gather);
output_transpose->perm(perm);
+ output_transpose->name(name + "/Transpose");
loco::replace(gather).with(output_transpose);
}
@@ -209,7 +218,11 @@ public:
if (prefix == -1 || !is_valid_prefix(prefix))
continue;
+ auto name = fully_connected->name();
+ assert(name.length() > 0);
+
auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
+ luci::add_origin(bcq_fc, luci::get_origin(fully_connected));
bcq_fc->op_version(1);
bcq_fc->weights_scales(alpha(g, prefix));
@@ -217,6 +230,7 @@ public:
bcq_fc->bias(fully_connected->bias());
bcq_fc->weights_clusters(packed_clusters(g, prefix));
bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction());
+ bcq_fc->name(name + "/BCQFullyConnected");
loco::Node *bcq_input = fully_connected->input();
@@ -231,18 +245,16 @@ public:
new_shape->rank(1);
new_shape->dim(0) = 2;
- auto batch_size = 1;
- for (uint32_t i = 0; i < original_input->rank() - 1; ++i)
- batch_size *= original_input->dim(i).value();
-
- new_shape->at<loco::DataType::S32>(0) = batch_size;
- new_shape->at<loco::DataType::S32>(1) =
- original_input->dim(original_input->rank() - 1).value();
+ new_shape->at<loco::DataType::S32>(0) = -1;
+ new_shape->at<loco::DataType::S32>(1) = weights->dim(1).value();
new_shape->shape_status(luci::ShapeStatus::VALID);
+ new_shape->name(name + "/Reshape/shape");
auto reshape = g->nodes()->create<luci::CircleReshape>();
+ luci::add_origin(reshape, luci::get_origin(fully_connected));
reshape->tensor(original_input);
reshape->shape(new_shape);
+ reshape->name(name + "/Reshape");
bcq_input = reshape;
}
@@ -258,23 +270,28 @@ public:
perm->at<loco::DataType::S32>(0) = 1;
perm->at<loco::DataType::S32>(1) = 0;
perm->shape_status(luci::ShapeStatus::VALID);
+ perm->name(name + "/Transpose/perm");
auto input_transpose = g->nodes()->create<luci::CircleTranspose>();
+ luci::add_origin(input_transpose, luci::get_origin(fully_connected));
input_transpose->a(bcq_input);
input_transpose->perm(perm);
+ input_transpose->name(name + "_input/Transpose");
bcq_fc->input(input_transpose);
auto output_transpose = g->nodes()->create<luci::CircleTranspose>();
+ luci::add_origin(output_transpose, luci::get_origin(fully_connected));
output_transpose->a(bcq_fc);
output_transpose->perm(perm);
+ output_transpose->name(name + "_output/Transpose");
loco::replace(fully_connected).with(output_transpose);
return true;
}
else if (auto weights_as_input =
- dynamic_cast<luci::CircleConst *>(fully_connected->input()))
+ dynamic_cast<luci::CircleConst *>(fully_connected->input()))
{
auto prefix = get_prefix_of_const(weights_as_input);
if (prefix == -1 || !is_valid_prefix(prefix))
@@ -282,6 +299,9 @@ public:
assert(_do_w_x[prefix]->at<loco::DataType::BOOL>(0) == true);
+ auto name = weights_as_input->name();
+ assert(name.length() > 0);
+
auto perm = g->nodes()->create<luci::CircleConst>();
perm->dtype(loco::DataType::S32);
perm->size<loco::DataType::S32>(2);
@@ -290,12 +310,16 @@ public:
perm->at<loco::DataType::S32>(0) = 1;
perm->at<loco::DataType::S32>(1) = 0;
perm->shape_status(luci::ShapeStatus::VALID);
+ perm->name(name + "/Transpose/perm");
auto input_transpose = g->nodes()->create<luci::CircleTranspose>();
+ luci::add_origin(input_transpose, luci::get_origin(fully_connected));
input_transpose->a(fully_connected->weights());
input_transpose->perm(perm);
+ input_transpose->name(name + "/Transpose");
auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
+ luci::add_origin(bcq_fc, luci::get_origin(fully_connected));
assert(dynamic_cast<luci::CircleOutputExclude *>(fully_connected->bias()) != nullptr);
@@ -308,6 +332,8 @@ public:
bcq_fc->weights_hidden_size(weights_as_input->dim(1).value());
bcq_fc->input(input_transpose);
+ bcq_fc->name(name + "/BCQFullyConnected");
+
loco::replace(fully_connected).with(bcq_fc);
return true;
@@ -533,7 +559,7 @@ private:
new_beta->dim(1) = _packed_binary_code[prefix]->dim(1);
for (uint32_t i = 0; i < _packed_binary_code[prefix]->size<loco::DataType::S32>(); ++i)
new_beta->at<loco::DataType::S32>(i) =
- _packed_binary_code[prefix]->at<loco::DataType::S32>(i);
+ _packed_binary_code[prefix]->at<loco::DataType::S32>(i);
new_beta->shape_status(luci::ShapeStatus::VALID);
return new_beta;
@@ -556,9 +582,9 @@ private:
for (int i = 0; i < number_of_clusters; ++i)
{
packed_clusters->at<loco::DataType::S32>(i * 2) =
- qbits_of_clusters->at<loco::DataType::S32>(i);
+ qbits_of_clusters->at<loco::DataType::S32>(i);
packed_clusters->at<loco::DataType::S32>(i * 2 + 1) =
- size_of_clusters->at<loco::DataType::S32>(i);
+ size_of_clusters->at<loco::DataType::S32>(i);
}
return packed_clusters;
diff --git a/compiler/luci/pass/src/FuseBCQPass.test.cpp b/compiler/luci/pass/src/FuseBCQPass.test.cpp
new file mode 100644
index 000000000..73677affd
--- /dev/null
+++ b/compiler/luci/pass/src/FuseBCQPass.test.cpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseBCQPass.h"
+
+#include <gtest/gtest.h>
+
+TEST(FuseBCQPassTest, name)
+{
+ luci::FuseBCQPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/FuseBatchNormWithConvPass.cpp b/compiler/luci/pass/src/FuseBatchNormWithConvPass.cpp
new file mode 100644
index 000000000..062da7058
--- /dev/null
+++ b/compiler/luci/pass/src/FuseBatchNormWithConvPass.cpp
@@ -0,0 +1,232 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseBatchNormWithConvPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+
+namespace
+{
+/**
+ * Fuse Mul-Add to Conv2D if possible.
+ *
+ * NOTE TF's BatchNormalization is converted to Mul and Add.
+ *
+ * BEFORE
+ * | [CircleConst]
+ * | / [CircleConst]
+ * | / /
+ * [CircleConv2D] [CircleConst]
+ * | /
+ * [CircleMul] [CircleConst]
+ * | /
+ * [CircleAdd]
+ * |
+ *
+ * AFTER
+ * | [CircleConst]
+ * +--------------+ / [CircleConst]
+ * | | / /
+ * | [CircleConv2D] [CircleConst]
+ * [CircleConst] | | /
+ * [CircleConst] \ | [CircleMul] [CircleConst]
+ * \ \ | | /
+ * [CircleConv2D] [CircleAdd]
+ * |
+ */
+bool fused_batch_norm_with_conv(luci::CircleAdd *add)
+{
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *shift = nullptr;
+ if (auto add_lhs = dynamic_cast<luci::CircleMul *>(add->x()))
+ {
+ mul = add_lhs;
+ shift = dynamic_cast<luci::CircleConst *>(add->y());
+ }
+ else if (auto add_rhs = dynamic_cast<luci::CircleMul *>(add->y()))
+ {
+ mul = add_rhs;
+ shift = dynamic_cast<luci::CircleConst *>(add->x());
+ }
+
+ // If CircleMul is not found or constant operand of CircleAdd is not found,
+ // this pass cannot be applied.
+ if (mul == nullptr || shift == nullptr)
+ return false;
+
+ // If FusedActivationFunction of mul is not none, this pass cannot be applied.
+ if (mul->fusedActivationFunction() != luci::FusedActFunc::NONE)
+ return false;
+
+ // To apply this pass, shape of shift should be [1, 1, 1, out_channel].
+ if (shift->rank() != 4)
+ return false;
+ for (uint32_t i = 0; i < 3; ++i)
+ if (shift->dim(i).value() != 1)
+ return false;
+
+ luci::CircleConv2D *conv = nullptr;
+ luci::CircleConst *scale = nullptr;
+ if (auto mul_lhs = dynamic_cast<luci::CircleConv2D *>(mul->x()))
+ {
+ conv = mul_lhs;
+ scale = dynamic_cast<luci::CircleConst *>(mul->y());
+ }
+ else if (auto mul_rhs = dynamic_cast<luci::CircleConv2D *>(mul->y()))
+ {
+ conv = mul_rhs;
+ scale = dynamic_cast<luci::CircleConst *>(mul->x());
+ }
+
+ // If CircleConv2D is not found or constant operand of CircleMul is not found,
+ // this pass cannot be applied.
+ if (conv == nullptr || scale == nullptr)
+ return false;
+
+ // To apply this pass, shape of scale should be [1, 1, 1, out_channel].
+ if (scale->rank() != 4)
+ return false;
+ for (uint32_t i = 0; i < 3; ++i)
+ if (scale->dim(i).value() != 1)
+ return false;
+
+ // If FusedActivationFunction of conv is not none, this pass cannot be applied.
+ if (conv->fusedActivationFunction() != luci::FusedActFunc::NONE)
+ return false;
+
+ luci::CircleConst *filter = dynamic_cast<luci::CircleConst *>(conv->filter());
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(conv->bias());
+
+ // If filter or bias of conv is not const, this pass cannot be applied.
+ if (filter == nullptr || bias == nullptr)
+ return false;
+
+ // If dtype of filter is different with scale and shift, multiplication may be impossible.
+ if (filter->dtype() != scale->dtype())
+ return false;
+ if (filter->dtype() != shift->dtype())
+ return false;
+
+ // TODO Support more data type
+ if (filter->dtype() != loco::DataType::FLOAT32)
+ return false;
+
+ // Output channel dimension should be same. If not, this pass cannot be applied.
+ if (filter->dim(0).value() != scale->dim(3).value())
+ return false;
+ if (filter->dim(0).value() != shift->dim(3).value())
+ return false;
+
+ auto name = add->name();
+ assert(name.length() > 0);
+
+ luci::CircleConv2D *fused_conv = add->graph()->nodes()->create<luci::CircleConv2D>();
+ luci::CircleConst *fused_filter = add->graph()->nodes()->create<luci::CircleConst>();
+ luci::CircleConst *fused_bias = add->graph()->nodes()->create<luci::CircleConst>();
+
+ uint32_t filter_out_channel = filter->dim(0).value();
+ uint32_t filter_height = filter->dim(1).value();
+ uint32_t filter_width = filter->dim(2).value();
+ uint32_t filter_in_channel = filter->dim(3).value();
+
+ // Copy filter
+ fused_filter->dtype(filter->dtype());
+ fused_filter->size<loco::DataType::FLOAT32>(filter->size<loco::DataType::FLOAT32>());
+ fused_filter->rank(4);
+ fused_filter->dim(0).set(filter_out_channel);
+ fused_filter->dim(1).set(filter_height);
+ fused_filter->dim(2).set(filter_width);
+ fused_filter->dim(3).set(filter_in_channel);
+ fused_filter->shape_status(luci::ShapeStatus::VALID);
+ fused_filter->name(name + "/Conv2D/filter");
+
+ // Fuse scale to new filter
+ for (uint32_t c = 0; c < filter_out_channel; c++)
+ {
+ for (uint32_t h = 0; h < filter_height; h++)
+ {
+ for (uint32_t w = 0; w < filter_width; w++)
+ {
+ for (uint32_t b = 0; b < filter_in_channel; b++)
+ {
+ uint32_t offset = c * filter_height * filter_width * filter_in_channel +
+ h * filter_width * filter_in_channel + w * filter_in_channel + b;
+ fused_filter->at<loco::DataType::FLOAT32>(offset) =
+ filter->at<loco::DataType::FLOAT32>(offset) * scale->at<loco::DataType::FLOAT32>(c);
+ }
+ }
+ }
+ }
+
+ // Copy bias
+ assert(bias->rank() == 1);
+ assert(bias->dim(0).value() == filter_out_channel);
+ fused_bias->dtype(bias->dtype());
+ fused_bias->size<loco::DataType::FLOAT32>(bias->size<loco::DataType::FLOAT32>());
+ fused_bias->rank(1);
+ fused_bias->dim(0).set(filter_out_channel);
+ fused_bias->shape_status(luci::ShapeStatus::VALID);
+ fused_bias->name(name + "/Conv2D/bias");
+
+ // Fuse scale and shift to bias
+ for (uint32_t b = 0; b < filter_out_channel; ++b)
+ {
+ fused_bias->at<loco::DataType::FLOAT32>(b) =
+ bias->at<loco::DataType::FLOAT32>(b) * scale->at<loco::DataType::FLOAT32>(b) +
+ shift->at<loco::DataType::FLOAT32>(b);
+ }
+
+ // Set attributes of fused_conv
+ fused_conv->input(conv->input());
+ fused_conv->filter(fused_filter);
+ fused_conv->bias(fused_bias);
+ fused_conv->fusedActivationFunction(add->fusedActivationFunction());
+ fused_conv->padding(conv->padding());
+ fused_conv->stride()->h(conv->stride()->h());
+ fused_conv->stride()->w(conv->stride()->w());
+ fused_conv->dilation()->h(conv->dilation()->h());
+ fused_conv->dilation()->w(conv->dilation()->w());
+ fused_conv->name(name + "/Conv2D");
+ luci::add_origin(fused_conv, luci::composite_origin({luci::get_origin(add), luci::get_origin(mul),
+ luci::get_origin(conv)}));
+
+ replace(add).with(fused_conv);
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool FuseBatchNormWithConvPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto add = dynamic_cast<luci::CircleAdd *>(node))
+ {
+ if (fused_batch_norm_with_conv(add))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FuseBatchNormWithConvPass.test.cpp b/compiler/luci/pass/src/FuseBatchNormWithConvPass.test.cpp
new file mode 100644
index 000000000..96bc2bd35
--- /dev/null
+++ b/compiler/luci/pass/src/FuseBatchNormWithConvPass.test.cpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseBatchNormWithConvPass.h"
+
+#include <gtest/gtest.h>
+
+TEST(FuseBatchNormWithConvPassTest, name)
+{
+ luci::FuseBatchNormWithConvPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/FuseBatchNormWithDwConvPass.cpp b/compiler/luci/pass/src/FuseBatchNormWithDwConvPass.cpp
new file mode 100644
index 000000000..8b2286f43
--- /dev/null
+++ b/compiler/luci/pass/src/FuseBatchNormWithDwConvPass.cpp
@@ -0,0 +1,237 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseBatchNormWithDwConvPass.h"
+
+#include "helpers/NodeFiller.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+
+namespace
+{
+/**
+ * Fuse Mul-Add to DepthwiseConv2D if possible.
+ *
+ * NOTE TF's BatchNormalization is converted to Mul and Add.
+ *
+ * BEFORE
+ * | [CircleConst]
+ * | / [CircleConst]
+ * | / /
+ * [CircleDepthwiseConv2D] [CircleConst]
+ * | /
+ * [CircleMul] [CircleConst]
+ * | /
+ * [CircleAdd]
+ * |
+ *
+ * AFTER
+ * | [CircleConst]
+ * +-------------------------------------+ / [CircleConst]
+ * | | / /
+ * | [CircleDepthwiseConv2D] [CircleConst]
+ * | [CircleConst] | /
+ * | / [CircleConst] [CircleMul] [CircleConst]
+ * | / / | /
+ * [CircleDepthwiseConv2D] [CircleAdd]
+ * |
+ *
+ */
+
+/**
+ * @brief Check shape is [x] or [1, 1, 1, x]
+ */
+bool is_scale_shift_shape(luci::CircleConst *node)
+{
+ auto rank = node->rank();
+ if (rank != 1 && rank != 4)
+ return false;
+ for (uint32_t r = 0; r < rank - 1; ++r)
+ {
+ if (node->dim(r).value() != 1)
+ return false;
+ }
+ return true;
+}
+
+bool fused_batch_norm_with_dwconv(luci::CircleAdd *add)
+{
+ assert(add != nullptr);
+
+ // Find the pattern of CircleDepthwiseConv2D - CircleMul - CircleAdd
+ luci::CircleConst *scale = nullptr;
+ luci::CircleConst *shift = nullptr;
+ luci::CircleDepthwiseConv2D *dwconv = nullptr;
+ luci::CircleMul *mul = nullptr;
+ if (not luci::fill(&shift, &mul).with_commutative_args_of(add))
+ return false;
+ if (not luci::fill(&scale, &dwconv).with_commutative_args_of(mul))
+ return false;
+
+ // check scale and shift constant attributes
+ // scale and shift can be [x] or [1, 1, 1, x]
+ if (not is_scale_shift_shape(scale))
+ return false;
+ if (not is_scale_shift_shape(shift))
+ return false;
+
+ // check mul, add attributes
+ if (mul->dtype() != loco::DataType::FLOAT32)
+ return false;
+ if (mul->fusedActivationFunction() != luci::FusedActFunc::NONE)
+ return false;
+ if (add->dtype() != loco::DataType::FLOAT32)
+ return false;
+ // TODO support more Activations
+ if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
+ add->fusedActivationFunction() != luci::FusedActFunc::RELU6)
+ return false;
+
+ // get weight of dwconv
+ auto filter = dynamic_cast<luci::CircleConst *>(dwconv->filter());
+ if (not filter)
+ return false;
+ if (filter->dtype() != loco::DataType::FLOAT32)
+ return false;
+ if (filter->rank() != 4)
+ return false;
+
+ // check attributes of dwconv
+ if (dwconv->fusedActivationFunction() != luci::FusedActFunc::NONE)
+ return false;
+ if (dwconv->depthMultiplier() < 0) // can this happen?
+ return false;
+
+ // get bias of dwconv
+ auto bias = dynamic_cast<luci::CircleConst *>(dwconv->bias());
+ if (not bias)
+ return false;
+ if (bias->dtype() != loco::DataType::FLOAT32)
+ return false;
+ if (bias->rank() != 1)
+ return false;
+
+ // filter represents as [1, H, W, C*M] where M is multiplier.
+ auto filter_out_chn = filter->dim(3).value();
+ auto multiplier = static_cast<uint32_t>(dwconv->depthMultiplier());
+ auto srank = scale->rank(); // as rank can be 1 or 4
+ if (filter_out_chn != scale->dim(srank - 1).value() * multiplier)
+ return false;
+ srank = shift->rank();
+ if (filter_out_chn != shift->dim(srank - 1).value() * multiplier)
+ return false;
+ auto channel = filter_out_chn / multiplier;
+
+ auto name = add->name();
+ assert(name.length() > 0);
+
+ loco::Graph *graph = add->graph();
+ luci::CircleDepthwiseConv2D *fused_dwconv = graph->nodes()->create<luci::CircleDepthwiseConv2D>();
+ luci::CircleConst *fused_filter = graph->nodes()->create<luci::CircleConst>();
+ luci::CircleConst *fused_bias = graph->nodes()->create<luci::CircleConst>();
+
+ auto filter_in_chn = filter->dim(0).value();
+ auto filter_height = filter->dim(1).value();
+ auto filter_width = filter->dim(2).value();
+ assert(filter_in_chn == 1);
+
+ // Copy filter shape
+ fused_filter->dtype(filter->dtype());
+ fused_filter->size<loco::DataType::FLOAT32>(filter->size<loco::DataType::FLOAT32>());
+ fused_filter->rank(4);
+ fused_filter->dim(0).set(filter_in_chn);
+ fused_filter->dim(1).set(filter_height);
+ fused_filter->dim(2).set(filter_width);
+ fused_filter->dim(3).set(filter_out_chn);
+ fused_filter->shape_status(luci::ShapeStatus::VALID);
+ fused_filter->name(name + "/DepthwiseConv2D/filter");
+
+ // fused filter weight = filter weight * mul(scale) + add(shift)
+ for (uint32_t b = 0; b < filter_in_chn; b++)
+ {
+ for (uint32_t h = 0; h < filter_height; h++)
+ {
+ for (uint32_t w = 0; w < filter_width; w++)
+ {
+ for (uint32_t c = 0; c < filter_out_chn; c++)
+ {
+ uint32_t offset = b * filter_height * filter_width * filter_out_chn +
+ h * filter_width * filter_out_chn + w * filter_out_chn + c;
+ uint32_t chn = c / multiplier;
+ fused_filter->at<loco::DataType::FLOAT32>(offset) =
+ filter->at<loco::DataType::FLOAT32>(offset) * scale->at<loco::DataType::FLOAT32>(chn);
+ }
+ }
+ }
+ }
+
+ // Fuse bias with scale and shift
+ fused_bias->dtype(shift->dtype());
+ fused_bias->size<loco::DataType::FLOAT32>(shift->size<loco::DataType::FLOAT32>());
+ fused_bias->rank(1);
+ fused_bias->dim(0).set(channel);
+ fused_bias->shape_status(luci::ShapeStatus::VALID);
+ for (uint32_t c = 0; c < channel; ++c)
+ {
+ fused_bias->at<loco::DataType::FLOAT32>(c) =
+ bias->at<loco::DataType::FLOAT32>(c) * scale->at<loco::DataType::FLOAT32>(c) +
+ shift->at<loco::DataType::FLOAT32>(c);
+ }
+ fused_bias->name(name + "/DepthwiseConv2D/bias");
+
+ // set new tconv properties
+ fused_dwconv->input(dwconv->input());
+ fused_dwconv->filter(fused_filter);
+ fused_dwconv->bias(fused_bias);
+ fused_dwconv->fusedActivationFunction(add->fusedActivationFunction());
+ fused_dwconv->padding(dwconv->padding());
+ fused_dwconv->stride()->h(dwconv->stride()->h());
+ fused_dwconv->stride()->w(dwconv->stride()->w());
+ fused_dwconv->depthMultiplier(dwconv->depthMultiplier());
+ fused_dwconv->dilation()->h(dwconv->dilation()->h());
+ fused_dwconv->dilation()->w(dwconv->dilation()->w());
+ fused_dwconv->name(name + "/DepthwiseConv2D");
+ luci::add_origin(fused_dwconv,
+ luci::composite_origin(
+ {luci::get_origin(add), luci::get_origin(mul), luci::get_origin(dwconv)}));
+
+ replace(add).with(fused_dwconv);
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool FuseBatchNormWithDwConvPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto add = dynamic_cast<luci::CircleAdd *>(node))
+ {
+ if (fused_batch_norm_with_dwconv(add))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FuseBatchNormWithDwConvPass.test.cpp b/compiler/luci/pass/src/FuseBatchNormWithDwConvPass.test.cpp
new file mode 100644
index 000000000..3030a7306
--- /dev/null
+++ b/compiler/luci/pass/src/FuseBatchNormWithDwConvPass.test.cpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseBatchNormWithDwConvPass.h"
+
+#include <gtest/gtest.h>
+
+TEST(FuseBatchNormWithDwConvPassTest, name)
+{
+ luci::FuseBatchNormWithDwConvPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp
deleted file mode 100644
index 95ccd8176..000000000
--- a/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp
+++ /dev/null
@@ -1,159 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "luci/Pass/FuseBatchNormWithTConv.h"
-
-#include <luci/IR/CircleNodes.h>
-
-namespace
-{
-/**
- * NOTE TF's fusedBatchNorm is converted to mul and add of Circle.
- *
- * BEFORE
- *
- * [CircleTransposeConv]
- * |
- * [mul]
- * |
- * [add]
- * AFTER
- *
- * [CircleTransposeConv]
- */
-bool fused_batch_norm_with_tconv(luci::CircleTransposeConv *tconv)
-{
- // check whether it has bias or not. This optimization works only if it doesn't.
- auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
- if (not bias)
- return false;
-
- // get weight of tconv
- auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
- if (not filter)
- return false;
- if (filter->dtype() != loco::DataType::FLOAT32)
- return false;
-
- // get mul node
- auto tconv_output = loco::succs(tconv);
- assert(tconv_output.size() == 1);
- auto mul = dynamic_cast<luci::CircleMul *>(*tconv_output.begin());
- if (not mul)
- return false;
- if (mul->dtype() != loco::DataType::FLOAT32)
- return false;
-
- // get add node
- auto mul_output = loco::succs(mul);
- assert(mul_output.size() == 1);
- auto add = dynamic_cast<luci::CircleAdd *>(*mul_output.begin());
- if (not add)
- return false;
- if (add->dtype() != loco::DataType::FLOAT32)
- return false;
- if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
- add->fusedActivationFunction() != luci::FusedActFunc::RELU6)
- return false;
-
- // get scale of batchnorm
- auto scale = dynamic_cast<luci::CircleConst *>(mul->y());
- if (not scale)
- return false;
-
- // scale dim(0) == tconv filter channel dim
- if (filter->rank() != 4)
- return false;
- auto filter_out_dim = filter->dim(0).value();
- if (scale->rank() != 1)
- return false;
- auto scale_dim = scale->dim(0).value();
- if (filter_out_dim != scale_dim)
- return false;
-
- // get shift of batchnorm
- auto shift = dynamic_cast<luci::CircleConst *>(add->y());
- if (not shift)
- return false;
-
- // shift dim(0) == tconv filter channel dim
- if (shift->rank() != 1)
- return false;
- auto shift_dim = shift->dim(0).value();
- if (filter_out_dim != shift_dim)
- return false;
-
- // filter weight = filter weight * mul(scale) + add(shift)
- uint32_t filter_height_dim = filter->dim(1).value();
- uint32_t filter_width_dim = filter->dim(2).value();
- uint32_t filter_in_dim = filter->dim(3).value();
- for (uint32_t c = 0; c < filter_out_dim; c++)
- {
- for (uint32_t h = 0; h < filter_height_dim; h++)
- {
- for (uint32_t w = 0; w < filter_width_dim; w++)
- {
- for (uint32_t b = 0; b < filter_in_dim; b++)
- {
- uint32_t offset = c * filter_height_dim * filter_width_dim * filter_in_dim +
- h * filter_width_dim * filter_in_dim + w * filter_in_dim + b;
- filter->at<loco::DataType::FLOAT32>(offset) *= scale->at<loco::DataType::FLOAT32>(c);
- }
- }
- }
- }
-
- // fuse shift with transposed conv
- tconv->bias(shift);
-
- if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
- {
- // separate relu op from add op
- auto relu = add->graph()->nodes()->create<luci::CircleRelu6>();
- relu->features(tconv);
-
- // remove mul node
- replace(add).with(relu);
- }
- else
- {
- replace(add).with(tconv);
- }
-
- return true;
-}
-
-} // namespace
-
-namespace luci
-{
-
-bool FuseBatchNormWithTConvPass::run(loco::Graph *g)
-{
- bool changed = false;
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
- {
- auto tconv = dynamic_cast<luci::CircleTransposeConv *>(node);
- if (not tconv)
- continue;
-
- changed |= fused_batch_norm_with_tconv(tconv);
- }
-
- return changed;
-}
-
-} // namespace luci
diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp
new file mode 100644
index 000000000..337954960
--- /dev/null
+++ b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp
@@ -0,0 +1,208 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseBatchNormWithTConvPass.h"
+
+#include "helpers/NodeFiller.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+
+namespace
+{
+/**
+ * Fuse Mul-Add to TransposeConv if possible.
+ *
+ * NOTE TF's BatchNormalization is converted to Mul and Add.
+ *
+ * BEFORE
+ * | [CircleOutputExclude]
+ * | / [CircleConst]
+ * | / /
+ * [CircleTransposeConv] [CircleConst]
+ * | /
+ * [CircleMul] [CircleConst]
+ * | /
+ * [CircleAdd]
+ * |
+ *
+ * AFTER
+ * | [CircleOutputExclude]
+ * +-------------------------------------+ / [CircleConst]
+ * | | / /
+ * | [CircleTransposeConv] [CircleConst]
+ * | [CircleConst] | /
+ * | / [CircleConst] [CircleMul] [CircleConst]
+ * | / / | /
+ * [CircleTransposeConv] [CircleAdd]
+ * |
+ * ([CircleRelu6])
+ * |
+ *
+ * Note: CircleRelu6 is inserted if Add activation is ReLU6
+ */
+bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
+{
+ assert(add != nullptr);
+
+ // Find the pattern of CircleTransposeConv - CircleMul - CircleAdd
+ luci::CircleConst *scale = nullptr;
+ luci::CircleConst *shift = nullptr;
+ luci::CircleTransposeConv *tconv = nullptr;
+ luci::CircleMul *mul = nullptr;
+ if (not luci::fill(&shift, &mul).with_commutative_args_of(add))
+ return false;
+ if (not luci::fill(&scale, &tconv).with_commutative_args_of(mul))
+ return false;
+
+ // check scale and shift constant attributes
+ if (scale->rank() != 1)
+ return false;
+ if (shift->rank() != 1)
+ return false;
+ // check mul, add attributes
+ if (mul->dtype() != loco::DataType::FLOAT32)
+ return false;
+ if (add->dtype() != loco::DataType::FLOAT32)
+ return false;
+ if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
+ add->fusedActivationFunction() != luci::FusedActFunc::RELU6)
+ return false;
+
+ // tconv bias should be not set
+ if (not dynamic_cast<luci::CircleOutputExclude *>(tconv->bias()))
+ return false;
+
+ // get weight of tconv
+ auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
+ if (not filter)
+ return false;
+ if (filter->dtype() != loco::DataType::FLOAT32)
+ return false;
+ if (filter->rank() != 4)
+ return false;
+
+ auto filter_out_chn = filter->dim(0).value();
+ if (filter_out_chn != scale->dim(0).value())
+ return false;
+ if (filter_out_chn != shift->dim(0).value())
+ return false;
+
+ auto name = add->name();
+ assert(name.length() > 0);
+
+ loco::Graph *graph = add->graph();
+ luci::CircleTransposeConv *fused_tconv = graph->nodes()->create<luci::CircleTransposeConv>();
+ luci::CircleConst *fused_filter = graph->nodes()->create<luci::CircleConst>();
+ luci::CircleConst *fused_bias = graph->nodes()->create<luci::CircleConst>();
+
+ auto filter_height = filter->dim(1).value();
+ auto filter_width = filter->dim(2).value();
+ auto filter_in_chn = filter->dim(3).value();
+
+ // Copy filter shape
+ fused_filter->dtype(filter->dtype());
+ fused_filter->size<loco::DataType::FLOAT32>(filter->size<loco::DataType::FLOAT32>());
+ fused_filter->rank(4);
+ fused_filter->dim(0).set(filter_out_chn);
+ fused_filter->dim(1).set(filter_height);
+ fused_filter->dim(2).set(filter_width);
+ fused_filter->dim(3).set(filter_in_chn);
+ fused_filter->shape_status(luci::ShapeStatus::VALID);
+ fused_filter->name(name + "/TransposeConv/filter");
+
+ // fused filter weight = filter weight * mul(scale) + add(shift)
+ for (uint32_t c = 0; c < filter_out_chn; c++)
+ {
+ for (uint32_t h = 0; h < filter_height; h++)
+ {
+ for (uint32_t w = 0; w < filter_width; w++)
+ {
+ for (uint32_t b = 0; b < filter_in_chn; b++)
+ {
+ uint32_t offset = c * filter_height * filter_width * filter_in_chn +
+ h * filter_width * filter_in_chn + w * filter_in_chn + b;
+ fused_filter->at<loco::DataType::FLOAT32>(offset) =
+ filter->at<loco::DataType::FLOAT32>(offset) * scale->at<loco::DataType::FLOAT32>(c);
+ }
+ }
+ }
+ }
+
+ // Copy fused_bias from shift
+ fused_bias->dtype(shift->dtype());
+ fused_bias->size<loco::DataType::FLOAT32>(shift->size<loco::DataType::FLOAT32>());
+ fused_bias->rank(1);
+ fused_bias->dim(0).set(filter_out_chn);
+ fused_bias->shape_status(luci::ShapeStatus::VALID);
+ for (uint32_t c = 0; c < filter_out_chn; ++c)
+ {
+ fused_bias->at<loco::DataType::FLOAT32>(c) = shift->at<loco::DataType::FLOAT32>(c);
+ }
+ fused_bias->name(name + "/TransposeConv/bias");
+
+ // set new tconv properties
+ fused_tconv->inputSizes(tconv->inputSizes());
+ fused_tconv->filter(fused_filter);
+ fused_tconv->outBackprop(tconv->outBackprop());
+ fused_tconv->bias(fused_bias);
+ fused_tconv->padding(tconv->padding());
+ fused_tconv->stride()->h(tconv->stride()->h());
+ fused_tconv->stride()->w(tconv->stride()->w());
+ fused_tconv->name(name + "/TransposeConv");
+ luci::add_origin(fused_tconv,
+ luci::composite_origin(
+ {luci::get_origin(add), luci::get_origin(mul), luci::get_origin(tconv)}));
+
+ if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
+ {
+ // separate relu op from add op
+ auto relu = add->graph()->nodes()->create<luci::CircleRelu6>();
+ relu->features(fused_tconv);
+ relu->name(name + "/Relu6");
+ luci::add_origin(relu, luci::get_origin(add));
+
+ replace(add).with(relu);
+ }
+ else
+ {
+ replace(add).with(fused_tconv);
+ }
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool FuseBatchNormWithTConvPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto add = dynamic_cast<luci::CircleAdd *>(node))
+ {
+ if (fused_batch_norm_with_tconv(add))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.test.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.test.cpp
new file mode 100644
index 000000000..051100dc9
--- /dev/null
+++ b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.test.cpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseBatchNormWithTConvPass.h"
+
+#include <gtest/gtest.h>
+
+TEST(FuseBatchNormWithTConvPassTest, name)
+{
+ luci::FuseBatchNormWithTConvPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/FuseInstanceNormPass.cpp b/compiler/luci/pass/src/FuseInstanceNormPass.cpp
index 237152f98..ab7baa1fa 100644
--- a/compiler/luci/pass/src/FuseInstanceNormPass.cpp
+++ b/compiler/luci/pass/src/FuseInstanceNormPass.cpp
@@ -15,105 +15,16 @@
*/
#include "luci/Pass/FuseInstanceNormPass.h"
+#include "helpers/NodeFiller.h"
#include "FuseInstanceNormPassInternal.h"
#include <luci/IR/CircleNodes.h>
-#include <loco/Service/ShapeInference.h>
+#include <luci/Profile/CircleNodeOrigin.h>
#include <cassert>
#include <set>
-// Helper to find commutative node's arguments
-namespace
-{
-
-/**
- * INTRODUCTION
- * Binary operation f(x,y) is 'commutative' when
- * f(x,y) == f(y,x) holds for all x, y.
- * For examples, ADD, MUL and SQUARED_DIFFERENCE are commutative.
- * These helpers make it easy to find commutative arguemnts of commtative node.
- *
- * HOW TO USE
- * COMM_NODE *node;
- * ARG_TYPE_1 *arg1;
- * ARG_TYPE_2 *arg2;
- *
- * bool ok = fill(&arg1, &arg2).with_commutative_args_of(node);
- *
- * Result
- * If 'node's commutative argument types are actually {ARG_TYPE_1, ARG_TYPE_2}
- * (as a set), 'arg1' and 'arg2' set as actual 'node's arguemnts with matching
- * type, and return value 'ok' is true.
- * Otherwise, 'arg1' and 'arg2' not changed, 'ok' is false.
- */
-
-template <class ARG_TYPE_1, class ARG_TYPE_2> class NodeFiller final
-{
-public:
- NodeFiller(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) : _arg_1(arg_1), _arg_2(arg_2)
- {
- // DO NOTHING
- }
-
- /**
- * @return true When 'node's argument types are 'ARG_TYPE_1' and 'ARG_TYPE_2'
- * In such case, it assign '_arg_1' and '_arg_2' to actual arguments
- *
- * @return false When 'node's argument types are NOT matched with 'ARG_TYPE_*'
- * In such case, it does not amend '_arg_1' and '_arg_2'
- *
- * @require COMM_NODE has member x() and y()
- */
- template <class COMM_NODE> bool with_commutative_args_of(const COMM_NODE *node);
-
-private:
- ARG_TYPE_1 **_arg_1;
- ARG_TYPE_2 **_arg_2;
-};
-
-template <class ARG_TYPE_1, class ARG_TYPE_2>
-inline NodeFiller<ARG_TYPE_1, ARG_TYPE_2> fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
-{
- return NodeFiller<ARG_TYPE_1, ARG_TYPE_2>{arg_1, arg_2};
-}
-
-template <class ARG_TYPE_1, class ARG_TYPE_2>
-template <class COMM_NODE>
-bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_commutative_args_of(const COMM_NODE *node)
-{
- // Case 1) X == ARG_TYPE_1 / Y == ARG_TYPE_2
- {
- auto x = dynamic_cast<ARG_TYPE_1 *>(node->x());
- auto y = dynamic_cast<ARG_TYPE_2 *>(node->y());
-
- if (x && y)
- {
- *_arg_1 = x;
- *_arg_2 = y;
- return true;
- }
- }
-
- // Case 2) X == ARG_TYPE_2 / Y == ARG_TYPE_1
- {
- auto x = dynamic_cast<ARG_TYPE_2 *>(node->x());
- auto y = dynamic_cast<ARG_TYPE_1 *>(node->y());
-
- if (x && y)
- {
- *_arg_1 = y;
- *_arg_2 = x;
- return true;
- }
- }
-
- return false;
-}
-
-} // namespace
-
// Helper to check detail
/// @return true When node has shape of '1 x .. x 1 x depth'
@@ -150,11 +61,10 @@ bool is_instance_mean_v0(luci::CircleMean *mean)
//
// CHECK 1) input is rank 4
//
- auto input = mean->input();
- if (not loco::shape_known(input))
+ auto input = loco::must_cast<luci::CircleNode *>(mean->input());
+ if (input->shape_status() != luci::ShapeStatus::VALID)
return false;
- auto input_shape = loco::shape_get(input).as<loco::TensorShape>();
- if (input_shape.rank() != 4)
+ if (input->rank() != 4)
return false;
//
@@ -195,11 +105,10 @@ bool is_instance_mean_v1(luci::CircleMean *mean)
//
// CHECK 1) input is rank 5 (NHWCX)
//
- auto input = mean->input();
- if (not loco::shape_known(input))
+ auto input = loco::must_cast<luci::CircleNode *>(mean->input());
+ if (input->shape_status() != luci::ShapeStatus::VALID)
return false;
- auto input_shape = loco::shape_get(input).as<loco::TensorShape>();
- if (input_shape.rank() != 5)
+ if (input->rank() != 5)
return false;
//
@@ -445,8 +354,9 @@ bool InstanceNormPattern::matched()
// So it is handled in the separate if statement
if (_pv == PatternVersion::Version_2)
{
- CHECK_OR_FALSE(fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
- CHECK_OR_FALSE(fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
+ CHECK_OR_FALSE(
+ luci::fill(&mul_gamma, &const_as_beta).with_commutative_args_of(add_as_terminal));
+ CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
sub = dynamic_cast<luci::CircleSub *>(div->x());
CHECK_OR_FALSE(sub);
@@ -456,6 +366,7 @@ bool InstanceNormPattern::matched()
luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm);
CHECK_OR_FALSE(ifm_node->rank() == 4);
+ CHECK_OR_FALSE(ifm_node->dim(3).known());
uint32_t ifm_channel_depth = ifm_node->dim(3).value();
mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
@@ -477,7 +388,7 @@ bool InstanceNormPattern::matched()
CHECK_OR_FALSE(zero_point_five->at<loco::DataType::FLOAT32>(0) == 0.5);
CHECK_OR_FALSE(
- fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
+ luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
// TODO Support regarding broadcast
CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
@@ -489,7 +400,8 @@ bool InstanceNormPattern::matched()
loco::Node *ifm_should_be = nullptr;
luci::CircleMean *mean_of_ifm_should_be = nullptr;
- CHECK_OR_FALSE(fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff));
+ CHECK_OR_FALSE(
+ luci::fill(&ifm_should_be, &mean_of_ifm_should_be).with_commutative_args_of(sqdiff));
CHECK_OR_FALSE(ifm == ifm_should_be);
CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
@@ -503,25 +415,25 @@ bool InstanceNormPattern::matched()
if (_pv == PatternVersion::Version_0)
{
- CHECK_OR_FALSE(fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
- CHECK_OR_FALSE(fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
+ CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
+ CHECK_OR_FALSE(luci::fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
}
if (_pv == PatternVersion::Version_1)
{
- CHECK_OR_FALSE(fill(&mul_as_scaled_reshape, &sub).with_commutative_args_of(add_as_terminal));
CHECK_OR_FALSE(
- fill(&reshape_of_ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_reshape));
+ luci::fill(&mul_as_scaled_reshape, &sub).with_commutative_args_of(add_as_terminal));
+ CHECK_OR_FALSE(
+ luci::fill(&reshape_of_ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_reshape));
ifm = reshape_of_ifm->tensor();
}
- CHECK_OR_FALSE(loco::shape_known(ifm));
- auto ifm_shape = loco::shape_get(ifm);
- CHECK_OR_FALSE(ifm_shape.domain() == loco::Domain::Tensor);
- auto ifm_tensor_shape = ifm_shape.as<loco::TensorShape>();
- CHECK_OR_FALSE(ifm_tensor_shape.rank() == 4);
- uint32_t ifm_channel_depth = ifm_tensor_shape.dim(3).value();
+ auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
+ CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
+ CHECK_OR_FALSE(ifm_circle->rank() == 4);
+ CHECK_OR_FALSE(ifm_circle->dim(3).known());
+ uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
- CHECK_OR_FALSE(fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
+ CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
if (_pv == PatternVersion::Version_0)
{
@@ -536,7 +448,7 @@ bool InstanceNormPattern::matched()
CHECK_OR_FALSE(add_as_variance);
CHECK_OR_FALSE(
- fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
+ luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
// TODO Support regarding broadcast
@@ -557,7 +469,7 @@ bool InstanceNormPattern::matched()
if (_pv == PatternVersion::Version_0)
{
loco::Node *ifm_should_be = nullptr;
- CHECK_OR_FALSE(fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
+ CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
CHECK_OR_FALSE(ifm == ifm_should_be);
CHECK_OR_FALSE(is_instance_mean_v0(mean_of_ifm));
CHECK_OR_FALSE(ifm == mean_of_ifm->input());
@@ -565,7 +477,8 @@ bool InstanceNormPattern::matched()
if (_pv == PatternVersion::Version_1)
{
loco::Node *reshape_should_be = nullptr;
- CHECK_OR_FALSE(fill(&reshape_should_be, &mean_of_reshape).with_commutative_args_of(sqdiff));
+ CHECK_OR_FALSE(
+ luci::fill(&reshape_should_be, &mean_of_reshape).with_commutative_args_of(sqdiff));
CHECK_OR_FALSE(reshape_of_ifm == reshape_should_be);
CHECK_OR_FALSE(is_instance_mean_v1(mean_of_reshape));
CHECK_OR_FALSE(reshape_of_ifm == mean_of_reshape->input());
@@ -592,15 +505,15 @@ bool InstanceNormPattern::matched()
if (_pv == PatternVersion::Version_0)
{
- CHECK_OR_FALSE(fill(&mul_gamma_should_be, &mean_of_ifm_should_be)
- .with_commutative_args_of(mul_as_scaled_mean));
+ CHECK_OR_FALSE(luci::fill(&mul_gamma_should_be, &mean_of_ifm_should_be)
+ .with_commutative_args_of(mul_as_scaled_mean));
CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);
}
if (_pv == PatternVersion::Version_1)
{
- CHECK_OR_FALSE(fill(&mul_gamma_should_be, &mean_of_reshape_should_be)
- .with_commutative_args_of(mul_as_scaled_mean));
+ CHECK_OR_FALSE(luci::fill(&mul_gamma_should_be, &mean_of_reshape_should_be)
+ .with_commutative_args_of(mul_as_scaled_mean));
CHECK_OR_FALSE(mul_gamma == mul_gamma_should_be);
CHECK_OR_FALSE(mean_of_reshape == mean_of_reshape_should_be);
}
@@ -631,47 +544,59 @@ void fuse_instance_norm(const InstanceNormPattern &p)
auto graph = p.add_as_terminal->graph();
- // Special case for version 2 (no need to reshape)
- if (p.version() == InstanceNormPattern::Version_2)
+ // Version 0 and 1 need to reshape
+ if (p.version() != InstanceNormPattern::Version_2)
{
- // Make Instance Norm to replace
- auto instance_norm = graph->nodes()->create<luci::CircleInstanceNorm>();
- instance_norm->input(p.ifm);
- instance_norm->gamma(p.const_as_gamma);
- instance_norm->beta(p.const_as_beta);
- float epsilon = p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
- instance_norm->epsilon(epsilon);
- instance_norm->fusedActivationFunction(p.add_as_terminal->fusedActivationFunction());
-
- replace(p.add_as_terminal).with(instance_norm);
-
- return;
- }
-
- // Make reshape for gamma & beta
- auto reshape_gamma = graph->nodes()->create<luci::CircleReshape>();
- auto reshape_beta = graph->nodes()->create<luci::CircleReshape>();
- {
- auto ifm_shape = loco::shape_get(p.ifm).as<loco::TensorShape>();
- uint32_t ifm_channel_depth = ifm_shape.dim(3).value();
-
- int32_t new_shape[1] = {static_cast<int32_t>(ifm_channel_depth)};
-
- reshape_gamma->tensor(p.const_as_gamma);
- reshape_beta->tensor(p.const_as_beta);
+ p.const_as_gamma->rank(1);
+ p.const_as_gamma->dim(0).set(p.const_as_gamma->size<loco::DataType::FLOAT32>());
+ p.const_as_beta->rank(1);
+ p.const_as_beta->dim(0).set(p.const_as_beta->size<loco::DataType::FLOAT32>());
- luci::set_new_shape(reshape_gamma, new_shape, 1);
- luci::set_new_shape(reshape_beta, new_shape, 1);
+ p.const_as_gamma->shape_status(luci::ShapeStatus::UNDEFINED);
+ p.const_as_beta->shape_status(luci::ShapeStatus::UNDEFINED);
}
// Make Instance Norm to replace
auto instance_norm = graph->nodes()->create<luci::CircleInstanceNorm>();
instance_norm->input(p.ifm);
- instance_norm->gamma(reshape_gamma);
- instance_norm->beta(reshape_beta);
+ instance_norm->gamma(p.const_as_gamma);
+ instance_norm->beta(p.const_as_beta);
float epsilon = p.const_as_epsilon->at<loco::DataType::FLOAT32>(0);
instance_norm->epsilon(epsilon);
instance_norm->fusedActivationFunction(p.add_as_terminal->fusedActivationFunction());
+ // NOTE unique name should be assigned in export
+ instance_norm->name("InstanceNorm");
+
+ // set origin
+ std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
+ luci::get_origin(p.sqdiff),
+ luci::get_origin(p.mean_as_variance),
+ luci::get_origin(p.add_as_variance),
+ luci::get_origin(p.mul_gamma),
+ luci::get_origin(p.sub),
+ luci::get_origin(p.add_as_terminal)};
+ if (p.version() == InstanceNormPattern::PatternVersion::Version_0)
+ {
+ origin_vec.push_back(luci::get_origin(p.mean_of_ifm));
+ origin_vec.push_back(luci::get_origin(p.rsqrt));
+ origin_vec.push_back(luci::get_origin(p.mul_as_scaled_ifm));
+ origin_vec.push_back(luci::get_origin(p.mul_as_scaled_mean));
+ }
+ if (p.version() == InstanceNormPattern::PatternVersion::Version_1)
+ {
+ origin_vec.push_back(luci::get_origin(p.reshape_of_ifm));
+ origin_vec.push_back(luci::get_origin(p.mean_of_reshape));
+ origin_vec.push_back(luci::get_origin(p.rsqrt));
+ origin_vec.push_back(luci::get_origin(p.mul_as_scaled_mean));
+ origin_vec.push_back(luci::get_origin(p.mul_as_scaled_reshape));
+ }
+ if (p.version() == InstanceNormPattern::PatternVersion::Version_2)
+ {
+ origin_vec.push_back(luci::get_origin(p.mean_of_ifm));
+ origin_vec.push_back(luci::get_origin(p.pow));
+ origin_vec.push_back(luci::get_origin(p.div));
+ }
+ luci::add_origin(instance_norm, luci::composite_origin(origin_vec));
replace(p.add_as_terminal).with(instance_norm);
}
diff --git a/compiler/luci/pass/src/FuseInstanceNormPass.test.cpp b/compiler/luci/pass/src/FuseInstanceNormPass.test.cpp
index 3037f3def..b83ccca50 100644
--- a/compiler/luci/pass/src/FuseInstanceNormPass.test.cpp
+++ b/compiler/luci/pass/src/FuseInstanceNormPass.test.cpp
@@ -16,6 +16,8 @@
#include "FuseInstanceNormPassInternal.h"
+#include "luci/Pass/FuseInstanceNormPass.h"
+
#include <vector>
#include <gtest/gtest.h>
@@ -34,6 +36,13 @@ void setShape(luci::CircleNode &node, const std::vector<int> &v)
} // namespace
+TEST(FuseInstanceNormPassTest, name)
+{
+ luci::FuseInstanceNormPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
TEST(FuseInstanceNormPass, is_quasi_1D_with_dummy_dim)
{
luci::CircleConst const_node;
diff --git a/compiler/luci/pass/src/FusePreActivationBatchNormPass.cpp b/compiler/luci/pass/src/FusePreActivationBatchNormPass.cpp
index bcde5fac4..469fcddbb 100644
--- a/compiler/luci/pass/src/FusePreActivationBatchNormPass.cpp
+++ b/compiler/luci/pass/src/FusePreActivationBatchNormPass.cpp
@@ -16,9 +16,11 @@
#include "luci/Pass/FusePreActivationBatchNormPass.h"
#include "FusePreActivationBatchNormPassInternal.h"
+#include "BatchNormPatternFinder.h"
#include <luci/IR/CircleNodes.h>
#include <luci/Log.h>
+#include <luci/Profile/CircleNodeOrigin.h>
namespace
{
@@ -37,83 +39,6 @@ bool is_non_negative(const luci::CircleConst *node)
return true;
}
-// Check if mul is batchnorm mul
-bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node,
- luci::CircleConst *&gamma)
-{
- auto x = dynamic_cast<luci::CircleConst *>(mul->x());
- auto y = dynamic_cast<luci::CircleConst *>(mul->y());
-
- luci::CircleNode *pred = nullptr;
- luci::CircleConst *constant = nullptr;
-
- if (x != nullptr && y == nullptr)
- {
- pred = loco::must_cast<luci::CircleNode *>(mul->y());
- constant = x;
- }
- else if (x == nullptr && y != nullptr)
- {
- pred = loco::must_cast<luci::CircleNode *>(mul->x());
- constant = y;
- }
- else
- {
- return false;
- }
-
- if (constant->rank() != 1)
- return false;
-
- auto channel_dim = constant->dim(0);
- if (!(channel_dim == mul->dim(mul->rank() - 1)))
- return false;
-
- pred_node = pred;
- gamma = constant;
- return true;
-}
-
-// Check if add is batchnorm add
-bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::CircleConst *&beta)
-{
- auto x = loco::must_cast<luci::CircleNode *>(add->x());
- auto y = loco::must_cast<luci::CircleNode *>(add->y());
-
- luci::CircleMul *pred = nullptr;
- luci::CircleConst *constant = nullptr;
-
- if (add->fusedActivationFunction() != luci::FusedActFunc::RELU)
- return false;
-
- if (x->opcode() == luci::CircleOpcode::CIRCLECONST && y->opcode() == luci::CircleOpcode::MUL)
- {
- pred = loco::must_cast<luci::CircleMul *>(y);
- constant = loco::must_cast<luci::CircleConst *>(x);
- }
- else if (x->opcode() == luci::CircleOpcode::MUL && y->opcode() == luci::CircleOpcode::CIRCLECONST)
- {
- pred = loco::must_cast<luci::CircleMul *>(x);
- constant = loco::must_cast<luci::CircleConst *>(y);
- }
- else
- {
- return false;
- }
-
- if (constant->rank() != 1)
- return false;
-
- auto channel_dim = constant->dim(0);
- // Assumption: Layout is channel-last
- if (!(channel_dim == add->dim(add->rank() - 1)))
- return false;
-
- mul = pred;
- beta = constant;
- return true;
-}
-
const luci::CircleConv2D *get_forward_conv2d(const luci::CircleNode *node, uint32_t channel_size)
{
auto opcode = node->opcode();
@@ -249,6 +174,9 @@ bool update_conv_bias_with_beta(luci::CircleConv2D *conv, const luci::CircleCons
auto size = beta->dim(0).value();
auto bias = dynamic_cast<luci::CircleConst *>(conv->bias());
+ auto name = conv->name();
+ assert(name.length() > 0);
+
if (bias == nullptr)
{
bias = conv->graph()->nodes()->create<luci::CircleConst>();
@@ -256,6 +184,7 @@ bool update_conv_bias_with_beta(luci::CircleConv2D *conv, const luci::CircleCons
bias->rank(1);
bias->dim(0).set(size);
bias->size<loco::DataType::FLOAT32>(size);
+ bias->name(name + "/bias");
conv->bias(bias);
}
else
@@ -282,14 +211,12 @@ bool update_conv_bias_with_beta(luci::CircleConv2D *conv, const luci::CircleCons
luci::CircleSub *insert_sub(luci::CircleNode *pred, luci::CircleConst *beta)
{
+ auto name = pred->name();
+ assert(name.length() > 0);
+
auto sub = pred->graph()->nodes()->create<luci::CircleSub>();
- sub->dtype(loco::DataType::FLOAT32);
- sub->rank(pred->rank());
- for (uint32_t i = 0; i < sub->rank(); i++)
- {
- sub->dim(i).set(pred->dim(i).value());
- }
sub->fusedActivationFunction(luci::FusedActFunc::NONE);
+ sub->name(name + "/Sub");
loco::replace(pred).with(sub);
@@ -366,6 +293,8 @@ bool fuse_sub_with_conv(luci::CircleSub *sub)
if (!update_conv_bias_with_beta(conv, beta, false))
return false;
+ luci::add_origin(conv, luci::get_origin(sub));
+
auto pred = sub->x();
loco::replace(sub).with(pred);
@@ -442,6 +371,7 @@ bool fuse_add_with_conv(luci::CircleAdd *add, std::vector<luci::CircleSub *> &su
if (!update_conv_bias_with_beta(conv, beta, true))
return false;
+ luci::add_origin(conv, luci::get_origin(add));
loco::replace(add).with(pred);
add->drop();
@@ -462,6 +392,8 @@ bool fuse_add_with_conv(luci::CircleAdd *add, std::vector<luci::CircleSub *> &su
if (!update_conv_bias_with_beta(conv, beta, true))
return false;
+ luci::add_origin(conv, luci::get_origin(add));
+
auto relu = *loco::succs(add).begin();
auto relu_node = loco::must_cast<luci::CircleRelu *>(relu);
assert(relu_node != nullptr);
@@ -471,6 +403,7 @@ bool fuse_add_with_conv(luci::CircleAdd *add, std::vector<luci::CircleSub *> &su
add->drop();
sub_list.push_back(insert_sub(pred, beta));
+ luci::add_origin(sub_list.back(), luci::get_origin(add));
relu_node->features(pred);
@@ -530,6 +463,11 @@ bool fuse_mul_with_conv(luci::CircleMul *mul)
// Update CONV weights
update_conv_weights_with_gamma(conv, gamma);
+
+ // Update origin
+ // TODO need to remove const
+ luci::add_origin(const_cast<luci::CircleConv2D *>(conv),
+ luci::get_origin(loco::must_cast<luci::CircleNode *>(mul)));
}
loco::replace(mul).with(pred_node);
@@ -568,6 +506,8 @@ bool swap_mul_add(luci::CircleAdd *add, std::vector<luci::CircleMul *> &mul_list
if (!is_batchnorm_add(add, mul, beta))
return false;
+ if (add->fusedActivationFunction() != luci::FusedActFunc::RELU)
+ return false;
if (loco::succs(mul).size() != 1)
return false;
@@ -582,8 +522,13 @@ bool swap_mul_add(luci::CircleAdd *add, std::vector<luci::CircleMul *> &mul_list
return false;
// Insert Relu at the bottom
+ auto name = add->name();
+ assert(name.length() > 0);
+
auto relu = add->graph()->nodes()->create<luci::CircleRelu>();
relu->features(mul);
+ relu->name(name + "/Relu");
+ luci::add_origin(relu, luci::get_origin(add));
loco::replace(add).with(relu);
// Replace beta <- beta / gamma
diff --git a/compiler/luci/pass/src/FusePreActivationBatchNormPass.test.cpp b/compiler/luci/pass/src/FusePreActivationBatchNormPass.test.cpp
index a79b5bd5d..3d5791c9e 100644
--- a/compiler/luci/pass/src/FusePreActivationBatchNormPass.test.cpp
+++ b/compiler/luci/pass/src/FusePreActivationBatchNormPass.test.cpp
@@ -16,6 +16,8 @@
#include "FusePreActivationBatchNormPassInternal.h"
+#include "luci/Pass/FusePreActivationBatchNormPass.h"
+
#include <luci/IR/CircleNodes.h>
#include <math.h>
@@ -148,6 +150,22 @@ public:
conv_filter->at<loco::DataType::FLOAT32>(i * out_size + j) = i * out_size + j;
}
}
+
+ pred_conv->name("pred_conv");
+ pred_conv_filter->name("pred_conv_filter");
+ pred_conv_bias->name("pred_conv_bias");
+ pred_conv2->name("pred_conv2");
+ pred_conv2_filter->name("pred_conv2_filter");
+ pred_conv2_bias->name("pred_conv2_bias");
+ pred_add->name("pred_add");
+ mul->name("mul");
+ mul_gamma->name("mul_gamma");
+ add->name("add");
+ add_beta->name("add_beta");
+ conv->name("conv");
+ conv_filter->name("conv_filter");
+ conv_bias->name("conv_bias");
+ succ_add->name("succ_add");
}
public:
@@ -171,6 +189,13 @@ public:
} // namespace
+TEST(FusePreActivationBatchNormPassTest, name)
+{
+ luci::FusePreActivationBatchNormPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
TEST(FusePreActivationBatchNorm, swap_mul_add)
{
SimpleGraph g;
diff --git a/compiler/luci/pass/src/MakeBatchNormGammaPositivePass.cpp b/compiler/luci/pass/src/MakeBatchNormGammaPositivePass.cpp
index 281d1b081..96776dc92 100644
--- a/compiler/luci/pass/src/MakeBatchNormGammaPositivePass.cpp
+++ b/compiler/luci/pass/src/MakeBatchNormGammaPositivePass.cpp
@@ -16,6 +16,8 @@
#include "luci/Pass/MakeBatchNormGammaPositivePass.h"
+#include "BatchNormPatternFinder.h"
+
#include <luci/IR/CircleNodes.h>
namespace
@@ -39,71 +41,27 @@ bool negative_gamma_to_positive(luci::CircleConst *gamma)
return changed;
}
-// Check if add is batchnorm add
-bool is_batchnorm_add(const luci::CircleAdd *add)
+bool make_positive_gamma(luci::CircleAdd *add)
{
- auto x = dynamic_cast<luci::CircleConst *>(add->x());
- auto y = dynamic_cast<luci::CircleConst *>(add->y());
-
- luci::CircleConst *constant = nullptr;
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *beta = nullptr;
+ luci::CircleConst *gamma = nullptr;
+ luci::CircleNode *pred = nullptr;
- if (x != nullptr && y == nullptr)
- constant = x;
- else if (x == nullptr && y != nullptr)
- constant = y;
- else
+ if (!is_batchnorm_add(add, mul, beta))
return false;
- if (constant->rank() != 1)
+ if (loco::succs(mul).size() != 1)
return false;
+ if (!is_batchnorm_mul(mul, pred, gamma))
+ return false;
+ assert(pred == add);
// Only support Relu
if (add->fusedActivationFunction() != luci::FusedActFunc::RELU)
return false;
- auto channel_dim = constant->dim(0);
- if (!(channel_dim == add->dim(add->rank() - 1)))
- return false;
-
- return true;
-}
-
-// Check if mul is batchnorm mul
-bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleConst *&gamma)
-{
- auto x = dynamic_cast<luci::CircleConst *>(mul->x());
- auto y = dynamic_cast<luci::CircleConst *>(mul->y());
-
- luci::CircleConst *constant = nullptr;
-
- if (x != nullptr && y == nullptr)
- constant = x;
- else if (x == nullptr && y != nullptr)
- constant = y;
- else
- return false;
-
- if (constant->rank() != 1)
- return false;
-
- auto channel_dim = constant->dim(0);
- if (!(channel_dim == mul->dim(mul->rank() - 1)))
- return false;
-
- // Check successor is batchnorm add
- auto succs = loco::succs(mul);
- if (succs.size() != 1)
- return false;
-
- auto add = dynamic_cast<luci::CircleAdd *>(*succs.begin());
- if (add == nullptr)
- return false;
-
- if (!is_batchnorm_add(add))
- return false;
-
- gamma = constant;
- return true;
+ return negative_gamma_to_positive(gamma);
}
} // namespace
@@ -111,18 +69,29 @@ bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleConst *&gamma)
namespace luci
{
+/**
+ * Make negative gamma values of Mul-Add (as BatchNorm) to a small positive value (1e-10)
+ *
+ * PATTERN:
+ * |
+ * [CircleNode] [CircleConst](as gamma)
+ * | |
+ * [CircleMul] [CircleConst]
+ * | |
+ * [CircleAdd]
+ * |
+ */
bool MakeBatchNormGammaPositivePass::run(loco::Graph *g)
{
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- auto mul = dynamic_cast<luci::CircleMul *>(node);
- if (mul == nullptr)
+ auto add = dynamic_cast<luci::CircleAdd *>(node);
+ if (add == nullptr)
continue;
- luci::CircleConst *gamma;
- if (is_batchnorm_mul(mul, gamma))
- changed = negative_gamma_to_positive(gamma);
+ if (make_positive_gamma(add))
+ changed = true;
}
return changed;
}
diff --git a/compiler/luci/pass/src/MakeBatchNormGammaPositivePass.test.cpp b/compiler/luci/pass/src/MakeBatchNormGammaPositivePass.test.cpp
new file mode 100644
index 000000000..83093edc8
--- /dev/null
+++ b/compiler/luci/pass/src/MakeBatchNormGammaPositivePass.test.cpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/MakeBatchNormGammaPositivePass.h"
+
+#include <gtest/gtest.h>
+
+TEST(MakeBatchNormGammaPositivePassTest, name)
+{
+ luci::MakeBatchNormGammaPositivePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp b/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp
deleted file mode 100644
index beb962a05..000000000
--- a/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "luci/Pass/MigrateLegacyShapeDtypePass.h"
-
-#include <loco/Service/ShapeInference.h>
-#include <loco/Service/TypeInference.h>
-
-#include <luci/IR/CircleNodes.h>
-
-#include <loco.h>
-
-namespace
-{
-
-bool has_same_shape(luci::CircleNode *node, loco::TensorShape shape)
-{
- if (node->rank() != shape.rank())
- return false;
-
- for (uint32_t i = 0; i < shape.rank(); ++i)
- if (!(node->dim(i) == shape.dim(i)))
- return false;
-
- return true;
-}
-
-} // namespace
-
-namespace luci
-{
-
-bool MigrateLegacyShapeDtypePass::run(luci::Module *m)
-{
- bool changed = false;
-
- for (size_t g = 0; g < m->size(); ++g)
- {
- if (run(m->graph(g)))
- changed = true;
- }
-
- return changed;
-}
-
-bool MigrateLegacyShapeDtypePass::run(loco::Graph *g)
-{
- bool changed = false;
-
- for (auto node : loco::all_nodes(g))
- {
- auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- if (loco::shape_known(node))
- {
- auto loco_shape = loco::shape_get(node).as<loco::TensorShape>();
-
- assert(circle_node->shape_signature().rank() == 0 ||
- circle_node->shape_signature().rank() == loco_shape.rank());
-
- // When shape of loco is copied to circle node, ShapeSignature should be applied.
- loco::TensorShape new_shape;
- new_shape.rank(loco_shape.rank());
- for (uint32_t i = 0; i < loco_shape.rank(); ++i)
- {
- if (circle_node->shape_signature().rank() > 0 &&
- circle_node->shape_signature().dim(i) == -1)
- new_shape.dim(i) = 1;
- else
- new_shape.dim(i) = loco_shape.dim(i);
- }
-
- if (circle_node->shape_status() == luci::ShapeStatus::UNDEFINED ||
- !has_same_shape(circle_node, new_shape))
- {
- circle_node->rank(new_shape.rank());
- for (uint32_t i = 0; i < new_shape.rank(); ++i)
- circle_node->dim(i) = new_shape.dim(i);
-
- if (circle_node->shape_status() == luci::ShapeStatus::UNDEFINED)
- circle_node->shape_status(luci::ShapeStatus::VALID);
-
- changed = true;
- }
- }
-
- if (loco::dtype_known(node))
- {
- if (loco::dtype_get(node) != circle_node->dtype())
- {
- circle_node->dtype(loco::dtype_get(node));
- changed = true;
- }
- }
- }
-
- return changed;
-}
-
-} // namespace luci
diff --git a/compiler/luci/pass/src/ModulePhase.test.cpp b/compiler/luci/pass/src/ModulePhase.test.cpp
new file mode 100644
index 000000000..5d92c59f4
--- /dev/null
+++ b/compiler/luci/pass/src/ModulePhase.test.cpp
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ModulePhase.h"
+
+#include "luci/Pass/CircleShapeInferencePass.h"
+
+#include <loco.h>
+
+#include <gtest/gtest.h>
+
+TEST(ModulePhaseTest, saturate)
+{
+ auto m = luci::make_module();
+ auto g = loco::make_graph();
+ m->add(std::move(g));
+
+ luci::Phase phase;
+
+ // Any Pass will do for testing
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+
+ luci::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{m.get()};
+ phase_runner.run(phase);
+
+ SUCCEED();
+}
+
+TEST(ModulePhaseTest, restart)
+{
+ auto m = luci::make_module();
+ auto g = loco::make_graph();
+ m->add(std::move(g));
+
+ luci::Phase phase;
+
+ // Any Pass will do for testing
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+
+ luci::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{m.get()};
+ phase_runner.run(phase);
+
+ SUCCEED();
+}
diff --git a/compiler/luci/pass/src/PassTestGraphs.h b/compiler/luci/pass/src/PassTestGraphs.h
new file mode 100644
index 000000000..f5ae24f0b
--- /dev/null
+++ b/compiler/luci/pass/src/PassTestGraphs.h
@@ -0,0 +1,142 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PASS_TEST_GRAPHS_H__
+#define __LUCI_PASS_TEST_GRAPHS_H__
+
+#include <loco.h>
+#include <luci/IR/CircleNodes.h>
+
+namespace luci
+{
+
+/**
+ * ConstantFoldingTestGraph is a base class for testing
+ * constant folding passes. It creates Input and Output
+ * in the below graph. Child classes must implement Connector
+ * and Folded pattern.
+ *
+ * [Input] [Folded pattern] (Implemented by child class)
+ * \ /
+ * [Connector] (Implemented by child class)
+ * |
+ * [Output]
+ *
+ * Connector should satisfy the below conditions
+ * - Input type == Output type == Folded pattern type
+ * - Input shape == Output shape == Folded pattern shape
+ *
+ * For example, Add, Mul, Sub, .. can be a Connector
+ */
+class ConstantFoldingTestGraph
+{
+public:
+ ConstantFoldingTestGraph(std::vector<uint32_t> input_shape, loco::DataType input_dtype)
+ {
+ _input = _g.nodes()->create<luci::CircleInput>();
+ _output = _g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = _g.inputs()->create();
+ _input->index(graph_input->index());
+ auto graph_output = _g.outputs()->create();
+ _output->index(graph_output->index());
+
+ graph_input->dtype(input_dtype);
+ graph_output->dtype(input_dtype);
+ _input->dtype(input_dtype);
+ _output->dtype(input_dtype);
+
+ auto input_tensor_shape = std::make_unique<loco::TensorShape>();
+ input_tensor_shape->rank(input_shape.size());
+ for (int i = 0; i < input_shape.size(); i++)
+ input_tensor_shape->dim(i).set(input_shape[i]);
+ graph_input->shape(std::move(input_tensor_shape));
+
+ auto output_tensor_shape = std::make_unique<loco::TensorShape>();
+ output_tensor_shape->rank(input_shape.size());
+ for (int i = 0; i < input_shape.size(); i++)
+ output_tensor_shape->dim(i).set(input_shape[i]);
+ graph_output->shape(std::move(output_tensor_shape));
+
+ _input->rank(input_shape.size());
+ for (int i = 0; i < input_shape.size(); i++)
+ _input->dim(i).set(input_shape[i]);
+
+ _output->rank(input_shape.size());
+ for (int i = 0; i < input_shape.size(); i++)
+ _output->dim(i).set(input_shape[i]);
+
+ _input->name("input");
+ _output->name("output");
+ }
+
+ virtual void init() = 0;
+
+ virtual ~ConstantFoldingTestGraph() = default;
+
+ virtual loco::Node *createFoldedPattern() = 0;
+
+ virtual luci::CircleConst *getFoldedPattern() = 0;
+
+ loco::Graph *graph() { return &_g; }
+
+ // NOTE: we're not adding _ prefix as these class members are public
+protected:
+ loco::Graph _g;
+ luci::CircleInput *_input = nullptr;
+ luci::CircleOutput *_output = nullptr;
+};
+
+/**
+ * ConstantFoldingTestAddGraph is ConstantFoldingTestGraph
+ * whose Connector is Add.
+ */
+class ConstantFoldingAddTestGraph : public ConstantFoldingTestGraph
+{
+protected:
+ ConstantFoldingAddTestGraph(std::vector<uint32_t> input_shape, loco::DataType input_dtype)
+ : ConstantFoldingTestGraph(input_shape, input_dtype)
+ {
+ _add = _g.nodes()->create<luci::CircleAdd>();
+ _add->dtype(input_dtype);
+
+ _add->rank(input_shape.size());
+ for (int i = 0; i < input_shape.size(); i++)
+ _add->dim(i).set(input_shape[i]);
+
+ _add->x(_input);
+
+ _output->from(_add);
+
+ _add->name("add");
+ }
+
+protected:
+ void init() override { _add->y(createFoldedPattern()); }
+
+protected:
+ luci::CircleConst *getFoldedPattern() override
+ {
+ return dynamic_cast<luci::CircleConst *>(_add->y());
+ }
+
+protected:
+ luci::CircleAdd *_add = nullptr;
+};
+
+} // namespace luci
+
+#endif // __LUCI_PASS_TEST_GRAPHS_H__
diff --git a/compiler/luci/pass/src/ProgressReporter.h b/compiler/luci/pass/src/ProgressReporter.h
index cf30da735..8c6c95e65 100644
--- a/compiler/luci/pass/src/ProgressReporter.h
+++ b/compiler/luci/pass/src/ProgressReporter.h
@@ -30,7 +30,7 @@ class ProgressReporter : public logo::PhaseEventListener
{
public:
ProgressReporter(loco::Graph *graph, logo::PhaseStrategy strategy)
- : _graph{graph}, _strategy{strategy}
+ : _graph{graph}, _strategy{strategy}
{
// DO NOTHING
}
@@ -54,7 +54,7 @@ class ModuleProgressReporter : public logo::PhaseEventListener
{
public:
ModuleProgressReporter(luci::Module *module, logo::PhaseStrategy strategy)
- : _module{module}, _strategy{strategy}
+ : _module{module}, _strategy{strategy}
{
// DO NOTHING
}
diff --git a/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp b/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp
index 0f8d562e9..de973a431 100644
--- a/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp
+++ b/compiler/luci/pass/src/PropagateConcatenationQparam.test.cpp
@@ -136,30 +136,34 @@ class ConstInputConcatGraph
public:
ConstInputConcatGraph(loco::DataType quant_type)
{
- concat_node.dtype(quant_type);
- concat_node.fusedActivationFunction(luci::FusedActFunc::NONE);
- input_1.dtype(loco::DataType::FLOAT32);
- input_1.size<loco::DataType::FLOAT32>(5);
+ concat_node = g.nodes()->create<luci::CircleConcatenation>(2);
+ input_1 = g.nodes()->create<luci::CircleConst>();
+ input_2 = g.nodes()->create<luci::CircleConv2D>();
+
+ concat_node->dtype(quant_type);
+ concat_node->fusedActivationFunction(luci::FusedActFunc::NONE);
+ input_1->dtype(loco::DataType::FLOAT32);
+ input_1->size<loco::DataType::FLOAT32>(5);
for (int i = 0; i < 5; i++)
{
// Set data {-2, -1, 0, 1, 2}
- input_1.at<loco::DataType::FLOAT32>(i) = i - 2.0;
+ input_1->at<loco::DataType::FLOAT32>(i) = i - 2.0;
}
- input_2.dtype(quant_type);
+ input_2->dtype(quant_type);
- concat_node.values(0, &input_1);
- concat_node.values(1, &input_2);
+ concat_node->values(0, input_1);
+ concat_node->values(1, input_2);
if (quant_type == loco::DataType::U8)
{
- addQuantParam(concat_node, {0.1}, {10});
- addQuantParam(input_2, {2.0}, {2});
+ addQuantParam(*concat_node, {0.1}, {10});
+ addQuantParam(*input_2, {2.0}, {2});
}
else if (quant_type == loco::DataType::S16)
{
- addQuantParam(concat_node, {0.1}, {0});
- addQuantParam(input_2, {2.0}, {0});
+ addQuantParam(*concat_node, {0.1}, {0});
+ addQuantParam(*input_2, {2.0}, {0});
}
else
{
@@ -167,16 +171,11 @@ public:
}
}
- ~ConstInputConcatGraph()
- {
- concat_node.values(0, nullptr);
- concat_node.values(1, nullptr);
- }
-
public:
- luci::CircleConcatenation concat_node{2};
- luci::CircleConst input_1;
- luci::CircleConv2D input_2;
+ loco::Graph g;
+ luci::CircleConcatenation *concat_node = nullptr;
+ luci::CircleConst *input_1 = nullptr;
+ luci::CircleConv2D *input_2 = nullptr;
};
} // namespace
@@ -223,19 +222,20 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8)
// input_1 is const. const values are quantized with the qparam of concat
ConstInputConcatGraph cg(loco::DataType::U8);
- luci::propagate_concat_quantparam(&cg.concat_node, loco::DataType::U8);
- EXPECT_FLOAT_EQ(0.1, cg.concat_node.quantparam()->scale[0]);
- EXPECT_EQ(10, cg.concat_node.quantparam()->zerop[0]);
- EXPECT_FLOAT_EQ(0.1, cg.input_1.quantparam()->scale[0]);
- EXPECT_EQ(10, cg.input_1.quantparam()->zerop[0]);
- EXPECT_FLOAT_EQ(0.1, cg.input_2.quantparam()->scale[0]);
- EXPECT_EQ(10, cg.input_2.quantparam()->zerop[0]);
- EXPECT_EQ(loco::DataType::U8, cg.input_1.dtype());
- EXPECT_EQ(0, cg.input_1.at<loco::DataType::U8>(0));
- EXPECT_EQ(0, cg.input_1.at<loco::DataType::U8>(1));
- EXPECT_EQ(10, cg.input_1.at<loco::DataType::U8>(2));
- EXPECT_EQ(20, cg.input_1.at<loco::DataType::U8>(3));
- EXPECT_EQ(30, cg.input_1.at<loco::DataType::U8>(4));
+ luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::U8);
+ EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
+ EXPECT_EQ(10, cg.concat_node->quantparam()->zerop[0]);
+ const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
+ EXPECT_FLOAT_EQ(0.1, cg_input_1->quantparam()->scale[0]);
+ EXPECT_EQ(10, cg_input_1->quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(0.1, cg.input_2->quantparam()->scale[0]);
+ EXPECT_EQ(10, cg.input_2->quantparam()->zerop[0]);
+ EXPECT_EQ(loco::DataType::U8, cg_input_1->dtype());
+ EXPECT_EQ(0, cg_input_1->at<loco::DataType::U8>(0));
+ EXPECT_EQ(0, cg_input_1->at<loco::DataType::U8>(1));
+ EXPECT_EQ(10, cg_input_1->at<loco::DataType::U8>(2));
+ EXPECT_EQ(20, cg_input_1->at<loco::DataType::U8>(3));
+ EXPECT_EQ(30, cg_input_1->at<loco::DataType::U8>(4));
}
TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8_NEG)
@@ -260,20 +260,21 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_u8_NEG)
// concat has fused activation function and input_1 is const.
// const values are quantized using its min/max
ConstInputConcatGraph cg(loco::DataType::U8);
- cg.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU);
- luci::propagate_concat_quantparam(&cg.concat_node, loco::DataType::U8);
- EXPECT_FLOAT_EQ(0.1, cg.concat_node.quantparam()->scale[0]);
- EXPECT_EQ(10, cg.concat_node.quantparam()->zerop[0]);
- EXPECT_FLOAT_EQ(0.015686275, cg.input_1.quantparam()->scale[0]);
- EXPECT_EQ(128, cg.input_1.quantparam()->zerop[0]);
- EXPECT_FLOAT_EQ(2.0, cg.input_2.quantparam()->scale[0]);
- EXPECT_EQ(2, cg.input_2.quantparam()->zerop[0]);
- EXPECT_EQ(loco::DataType::U8, cg.input_1.dtype());
- EXPECT_EQ(quantize(-2, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(0));
- EXPECT_EQ(quantize(-1, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(1));
- EXPECT_EQ(quantize(0, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(2));
- EXPECT_EQ(quantize(1, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(3));
- EXPECT_EQ(quantize(2, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::U8>(4));
+ cg.concat_node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::U8);
+ EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
+ EXPECT_EQ(10, cg.concat_node->quantparam()->zerop[0]);
+ const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
+ EXPECT_FLOAT_EQ(0.015686275, cg_input_1->quantparam()->scale[0]);
+ EXPECT_EQ(128, cg_input_1->quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(2.0, cg.input_2->quantparam()->scale[0]);
+ EXPECT_EQ(2, cg.input_2->quantparam()->zerop[0]);
+ EXPECT_EQ(loco::DataType::U8, cg_input_1->dtype());
+ EXPECT_EQ(quantize(-2, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::U8>(0));
+ EXPECT_EQ(quantize(-1, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::U8>(1));
+ EXPECT_EQ(quantize(0, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::U8>(2));
+ EXPECT_EQ(quantize(1, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::U8>(3));
+ EXPECT_EQ(quantize(2, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::U8>(4));
}
TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16)
@@ -318,19 +319,20 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16)
// input_1 is const. const values are quantized with the qparam of concat
ConstInputConcatGraph cg(loco::DataType::S16);
- luci::propagate_concat_quantparam(&cg.concat_node, loco::DataType::S16);
- EXPECT_FLOAT_EQ(0.1, cg.concat_node.quantparam()->scale[0]);
- EXPECT_EQ(0, cg.concat_node.quantparam()->zerop[0]);
- EXPECT_FLOAT_EQ(0.1, cg.input_1.quantparam()->scale[0]);
- EXPECT_EQ(0, cg.input_1.quantparam()->zerop[0]);
- EXPECT_FLOAT_EQ(0.1, cg.input_2.quantparam()->scale[0]);
- EXPECT_EQ(0, cg.input_2.quantparam()->zerop[0]);
- EXPECT_EQ(loco::DataType::S16, cg.input_1.dtype());
- EXPECT_EQ(-20, cg.input_1.at<loco::DataType::S16>(0));
- EXPECT_EQ(-10, cg.input_1.at<loco::DataType::S16>(1));
- EXPECT_EQ(0, cg.input_1.at<loco::DataType::S16>(2));
- EXPECT_EQ(10, cg.input_1.at<loco::DataType::S16>(3));
- EXPECT_EQ(20, cg.input_1.at<loco::DataType::S16>(4));
+ luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::S16);
+ EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
+ EXPECT_EQ(0, cg.concat_node->quantparam()->zerop[0]);
+ const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
+ EXPECT_FLOAT_EQ(0.1, cg_input_1->quantparam()->scale[0]);
+ EXPECT_EQ(0, cg_input_1->quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(0.1, cg.input_2->quantparam()->scale[0]);
+ EXPECT_EQ(0, cg.input_2->quantparam()->zerop[0]);
+ EXPECT_EQ(loco::DataType::S16, cg_input_1->dtype());
+ EXPECT_EQ(-20, cg_input_1->at<loco::DataType::S16>(0));
+ EXPECT_EQ(-10, cg_input_1->at<loco::DataType::S16>(1));
+ EXPECT_EQ(0, cg_input_1->at<loco::DataType::S16>(2));
+ EXPECT_EQ(10, cg_input_1->at<loco::DataType::S16>(3));
+ EXPECT_EQ(20, cg_input_1->at<loco::DataType::S16>(4));
}
TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16_NEG)
@@ -355,18 +357,19 @@ TEST(PropagateConcatenationQparam, propagate_concat_quantparam_i16_NEG)
// concat has fused activation function and input_1 is const.
// const values are quantized using its min/max
ConstInputConcatGraph cg(loco::DataType::S16);
- cg.concat_node.fusedActivationFunction(luci::FusedActFunc::RELU);
- luci::propagate_concat_quantparam(&cg.concat_node, loco::DataType::S16);
- EXPECT_FLOAT_EQ(0.1, cg.concat_node.quantparam()->scale[0]);
- EXPECT_EQ(0, cg.concat_node.quantparam()->zerop[0]);
- EXPECT_FLOAT_EQ(0.000061037, cg.input_1.quantparam()->scale[0]);
- EXPECT_EQ(0, cg.input_1.quantparam()->zerop[0]);
- EXPECT_FLOAT_EQ(2.0, cg.input_2.quantparam()->scale[0]);
- EXPECT_EQ(0, cg.input_2.quantparam()->zerop[0]);
- EXPECT_EQ(loco::DataType::S16, cg.input_1.dtype());
- EXPECT_EQ(quantize(-2, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(0));
- EXPECT_EQ(quantize(-1, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(1));
- EXPECT_EQ(quantize(0, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(2));
- EXPECT_EQ(quantize(1, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(3));
- EXPECT_EQ(quantize(2, cg.input_1.quantparam()), cg.input_1.at<loco::DataType::S16>(4));
+ cg.concat_node->fusedActivationFunction(luci::FusedActFunc::RELU);
+ luci::propagate_concat_quantparam(cg.concat_node, loco::DataType::S16);
+ EXPECT_FLOAT_EQ(0.1, cg.concat_node->quantparam()->scale[0]);
+ EXPECT_EQ(0, cg.concat_node->quantparam()->zerop[0]);
+ const auto cg_input_1 = loco::must_cast<luci::CircleConst *>(cg.concat_node->values(0));
+ EXPECT_FLOAT_EQ(0.000061037, cg_input_1->quantparam()->scale[0]);
+ EXPECT_EQ(0, cg_input_1->quantparam()->zerop[0]);
+ EXPECT_FLOAT_EQ(2.0, cg.input_2->quantparam()->scale[0]);
+ EXPECT_EQ(0, cg.input_2->quantparam()->zerop[0]);
+ EXPECT_EQ(loco::DataType::S16, cg_input_1->dtype());
+ EXPECT_EQ(quantize(-2, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::S16>(0));
+ EXPECT_EQ(quantize(-1, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::S16>(1));
+ EXPECT_EQ(quantize(0, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::S16>(2));
+ EXPECT_EQ(quantize(1, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::S16>(3));
+ EXPECT_EQ(quantize(2, cg_input_1->quantparam()), cg_input_1->at<loco::DataType::S16>(4));
}
diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.cpp
index af83cd83b..26282086b 100644
--- a/compiler/luci/pass/src/PropagateQuantParamPass.cpp
+++ b/compiler/luci/pass/src/PropagateQuantParamPass.cpp
@@ -91,9 +91,8 @@ bool PropagateQuantParamPass::run(loco::Graph *g)
INFO(l) << "PropagateQuantParamPass visit node: " << circle_node->name() << std::endl;
PropagateQuantParam pqp;
- changed = circle_node->accept(&pqp);
- if (changed)
- break;
+ if (circle_node->accept(&pqp))
+ changed = true;
}
return changed;
diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp
index 15adbfc01..ed1f96828 100644
--- a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp
+++ b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp
@@ -83,6 +83,13 @@ public:
} // namespace
+TEST(PropagateQuantParamPassTest, name)
+{
+ luci::PropagateQuantParamPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
TEST(PropagateQuantParam, simple)
{
SimpleGraph g;
diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp
index fa0141114..85d600e47 100644
--- a/compiler/luci/pass/src/QuantizationUtils.cpp
+++ b/compiler/luci/pass/src/QuantizationUtils.cpp
@@ -96,7 +96,7 @@ void asymmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float
data = data < nudged_min ? nudged_min : data;
data = data > nudged_max ? nudged_max : data;
quantized_values[i] =
- static_cast<int32_t>(std::round((data - nudged_min) * scaling_factor_inv));
+ static_cast<int32_t>(std::round((data - nudged_min) * scaling_factor_inv));
}
node->dtype(loco::DataType::U8); // change the type of tensor
@@ -133,14 +133,14 @@ void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float
for (uint32_t i = 0; i < size; ++i)
{
node->at<loco::DataType::S16>(i) =
- std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
}
}
void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &zp,
float &nudged_min, float &nudged_max)
{
- assert(min != max);
+ assert(min <= max);
const int32_t kMaxScale = std::numeric_limits<int16_t>::max();
const int32_t kMinScale = -kMaxScale;
@@ -158,8 +158,8 @@ void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &
scale_factor_from_max_side = rmax / qmax_double;
scaling_factor = scale_factor_from_min_side > scale_factor_from_max_side
- ? scale_factor_from_min_side
- : scale_factor_from_max_side;
+ ? scale_factor_from_min_side
+ : scale_factor_from_max_side;
zp = 0;
nudged_min = static_cast<float>(qmin_double * scaling_factor);
nudged_max = static_cast<float>(qmax_double * scaling_factor);
@@ -226,7 +226,8 @@ void compute_asym_scale_zp(float min, float max, float &scaling_factor, int64_t
zp = nudged_zero_point;
}
-bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension, int &channel_dim_index)
+bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension,
+ int32_t &channel_dim_index)
{
auto succs = loco::succs(node);
@@ -304,7 +305,7 @@ bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension, int
uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices)
{
return indices[0] * dimension.dim(1).value() * dimension.dim(2).value() *
- dimension.dim(3).value() +
+ dimension.dim(3).value() +
indices[1] * dimension.dim(2).value() * dimension.dim(3).value() +
indices[2] * dimension.dim(3).value() + indices[3];
}
diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h
index 22a5cf1ee..c8c558d3c 100644
--- a/compiler/luci/pass/src/QuantizationUtils.h
+++ b/compiler/luci/pass/src/QuantizationUtils.h
@@ -37,7 +37,8 @@ void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float
float &scaling_factor, int64_t &zp, float &nudged_min,
float &nudged_max);
-bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension, int &channel_dim_index);
+bool get_channel_dim_index(CircleConst *node, loco::TensorShape &dimension,
+ int32_t &channel_dim_index);
uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices);
diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
index e10c4bb4d..e99c7b389 100644
--- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
+++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
@@ -24,33 +24,29 @@
#include <iostream>
#include <cmath>
-
-namespace luci
-{
+#include <functional>
namespace
{
-void cal_minmax_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max)
+using namespace luci;
+using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>;
+
+void iterate_per_channel(CircleConst *node, IterFunc func)
{
loco::TensorShape dimension;
dimension.rank(4);
uint32_t indices[4] = {
- 0,
+ 0,
};
- int channel_dim_index{0};
- int size{0};
+ int32_t channel_dim_index{0};
if (!get_channel_dim_index(node, dimension, channel_dim_index))
{
assert(false);
return;
}
- size = dimension.dim(channel_dim_index).value();
- std::vector<bool> has_min_max_value(size, false);
- min.resize(size);
- max.resize(size);
for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
{
for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
@@ -59,25 +55,57 @@ void cal_minmax_per_channel(CircleConst *node, std::vector<float> &min, std::vec
{
for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
{
- int channel_idx = indices[channel_dim_index];
- auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
- if (has_min_max_value[channel_idx])
- {
- min[channel_idx] = data < min[channel_idx] ? data : min[channel_idx];
- max[channel_idx] = data > max[channel_idx] ? data : max[channel_idx];
- }
- else
- {
- min[channel_idx] = data;
- max[channel_idx] = data;
- has_min_max_value[channel_idx] = true;
- }
+ func(indices, dimension, channel_dim_index);
}
}
}
}
}
+} // namespace
+
+namespace luci
+{
+
+namespace
+{
+
+void cal_minmax_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max)
+{
+ loco::TensorShape dimension;
+ dimension.rank(4);
+ int32_t channel_dim_index{0};
+
+ if (!get_channel_dim_index(node, dimension, channel_dim_index))
+ {
+ assert(false);
+ return;
+ }
+ auto size = dimension.dim(channel_dim_index).value();
+
+ std::vector<bool> has_min_max_value(size, false);
+ min.resize(size);
+ max.resize(size);
+
+ auto cal_minmax = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ if (has_min_max_value[channel_idx])
+ {
+ min[channel_idx] = data < min[channel_idx] ? data : min[channel_idx];
+ max[channel_idx] = data > max[channel_idx] ? data : max[channel_idx];
+ }
+ else
+ {
+ min[channel_idx] = data;
+ max[channel_idx] = data;
+ has_min_max_value[channel_idx] = true;
+ }
+ };
+
+ iterate_per_channel(node, cal_minmax);
+}
+
void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max,
std::vector<float> &scaling_factor, std::vector<int64_t> &zp,
std::vector<float> &nudged_min, std::vector<float> &nudged_max)
@@ -94,45 +122,24 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vec
compute_sym_scale_zp(min[i], max[i], scaling_factor[i], zp[i], nudged_min[i], nudged_max[i]);
}
- loco::TensorShape dimension;
- dimension.rank(4);
- uint32_t indices[4] = {
- 0,
+ auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data;
+ data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data;
+ quantized_values[cal_offset(dimension, indices)] =
+ static_cast<int32_t>(std::round(data * scaling_factor_inv));
};
- int channel_dim_index{0};
-
- if (!get_channel_dim_index(node, dimension, channel_dim_index))
- {
- assert(false);
- return;
- }
- for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
- {
- for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
- {
- for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
- {
- for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
- {
- int channel_idx = indices[channel_dim_index];
- const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
- auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
- data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data;
- data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data;
- quantized_values[cal_offset(dimension, indices)] =
- static_cast<int32_t>(std::round(data * scaling_factor_inv));
- }
- }
- }
- }
+ iterate_per_channel(node, quantize);
node->dtype(loco::DataType::S16); // change the type of tensor
node->size<loco::DataType::S16>(size); // resize tensor
for (uint32_t i = 0; i < size; ++i)
{
node->at<loco::DataType::S16>(i) =
- std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
}
}
@@ -142,35 +149,14 @@ void sym_wdequant_per_channel(CircleConst *node, std::vector<float> &scaling_fac
uint32_t size = node->size<loco::DataType::S16>();
std::vector<float> dequantized_values(size);
- loco::TensorShape dimension;
- dimension.rank(4);
- uint32_t indices[4] = {
- 0,
+ auto dequantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ auto data = node->at<loco::DataType::S16>(cal_offset(dimension, indices));
+ dequantized_values[cal_offset(dimension, indices)] =
+ static_cast<float>(data) * scaling_factor[channel_idx];
};
- int channel_dim_index{0};
-
- if (!get_channel_dim_index(node, dimension, channel_dim_index))
- {
- assert(false);
- return;
- }
- for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
- {
- for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
- {
- for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
- {
- for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
- {
- int channel_idx = indices[channel_dim_index];
- auto data = node->at<loco::DataType::S16>(cal_offset(dimension, indices));
- dequantized_values[cal_offset(dimension, indices)] =
- static_cast<float>(data) * scaling_factor[channel_idx];
- }
- }
- }
- }
+ iterate_per_channel(node, dequantize);
node->dtype(loco::DataType::FLOAT32); // change the type of tensor
node->size<loco::DataType::FLOAT32>(size); // resize tensor
@@ -198,38 +184,17 @@ void asymmetric_wquant_per_channel(CircleConst *node, std::vector<float> &min,
compute_asym_scale_zp(min[i], max[i], scaling_factor[i], zp[i], nudged_min[i], nudged_max[i]);
}
- loco::TensorShape dimension;
- dimension.rank(4);
- uint32_t indices[4] = {
- 0,
+ auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data;
+ data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data;
+ quantized_values[cal_offset(dimension, indices)] =
+ static_cast<int32_t>(std::round((data - nudged_min[channel_idx]) * scaling_factor_inv));
};
- int channel_dim_index{0};
-
- if (!get_channel_dim_index(node, dimension, channel_dim_index))
- {
- assert(false);
- return;
- }
- for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
- {
- for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
- {
- for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
- {
- for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
- {
- int channel_idx = indices[channel_dim_index];
- const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
- auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
- data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data;
- data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data;
- quantized_values[cal_offset(dimension, indices)] = static_cast<int32_t>(
- std::round((data - nudged_min[channel_idx]) * scaling_factor_inv));
- }
- }
- }
- }
+ iterate_per_channel(node, quantize);
node->dtype(loco::DataType::U8); // change the type of tensor
node->size<loco::DataType::U8>(size); // resize tensor
@@ -246,35 +211,14 @@ void asymmetric_wdequant_per_channel(CircleConst *node, std::vector<float> &scal
uint32_t size = node->size<loco::DataType::U8>();
std::vector<float> dequantized_values(size);
- loco::TensorShape dimension;
- dimension.rank(4);
- uint32_t indices[4] = {
- 0,
+ auto dequantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ auto data = node->at<loco::DataType::U8>(cal_offset(dimension, indices));
+ dequantized_values[cal_offset(dimension, indices)] =
+ static_cast<float>(data) * scaling_factor[channel_idx] + nudged_min[channel_idx];
};
- int channel_dim_index{0};
-
- if (!get_channel_dim_index(node, dimension, channel_dim_index))
- {
- assert(false);
- return;
- }
- for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
- {
- for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
- {
- for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
- {
- for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
- {
- int channel_idx = indices[channel_dim_index];
- auto data = node->at<loco::DataType::U8>(cal_offset(dimension, indices));
- dequantized_values[cal_offset(dimension, indices)] =
- static_cast<float>(data) * scaling_factor[channel_idx] + nudged_min[channel_idx];
- }
- }
- }
- }
+ iterate_per_channel(node, dequantize);
node->dtype(loco::DataType::FLOAT32); // change the type of tensor
node->size<loco::DataType::FLOAT32>(size); // resize tensor
@@ -311,7 +255,7 @@ struct QuantizeDequantizeWeights final : public luci::CircleNodeMutableVisitor<b
{
QuantizeDequantizeWeights(loco::DataType input, loco::DataType output,
QuantizationGranularity granularity)
- : input_type(input), output_type(output), granularity(granularity)
+ : input_type(input), output_type(output), granularity(granularity)
{
}
diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp
new file mode 100644
index 000000000..f226253c2
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.test.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
+
+#include <gtest/gtest.h>
+
+TEST(QuantizeDequantizeWeightsPassTest, name)
+{
+ luci::QuantizeDequantizeWeightsPass pass(loco::DataType::FLOAT32, loco::DataType::U8,
+ luci::QuantizationGranularity::LayerWise);
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
index f6eebe3b9..4707ad0e9 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
@@ -19,12 +19,51 @@
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Service/Nodes/CircleConst.h>
#include <luci/Log.h>
#include <oops/UserExn.h>
#include <iostream>
#include <cmath>
+#include <functional>
+
+namespace
+{
+
+using namespace luci;
+using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>;
+
+void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc func)
+{
+ loco::TensorShape dimension;
+ dimension.rank(4);
+ uint32_t indices[4] = {
+ 0,
+ };
+
+ if (!get_channel_dim_index(node, dimension, channel_dim_index))
+ {
+ assert(false);
+ return;
+ }
+
+ for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
+ {
+ for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
+ {
+ for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
+ {
+ for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
+ {
+ func(indices, dimension, channel_dim_index);
+ }
+ }
+ }
+ }
+}
+
+} // namespace
namespace luci
{
@@ -32,6 +71,30 @@ namespace luci
namespace
{
+// Create a new const node from an existing node.
+// The new node has the following characteristics
+// type: T
+// shape: same with 'node' (given as an argument)
+// buffer size: 'size' (given as an argument)
+// Note that contents are not filled in this function.
+template <loco::DataType T>
+luci::CircleConst *create_empty_const_from(luci::CircleConst *node, uint32_t size)
+{
+ auto new_node = node->graph()->nodes()->create<CircleConst>();
+ // TODO: We don't have any naming convention for quantized nodes yet.
+ // Fix this when we have one.
+ new_node->name(node->name());
+ new_node->dtype(T);
+ new_node->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ new_node->dim(i).set(node->dim(i).value());
+
+ new_node->size<T>(size);
+ new_node->shape_status(luci::ShapeStatus::VALID);
+
+ return new_node;
+}
+
void overwrite_quantparam(luci::CircleConcatenation *concat, luci::CircleNode *target)
{
auto concat_qparam = concat->quantparam();
@@ -44,6 +107,9 @@ void overwrite_quantparam(luci::CircleConcatenation *concat, luci::CircleNode *t
auto quantparam = std::make_unique<CircleQuantParam>();
target->quantparam(std::move(quantparam));
target_qparam = target->quantparam();
+
+ if (target_qparam == nullptr)
+ throw std::runtime_error("Creating new quant param failed");
}
target_qparam->min = concat_qparam->min;
target_qparam->max = concat_qparam->max;
@@ -79,7 +145,7 @@ void quant_const_values(luci::CircleConst *const_node, float scaling_factor, flo
const_node->size<loco::DataType::S16>(size); // resize tensor
for (uint32_t i = 0; i < size; ++i)
const_node->at<loco::DataType::S16>(i) =
- std::min(32767, std::max(-32767, quantized_values[i]));
+ std::min(32767, std::max(-32767, quantized_values[i]));
break;
default:
throw std::runtime_error("Unsupported data type");
@@ -219,17 +285,16 @@ void quant_const(CircleConst *node, loco::DataType quant_type)
}
// Check if the node is the bias of Conv2D, DepthwiseConv2D, FullyConnected, or TransposeConv layer
-// If true, return <input, weight> pair of the successor node (used to quantize bias)
-// If flase, return <nullptr, nullptr>
-std::pair<loco::Node *, loco::Node *> get_input_weight_of_bias(CircleNode *node)
+// Returns a list of <input, weights, output> vectors for the above operators.
+// Note that it returns a 'list' because bias can be used by multiple operators.
+std::vector<std::vector<loco::Node *>> get_input_weight_output_of_bias(CircleNode *node)
{
+ std::vector<std::vector<loco::Node *>> result;
auto circle_const = dynamic_cast<CircleConst *>(node);
if (circle_const == nullptr)
- return std::make_pair(nullptr, nullptr);
+ return result;
auto succs = loco::succs(node);
- if (succs.size() != 1) // assume bias is used by only one node
- return std::make_pair(nullptr, nullptr);
for (auto out : succs)
{
@@ -238,35 +303,39 @@ std::pair<loco::Node *, loco::Node *> get_input_weight_of_bias(CircleNode *node)
{
assert(conv->input() != nullptr);
assert(conv->filter() != nullptr);
- return std::make_pair(conv->input(), conv->filter());
+ result.push_back({conv->input(), conv->filter(), conv});
+ continue;
}
auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out);
if (dw_conv != nullptr && dw_conv->bias() == circle_const)
{
assert(dw_conv->input() != nullptr);
assert(dw_conv->filter() != nullptr);
- return std::make_pair(dw_conv->input(), dw_conv->filter());
+ result.push_back({dw_conv->input(), dw_conv->filter(), dw_conv});
+ continue;
}
auto fc = dynamic_cast<CircleFullyConnected *>(out);
if (fc != nullptr && fc->bias() == circle_const)
{
assert(fc->input() != nullptr);
assert(fc->weights() != nullptr);
- return std::make_pair(fc->input(), fc->weights());
+ result.push_back({fc->input(), fc->weights(), fc});
+ continue;
}
auto tconv = dynamic_cast<CircleTransposeConv *>(out);
if (tconv != nullptr && tconv->bias() == circle_const)
{
assert(tconv->outBackprop() != nullptr);
assert(tconv->filter() != nullptr);
- return std::make_pair(tconv->outBackprop(), tconv->filter());
+ result.push_back({tconv->outBackprop(), tconv->filter(), tconv});
+ continue;
}
}
- return std::make_pair(nullptr, nullptr);
+ return result;
}
-void asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weight_scale,
- float *scaling_factor, int64_t *zp)
+CircleConst *asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weight_scale,
+ float *scaling_factor, int64_t *zp)
{
float scale = input_scale * weight_scale;
const float scaling_factor_inv = (scale == 0) ? 0 : 1.0 / scale;
@@ -276,24 +345,27 @@ void asym_quant_bias_per_layer(CircleConst *node, float input_scale, float weigh
for (uint32_t i = 0; i < size; ++i)
{
quantized_values[i] =
- static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
+ static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
}
- node->dtype(loco::DataType::S32); // change the type of tensor
- node->size<loco::DataType::S32>(size); // resize tensor
+ auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
+
const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
for (uint32_t i = 0; i < size; ++i)
{
- node->at<loco::DataType::S32>(i) =
- std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ new_bias->at<loco::DataType::S32>(i) =
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
}
*scaling_factor = scale;
*zp = 0;
+
+ return new_bias;
}
-void quant_bias_per_channel(CircleConst *node, float input_scale, std::vector<float> &weight_scale,
- std::vector<float> &scaling_factor, std::vector<int64_t> &zp)
+CircleConst *quant_bias_per_channel(CircleConst *node, float input_scale,
+ std::vector<float> &weight_scale,
+ std::vector<float> &scaling_factor, std::vector<int64_t> &zp)
{
float scaling_factor_inv{0};
@@ -305,24 +377,27 @@ void quant_bias_per_channel(CircleConst *node, float input_scale, std::vector<fl
scaling_factor[i] = input_scale * weight_scale[i];
scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
quantized_values[i] =
- static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
+ static_cast<int32_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
zp[i] = 0;
}
- node->dtype(loco::DataType::S32); // change the type of tensor
- node->size<loco::DataType::S32>(size); // resize tensor
+ auto new_bias = create_empty_const_from<loco::DataType::S32>(node, size);
+
const int32_t kMinScale = std::numeric_limits<int32_t>::lowest();
const int32_t kMaxScale = std::numeric_limits<int32_t>::max();
for (uint32_t i = 0; i < size; ++i)
{
- node->at<loco::DataType::S32>(i) =
- std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ new_bias->at<loco::DataType::S32>(i) =
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
}
+
+ return new_bias;
}
-void int16_quant_bias_per_channel(CircleConst *node, float input_scale,
- std::vector<float> &weight_scale,
- std::vector<float> &scaling_factor, std::vector<int64_t> &zp)
+CircleConst *int16_quant_bias_per_channel(CircleConst *node, float input_scale,
+ std::vector<float> &weight_scale,
+ std::vector<float> &scaling_factor,
+ std::vector<int64_t> &zp)
{
float scaling_factor_inv{0};
@@ -334,16 +409,18 @@ void int16_quant_bias_per_channel(CircleConst *node, float input_scale,
scaling_factor[i] = input_scale * weight_scale[i];
scaling_factor_inv = (scaling_factor[i] == 0) ? 0 : 1.0 / scaling_factor[i];
quantized_values[i] =
- static_cast<int64_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
+ static_cast<int64_t>(std::round(node->at<loco::DataType::FLOAT32>(i) * scaling_factor_inv));
zp[i] = 0;
}
- node->dtype(loco::DataType::S64); // change the type of tensor
- node->size<loco::DataType::S64>(size); // resize tensor
+ auto new_bias = create_empty_const_from<loco::DataType::S64>(node, size);
+
for (uint32_t i = 0; i < size; ++i)
{
- node->at<loco::DataType::S64>(i) = quantized_values[i];
+ new_bias->at<loco::DataType::S64>(i) = quantized_values[i];
}
+
+ return new_bias;
}
bool has_min_max(const CircleNode *node)
@@ -362,42 +439,22 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_facto
uint32_t size = node->size<loco::DataType::FLOAT32>();
std::vector<int32_t> quantized_values(size);
- loco::TensorShape dimension;
- dimension.rank(4);
- uint32_t indices[4] = {
- 0,
+ auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ quantized_values[cal_offset(dimension, indices)] =
+ static_cast<int32_t>(std::round(data * scaling_factor_inv));
};
- if (!get_channel_dim_index(node, dimension, channel_dim_index))
- {
- assert(false);
- return;
- }
-
- for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
- {
- for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
- {
- for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
- {
- for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
- {
- int channel_idx = indices[channel_dim_index];
- const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
- auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
- quantized_values[cal_offset(dimension, indices)] =
- static_cast<int32_t>(std::round(data * scaling_factor_inv));
- }
- }
- }
- }
+ iterate_per_channel(node, channel_dim_index, quantize);
node->dtype(loco::DataType::S16); // change the type of tensor
node->size<loco::DataType::S16>(size); // resize tensor
for (uint32_t i = 0; i < size; ++i)
{
node->at<loco::DataType::S16>(i) =
- std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
}
}
@@ -412,35 +469,15 @@ void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min,
uint32_t size = node->size<loco::DataType::FLOAT32>();
std::vector<int32_t> quantized_values(size);
- loco::TensorShape dimension;
- dimension.rank(4);
- uint32_t indices[4] = {
- 0,
+ auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int32_t channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ quantized_values[cal_offset(dimension, indices)] =
+ static_cast<int32_t>(std::round((data - min[channel_idx]) * scaling_factor_inv));
};
- if (!get_channel_dim_index(node, dimension, channel_dim_index))
- {
- assert(false);
- return;
- }
-
- for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
- {
- for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
- {
- for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
- {
- for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
- {
- int channel_idx = indices[channel_dim_index];
- const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
- auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
- quantized_values[cal_offset(dimension, indices)] =
- static_cast<int32_t>(std::round((data - min[channel_idx]) * scaling_factor_inv));
- }
- }
- }
- }
+ iterate_per_channel(node, channel_dim_index, quantize);
node->dtype(loco::DataType::U8); // change the type of tensor
node->size<loco::DataType::U8>(size); // resize tensor
@@ -473,6 +510,21 @@ void asym_wquant_per_layer(CircleConst *node, float min, float scaling_factor)
}
}
+void set_bias(luci::CircleNode *node, luci::CircleConst *bias)
+{
+ if (auto conv = dynamic_cast<CircleConv2D *>(node))
+ conv->bias(bias);
+ else if (auto dconv = dynamic_cast<CircleDepthwiseConv2D *>(node))
+ dconv->bias(bias);
+ else if (auto tconv = dynamic_cast<CircleTransposeConv *>(node))
+ tconv->bias(bias);
+ else if (auto fc = dynamic_cast<CircleFullyConnected *>(node))
+ fc->bias(bias);
+ else
+ throw std::runtime_error("Only convolution, depthwise convolution, transposed convolution, and "
+ "fully-connected layer have bias");
+}
+
/**
* @brief QuantizeActivation quantizes tensors for activations
* @details Quantize using recorded min/max values
@@ -480,7 +532,7 @@ void asym_wquant_per_layer(CircleConst *node, float min, float scaling_factor)
struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<bool>
{
QuantizeActivation(loco::DataType input, loco::DataType output)
- : input_type(input), output_type(output)
+ : input_type(input), output_type(output)
{
}
@@ -503,8 +555,12 @@ struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<bool>
continue;
// Check if this is bias (bias is quantized later)
- auto iw = get_input_weight_of_bias(circle_node);
- if (iw.first != nullptr && iw.second != nullptr)
+ auto iwo = get_input_weight_output_of_bias(circle_node);
+ if (iwo.size() > 0)
+ continue;
+
+ // Check if this is bool type (bool type is not quantized)
+ if (circle_node->dtype() == loco::DataType::BOOL)
continue;
// Check if this is activation
@@ -547,7 +603,7 @@ struct QuantizeActivation final : public luci::CircleNodeMutableVisitor<bool>
struct QuantizeBias final : public luci::CircleNodeMutableVisitor<bool>
{
QuantizeBias(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
- : input_type(input), output_type(output), granularity(gr)
+ : input_type(input), output_type(output), granularity(gr)
{
}
@@ -562,65 +618,77 @@ struct QuantizeBias final : public luci::CircleNodeMutableVisitor<bool>
if (is_quantized(node))
return false;
- // Check if this is bias
- auto iw = get_input_weight_of_bias(node);
- if (iw.first == nullptr || iw.second == nullptr)
- return false;
-
- auto input = loco::must_cast<luci::CircleNode *>(iw.first);
- auto weight = loco::must_cast<luci::CircleNode *>(iw.second);
+ auto iwo_list = get_input_weight_output_of_bias(node);
- if (granularity == QuantizationGranularity::ChannelWise)
+ for (auto iwo : iwo_list)
{
- assert(input->quantparam()->scale.size() == 1); // input scale's layer-wise
- auto input_scale = input->quantparam()->scale[0];
+ assert(iwo.size() == 3);
- assert(weight->quantparam() != nullptr); // weight scale's channel-wise
- auto weight_scale = weight->quantparam()->scale;
+ auto input = loco::must_cast<luci::CircleNode *>(iwo[0]);
+ auto weight = loco::must_cast<luci::CircleNode *>(iwo[1]);
+ auto output = loco::must_cast<luci::CircleNode *>(iwo[2]);
- auto circle_const = loco::must_cast<luci::CircleConst *>(node);
+ auto const_bias = loco::must_cast<luci::CircleConst *>(node);
+ assert(const_bias->dtype() == loco::DataType::FLOAT32);
- uint32_t size = circle_const->size<loco::DataType::FLOAT32>();
- assert(size == weight_scale.size());
- std::vector<float> scaling_factor(size);
- std::vector<int64_t> zp(size);
+ CircleConst *new_bias = nullptr;
- if (output_type == loco::DataType::U8)
- {
- quant_bias_per_channel(circle_const, input_scale, weight_scale, scaling_factor, zp);
- }
- else if (output_type == loco::DataType::S16)
+ if (granularity == QuantizationGranularity::ChannelWise)
{
- int16_quant_bias_per_channel(circle_const, input_scale, weight_scale, scaling_factor, zp);
+ assert(input->quantparam()->scale.size() == 1); // input scale's layer-wise
+ auto input_scale = input->quantparam()->scale[0];
+
+ assert(weight->quantparam() != nullptr); // weight scale's channel-wise
+ auto weight_scale = weight->quantparam()->scale;
+
+ uint32_t size = const_bias->size<loco::DataType::FLOAT32>();
+ assert(size == weight_scale.size());
+ std::vector<float> scaling_factor(size);
+ std::vector<int64_t> zp(size);
+
+ if (output_type == loco::DataType::U8)
+ {
+ new_bias =
+ quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
+ }
+ else if (output_type == loco::DataType::S16)
+ {
+ new_bias =
+ int16_quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
+ }
+ else
+ {
+ throw std::runtime_error("Unsupported quantization type.");
+ }
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale = scaling_factor;
+ quantparam->zerop = zp;
+ assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
+ new_bias->quantparam(std::move(quantparam));
+
+ set_bias(output, new_bias);
}
else
{
- throw std::runtime_error("Unsupported quantization type.");
- }
+ assert(input->quantparam()->scale.size() == 1); // Only support per-layer quant
+ auto input_scale = input->quantparam()->scale[0];
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->scale = scaling_factor;
- quantparam->zerop = zp;
- assert(circle_const->quantparam() == nullptr); // bias should not be quantized before
- circle_const->quantparam(std::move(quantparam));
- }
- else
- {
- assert(input->quantparam()->scale.size() == 1); // Only support per-layer quant
- auto input_scale = input->quantparam()->scale[0];
-
- assert(weight->quantparam()->scale.size() == 1); // Only support per-layer quant
- auto weight_scale = weight->quantparam()->scale[0];
-
- auto circle_const = loco::must_cast<luci::CircleConst *>(node);
- float scaling_factor{0};
- int64_t zp{0};
- asym_quant_bias_per_layer(circle_const, input_scale, weight_scale, &scaling_factor, &zp);
- auto quantparam = std::make_unique<CircleQuantParam>();
- quantparam->scale.push_back(scaling_factor);
- quantparam->zerop.push_back(zp);
- assert(circle_const->quantparam() == nullptr); // bias should not be quantized before
- circle_const->quantparam(std::move(quantparam));
+ assert(weight->quantparam()->scale.size() == 1); // Only support per-layer quant
+ auto weight_scale = weight->quantparam()->scale[0];
+
+ float scaling_factor{0};
+ int64_t zp{0};
+ new_bias =
+ asym_quant_bias_per_layer(const_bias, input_scale, weight_scale, &scaling_factor, &zp);
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale.push_back(scaling_factor);
+ quantparam->zerop.push_back(zp);
+ assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
+ new_bias->quantparam(std::move(quantparam));
+
+ set_bias(output, new_bias);
+ }
}
return false;
}
@@ -633,7 +701,7 @@ struct QuantizeBias final : public luci::CircleNodeMutableVisitor<bool>
struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
{
QuantizeWeights(loco::DataType input, loco::DataType output, QuantizationGranularity gr)
- : input_type(input), output_type(output), granularity(gr)
+ : input_type(input), output_type(output), granularity(gr)
{
}
@@ -641,116 +709,179 @@ struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
loco::DataType output_type;
QuantizationGranularity granularity;
- // Quantize input tensors of each node
- bool visit(luci::CircleNode *node)
+private:
+ void quantize_weights(luci::CircleConst *weights)
{
- LOGGER(l);
- INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
- auto arity = node->arity();
- for (uint32_t i = 0; i < arity; i++)
+ // Find min/max per channel-wise
+ if (granularity == QuantizationGranularity::ChannelWise)
{
- auto input_node = node->arg(i);
- auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
+ auto quantparam = weights->quantparam();
+ if (quantparam == nullptr)
+ {
+ assert(false && "quantparam is nullptr");
+ return;
+ }
- // Check if this is already quantized
- if (is_quantized(circle_node))
- continue;
+ auto min = quantparam->min;
+ auto scaling_factor = quantparam->scale;
+ int32_t channel_dim_index = 0;
- if (is_weights(circle_node))
+ if (output_type == loco::DataType::U8)
{
- auto circle_const = loco::must_cast<luci::CircleConst *>(circle_node);
-
- // Find min/max per channel-wise
- if (granularity == QuantizationGranularity::ChannelWise)
- {
- auto quantparam = circle_node->quantparam();
- if (quantparam == nullptr)
- {
- assert(false && "quantparam is nullptr");
- return false;
- }
-
- auto min = quantparam->min;
- auto scaling_factor = quantparam->scale;
- int32_t channel_dim_index = 0;
-
- if (output_type == loco::DataType::U8)
- {
- asym_wquant_per_channel(circle_const, min, scaling_factor, channel_dim_index);
- }
- else
- {
- sym_wquant_per_channel(circle_const, scaling_factor, channel_dim_index);
- }
- quantparam->min.clear();
- quantparam->max.clear();
- quantparam->quantized_dimension = channel_dim_index;
- }
- // Find min/max per layer-wise
- else
- {
- // Quantize using recorded quantparam
- auto quantparam = circle_node->quantparam();
- assert(quantparam != nullptr);
- assert(quantparam->min.size() == 1); // only support layer-wise quant
- assert(quantparam->scale.size() == 1); // only support layer-wise quant
- auto min = quantparam->min[0];
- auto scaling_factor = quantparam->scale[0];
- asym_wquant_per_layer(circle_const, min, scaling_factor);
- quantparam->min.clear();
- quantparam->max.clear();
- }
+ asym_wquant_per_channel(weights, min, scaling_factor, channel_dim_index);
+ }
+ else
+ {
+ sym_wquant_per_channel(weights, scaling_factor, channel_dim_index);
}
+ quantparam->min.clear();
+ quantparam->max.clear();
+ quantparam->quantized_dimension = channel_dim_index;
+ }
+ // Find min/max per layer-wise
+ else
+ {
+ // Quantize using recorded quantparam
+ auto quantparam = weights->quantparam();
+ assert(quantparam != nullptr);
+ assert(quantparam->min.size() == 1); // only support layer-wise quant
+ assert(quantparam->scale.size() == 1); // only support layer-wise quant
+ auto min = quantparam->min[0];
+ auto scaling_factor = quantparam->scale[0];
+ asym_wquant_per_layer(weights, min, scaling_factor);
+ quantparam->min.clear();
+ quantparam->max.clear();
}
- return false;
}
-};
-void quant_instnorm(luci::CircleInstanceNorm *node, loco::DataType output_type,
- QuantizationGranularity granularity)
-{
- auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma());
- auto beta = loco::must_cast<luci::CircleConst *>(node->beta());
- assert(gamma->dtype() == loco::DataType::FLOAT32);
- assert(beta->dtype() == loco::DataType::FLOAT32);
+ bool visit(luci::CircleConv2D *node)
+ {
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
- if (granularity == QuantizationGranularity::LayerWise)
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ quantize_weights(new_weights);
+ return true;
+ }
+ return false;
+ }
+
+ bool visit(luci::CircleDepthwiseConv2D *node)
{
- quant_const(gamma, output_type);
- quant_const(beta, output_type);
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ quantize_weights(new_weights);
+ return true;
+ }
+ return false;
}
- else if (granularity == QuantizationGranularity::ChannelWise)
+
+ bool visit(luci::CircleInstanceNorm *node)
{
- quant_const_per_channel(gamma, output_type);
- quant_const_per_channel(beta, output_type);
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
+
+ auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma());
+ auto beta = loco::must_cast<luci::CircleConst *>(node->beta());
+
+ bool changed = false;
+ if (!is_quantized(gamma))
+ {
+ assert(gamma->dtype() == loco::DataType::FLOAT32);
+ auto new_gamma = luci::clone(gamma);
+ if (granularity == QuantizationGranularity::LayerWise)
+ quant_const(new_gamma, output_type);
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ quant_const_per_channel(new_gamma, output_type);
+ node->gamma(new_gamma);
+ changed = true;
+ }
+ if (!is_quantized(beta))
+ {
+ assert(beta->dtype() == loco::DataType::FLOAT32);
+ auto new_beta = luci::clone(beta);
+ if (granularity == QuantizationGranularity::LayerWise)
+ quant_const(new_beta, output_type);
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ quant_const_per_channel(new_beta, output_type);
+ node->beta(new_beta);
+ changed = true;
+ }
+
+ return changed;
}
- else
- throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'");
-}
-void quant_prelu(luci::CirclePRelu *node, loco::DataType output_type,
- QuantizationGranularity granularity)
-{
- auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha());
- assert(alpha->dtype() == loco::DataType::FLOAT32);
+ bool visit(luci::CirclePRelu *node)
+ {
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
+
+ auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha());
+
+ if (!is_quantized(alpha))
+ {
+ assert(alpha->dtype() == loco::DataType::FLOAT32);
+ auto new_alpha = luci::clone(alpha);
+ if (granularity == QuantizationGranularity::LayerWise)
+ quant_const(new_alpha, output_type);
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ quant_const_per_channel(new_alpha, output_type);
+ node->alpha(new_alpha);
+ return true;
+ }
- if (granularity == QuantizationGranularity::LayerWise)
+ return false;
+ }
+
+ bool visit(luci::CircleTransposeConv *node)
{
- quant_const(alpha, output_type);
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->filter(new_weights);
+ quantize_weights(new_weights);
+ return true;
+ }
+ return false;
}
- else if (granularity == QuantizationGranularity::ChannelWise)
+
+ bool visit(luci::CircleFullyConnected *node)
{
- quant_const_per_channel(alpha, output_type);
+ LOGGER(l);
+ INFO(l) << "QuantizeWeights visit node: " << node->name() << std::endl;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(node->weights());
+ if (!is_quantized(weights))
+ {
+ auto new_weights = luci::clone(weights);
+ node->weights(new_weights);
+ quantize_weights(new_weights);
+ return true;
+ }
+ return false;
}
- else
- throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'");
-}
+
+ bool visit(luci::CircleNode *) { return false; }
+};
/**
* @brief Quantize const input tensors using min/max of const values
*/
-void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type,
- QuantizationGranularity granularity)
+void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type)
{
auto opcode = node->opcode();
auto arity = node->arity();
@@ -763,6 +894,8 @@ void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type,
case luci::CircleOpcode::CONV_2D:
case luci::CircleOpcode::DEPTHWISE_CONV_2D:
case luci::CircleOpcode::FULLY_CONNECTED:
+ case luci::CircleOpcode::INSTANCE_NORM:
+ case luci::CircleOpcode::PRELU:
case luci::CircleOpcode::TRANSPOSE_CONV:
// Handled in QuantizeWeights and QuantizeBias
break;
@@ -771,8 +904,13 @@ void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type,
// Handled in propagate_concat_quantparam
break;
+ case luci::CircleOpcode::LOGICAL_OR:
+ // Inputs of logical Ops are bool, thus not quantized
+ break;
+
case luci::CircleOpcode::ARG_MAX:
case luci::CircleOpcode::ARG_MIN:
+ case luci::CircleOpcode::BATCH_TO_SPACE_ND:
case luci::CircleOpcode::MEAN:
case luci::CircleOpcode::PAD:
case luci::CircleOpcode::REDUCE_ANY:
@@ -783,6 +921,9 @@ void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type,
case luci::CircleOpcode::RESIZE_BILINEAR:
case luci::CircleOpcode::RESIZE_NEAREST_NEIGHBOR:
case luci::CircleOpcode::REVERSE_SEQUENCE:
+ case luci::CircleOpcode::SLICE:
+ case luci::CircleOpcode::SPACE_TO_BATCH_ND:
+ case luci::CircleOpcode::STRIDED_SLICE:
case luci::CircleOpcode::SUM:
case luci::CircleOpcode::TILE:
case luci::CircleOpcode::TOPK_V2:
@@ -791,41 +932,53 @@ void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type,
// Ex: axis, paddings
input_node = node->arg(0);
const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node != nullptr)
+ if (const_node != nullptr && !is_quantized(const_node))
quant_const(const_node, output_type);
break;
- case luci::CircleOpcode::INSTANCE_NORM:
- quant_instnorm(loco::must_cast<luci::CircleInstanceNorm *>(node), output_type, granularity);
- break;
-
- case luci::CircleOpcode::PRELU:
- quant_prelu(loco::must_cast<luci::CirclePRelu *>(node), output_type, granularity);
- break;
-
case luci::CircleOpcode::ADD:
case luci::CircleOpcode::ADD_N:
+ case luci::CircleOpcode::DEPTH_TO_SPACE:
case luci::CircleOpcode::DIV:
+ case luci::CircleOpcode::ELU:
case luci::CircleOpcode::EQUAL:
+ case luci::CircleOpcode::FLOOR:
+ case luci::CircleOpcode::FLOOR_DIV:
case luci::CircleOpcode::GREATER:
case luci::CircleOpcode::GREATER_EQUAL:
case luci::CircleOpcode::LESS:
case luci::CircleOpcode::LESS_EQUAL:
+ case luci::CircleOpcode::LOGISTIC:
case luci::CircleOpcode::MAXIMUM:
case luci::CircleOpcode::MINIMUM:
case luci::CircleOpcode::MUL:
case luci::CircleOpcode::NOT_EQUAL:
+ case luci::CircleOpcode::POW:
+ case luci::CircleOpcode::RSQRT:
+ case luci::CircleOpcode::SOFTMAX:
+ case luci::CircleOpcode::SPACE_TO_DEPTH:
+ case luci::CircleOpcode::SQRT:
case luci::CircleOpcode::SUB:
+ case luci::CircleOpcode::TANH:
// Quantize all const inputs using their values
for (uint32_t i = 0; i < arity; i++)
{
input_node = node->arg(i);
const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node != nullptr)
+ if (const_node != nullptr && !is_quantized(const_node))
quant_const(const_node, output_type);
}
break;
+ case luci::CircleOpcode::SPLIT:
+ // Only the second input is quantized
+ // First input should not be quantized (e.g., split_dim)
+ input_node = node->arg(1);
+ const_node = dynamic_cast<luci::CircleConst *>(input_node);
+ if (const_node != nullptr && !is_quantized(const_node))
+ quant_const(const_node, output_type);
+ break;
+
default:
for (uint32_t i = 0; i < arity; i++)
{
@@ -850,8 +1003,8 @@ void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type,
* (U8 qparam2)
*
* AFTER
- * [CircleNode] [CircleConst]
- * (U8 qparam2) (U8 qparam2)
+ * [CircleNode] [CircleConst] [CircleConst] <- Dead node
+ * (U8 qparam2) (U8 qparam2) (FP32)
* \ /
* \ /
* [CircleConcatenation]
@@ -871,7 +1024,11 @@ void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataTy
auto node = concat->arg(i);
auto const_node = dynamic_cast<luci::CircleConst *>(node);
if (const_node != nullptr)
- quant_const(const_node, quant_type);
+ {
+ auto new_const = luci::clone(const_node);
+ quant_const(new_const, quant_type);
+ concat->values(i, new_const);
+ }
}
return;
}
@@ -884,20 +1041,6 @@ void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataTy
if (node->opcode() == luci::CircleOpcode::CONCATENATION)
continue;
- // Skip if this input is used by other Ops
- auto succs = loco::succs(node);
- if (succs.size() != 1)
- {
- if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
- {
- luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
- quant_const(const_node, quant_type);
- }
- continue;
- }
-
- assert(succs.find(concat) != succs.end());
-
// Quantize constant values
if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
{
@@ -913,15 +1056,21 @@ void propagate_concat_quantparam(luci::CircleConcatenation *concat, loco::DataTy
const auto scaling_factor = concat_qparam->scale[0];
const auto zerop = concat_qparam->zerop[0];
- quant_const_values(const_node, scaling_factor, zerop, quant_type);
+ auto new_const = luci::clone(const_node);
+ quant_const_values(new_const, scaling_factor, zerop, quant_type);
+ concat->values(i, new_const);
+ overwrite_quantparam(concat, new_const);
}
else
{
+ const auto succs = loco::succs(node);
+ if (succs.size() > 1)
+ continue;
+
// Non-const input must have been quantized
assert(node->quantparam() != nullptr);
+ overwrite_quantparam(concat, node);
}
-
- overwrite_quantparam(concat, node);
}
}
@@ -954,13 +1103,6 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g)
circle_node->accept(&qb);
}
- // Quantize const inputs other than weights and bias
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
- {
- auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- quantize_const_inputs(circle_node, _output_dtype, _granularity);
- }
-
// Propagate quantization parameters of concat Op
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
@@ -976,6 +1118,13 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g)
propagate_concat_quantparam(concat, _output_dtype);
}
+ // Quantize const inputs other than weights and bias
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ quantize_const_inputs(circle_node, _output_dtype);
+ }
+
// Update output dtype
auto graph_outputs = g->outputs();
for (auto node : loco::output_nodes(g))
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp
new file mode 100644
index 000000000..75ec0cfd8
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/QuantizeWithMinMaxPass.h"
+
+#include <gtest/gtest.h>
+
+TEST(QuantizeWithMinMaxPassTest, name)
+{
+ luci::QuantizeWithMinMaxPass pass(loco::DataType::FLOAT32, loco::DataType::U8,
+ luci::QuantizationGranularity::LayerWise);
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.cpp
new file mode 100644
index 000000000..5ea803cc9
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.cpp
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizedModelVerifier.h"
+
+#include "VerifyQuantizedNodeLayerWiseGranularity.h"
+#include "VerifyQuantizedNodeChannelWiseGranularity.h"
+#include "VerifyQuantizedNodeU8Type.h"
+#include "VerifyQuantizedNodeS16Type.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+void QuantizedModelVerifier::verify(loco::Graph *g)
+{
+ if (_quantized_dtype != Type::U8 && _quantized_dtype != Type::S16)
+ throw std::runtime_error("Unsupported quantized dtype");
+
+ if (_granularity != Granularity::ChannelWise && _granularity != Granularity::LayerWise)
+ throw std::runtime_error("Unsupported granularity");
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+
+ // Verify Type
+ if (_quantized_dtype == Type::U8)
+ {
+ VerifyQuantizedNodeU8Type vt;
+ if (!circle_node->accept(&vt))
+ throw std::runtime_error("Wrong data type");
+ }
+ else if (_quantized_dtype == Type::S16)
+ {
+ VerifyQuantizedNodeS16Type vt;
+ if (!circle_node->accept(&vt))
+ throw std::runtime_error("Wrong data type");
+ }
+
+ // Verify Granularity
+ if (_granularity == Granularity::LayerWise)
+ {
+ VerifyQuantizedNodeLayerWiseGranularity vg;
+ if (!circle_node->accept(&vg))
+ throw std::runtime_error("Wrong granularity");
+ }
+ else if (_granularity == Granularity::ChannelWise)
+ {
+ VerifyQuantizedNodeChannelWiseGranularity vg;
+ if (!circle_node->accept(&vg))
+ throw std::runtime_error("Wrong granularity");
+ }
+ }
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.h b/compiler/luci/pass/src/QuantizedModelVerifier.h
new file mode 100644
index 000000000..d5fbb8e74
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_QUANTIZED_MODEL_VERIFIER_H__
+#define __LUCI_QUANTIZED_MODEL_VERIFIER_H__
+
+#include "luci/Pass/QuantizationParameters.h"
+
+#include <loco.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to verify quantized model
+ *
+ * TODO Move this to luci/service
+ */
+struct QuantizedModelVerifier
+{
+
+public:
+ QuantizedModelVerifier(loco::DataType quantized_dtype, QuantizationGranularity granularity)
+ : _quantized_dtype(quantized_dtype), _granularity(granularity)
+ {
+ }
+
+ void verify(loco::Graph *g);
+
+private:
+ loco::DataType _quantized_dtype;
+ QuantizationGranularity _granularity;
+};
+
+} // namespace luci
+
+#endif // __LUCI_QUANTIZED_MODEL_VERIFIER_H__
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
new file mode 100644
index 000000000..eae1b0c1f
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
@@ -0,0 +1,1668 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizedModelVerifier.h"
+
+#include "luci/Pass/QuantizeWithMinMaxPass.h"
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+using Type = loco::DataType;
+using Granularity = luci::QuantizationGranularity;
+
+namespace
+{
+
+/**
+ * @brief A helper function to create dummy const node
+ */
+template <Type T> luci::CircleConst *create_dummy_const(loco::Graph *g, luci::test::ShapeU32 shape)
+{
+ auto node = g->nodes()->create<luci::CircleConst>();
+ {
+ node->dtype(T);
+ node->shape(shape);
+ node->size<T>(luci::test::num_elements(shape));
+
+ for (int32_t i = 0; i < luci::test::num_elements(shape); i++)
+ {
+ // DESIGN NOTE
+ //
+ // Filling with any random numbers are fine
+ // Q. Should it include minus numbers?
+ switch (T)
+ {
+ case Type::FLOAT32:
+ // Fill with index
+ node->at<T>(i) = static_cast<float>(i);
+ break;
+ case Type::BOOL:
+ // Fill by flip
+ node->at<T>(i) = (i % 2) ? true : false;
+ break;
+ case Type::U8:
+ // Fill with index
+ node->at<T>(i) = static_cast<uint8_t>(i);
+ break;
+ case Type::S16:
+ // Fill with index
+ node->at<T>(i) = static_cast<int16_t>(i);
+ break;
+ }
+ }
+ }
+
+ return node;
+}
+
+/**
+ * @brief A helper function to create const node with value
+ */
+template <Type DT, typename T>
+luci::CircleConst *create_const(loco::Graph *g, luci::test::ShapeU32 shape,
+ std::initializer_list<T> values)
+{
+ auto node = g->nodes()->create<luci::CircleConst>();
+ {
+ node->dtype(DT);
+ node->shape(shape);
+ node->size<DT>(luci::test::num_elements(shape));
+
+ assert(values.size() == node->size<DT>());
+
+ uint32_t index = 0;
+ for (auto val : values)
+ {
+ node->at<DT>(index++) = static_cast<T>(val);
+ }
+ }
+
+ return node;
+}
+
+void insert_scale_zp(luci::CircleNode *node, float scale, int64_t zp)
+{
+ auto qparam = node->quantparam();
+ assert(qparam != nullptr); // FIX_CALLER_UNLESS
+ qparam->scale.push_back(scale);
+ qparam->zerop.push_back(zp);
+}
+
+void quantize_and_verify(loco::Graph *g, Type quantized_dtype, Granularity granularity)
+{
+ luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
+ pass.run(g);
+
+ luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
+ verifier.verify(g);
+}
+
+// Helper function to reduce duplicate test codes
+// Assumption: g->output()->from() is the target node
+void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype,
+ Granularity granularity, Type wrong_dtype)
+{
+ luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
+ pass.run(g->g());
+
+ auto node = loco::must_cast<luci::CircleNode *>(g->output()->from());
+ node->dtype(wrong_dtype);
+
+ luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
+ verifier.verify(g->g());
+}
+
+// Helper function to reduce duplicate test codes
+// Assumption: g->output()->from() is the target node
+void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype,
+ Granularity granularity)
+{
+ luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
+ pass.run(g->g());
+
+ auto node = loco::must_cast<luci::CircleNode *>(g->output()->from());
+ insert_scale_zp(node, 1.0, 1);
+
+ luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
+ verifier.verify(g->g());
+}
+
+// Helper function to reduce duplicate test codes
+void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype,
+ Granularity granularity, luci::CircleNode *target)
+{
+ luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
+ pass.run(g->g());
+
+ insert_scale_zp(target, 1.0, 1);
+
+ luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
+ verifier.verify(g->g());
+}
+
+// Set min/max for all non-const nodes in the graph
+void set_minmax_to_non_const(loco::Graph *g, float min, float max)
+{
+ for (auto node : loco::all_nodes(g))
+ {
+ auto const_node = dynamic_cast<luci::CircleConst *>(node);
+ if (const_node != nullptr)
+ continue;
+
+ // Min/Max is not recorded for ArgMax
+ // See MinMaxObserver.cpp in record_minmax module
+ auto argmax_node = dynamic_cast<luci::CircleArgMax *>(node);
+ if (argmax_node != nullptr)
+ continue;
+
+ // Min/Max is not recorded for Split
+ // See MinMaxObserver.cpp in record_minmax module
+ auto split_node = dynamic_cast<luci::CircleSplit *>(node);
+ if (split_node != nullptr)
+ continue;
+
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ {
+ qparam->min.emplace_back(min);
+ qparam->max.emplace_back(max);
+ }
+ circle_node->quantparam(std::move(qparam));
+ }
+}
+
+/**
+ * @brief Simple Test Graph
+ * @note
+ * The simple test graph's nodes are initialized with
+ * simple shapes and values.
+ */
+class SimpleTestGraph : public luci::test::TestIOGraph
+{
+public:
+ virtual void init(void) = 0;
+};
+
+class InstanceNormTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _gamma = create_dummy_const<Type::FLOAT32>(g(), {32});
+ _beta = create_dummy_const<Type::FLOAT32>(g(), {32});
+ _instnorm = g()->nodes()->create<luci::CircleInstanceNorm>();
+ {
+ _instnorm->input(input());
+ _instnorm->gamma(_gamma);
+ _instnorm->beta(_beta);
+ }
+ output()->from(_instnorm);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ loco::Node *gamma(void) const { return _instnorm->gamma(); }
+ loco::Node *beta(void) const { return _instnorm->beta(); }
+
+public:
+ luci::CircleInstanceNorm *_instnorm = nullptr;
+ luci::CircleConst *_input = nullptr;
+ luci::CircleConst *_gamma = nullptr;
+ luci::CircleConst *_beta = nullptr;
+};
+
+class LogisticTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _logistic = g()->nodes()->create<luci::CircleLogistic>();
+ {
+ _logistic->x(input());
+ }
+ output()->from(_logistic);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleLogistic *_logistic = nullptr;
+};
+
+class SoftmaxTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _softmax = g()->nodes()->create<luci::CircleSoftmax>();
+ {
+ _softmax->logits(input());
+ _softmax->beta(0.1);
+ }
+ output()->from(_softmax);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleSoftmax *_softmax = nullptr;
+};
+
+class SpaceToBatchNDTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({1, 2, 2, 1}, {4, 1, 1, 1});
+ _block_shape = create_dummy_const<Type::S32>(g(), {2});
+ for (uint32_t i = 0; i < 2; i++)
+ _block_shape->at<Type::S32>(i) = 2;
+
+ _paddings = create_dummy_const<Type::S32>(g(), {2, 2});
+ for (uint32_t i = 0; i < 4; i++)
+ _paddings->at<Type::S32>(i) = 0;
+
+ _stob = g()->nodes()->create<luci::CircleSpaceToBatchND>();
+ {
+ _stob->input(input());
+ _stob->block_shape(_block_shape);
+ _stob->paddings(_paddings);
+ }
+ output()->from(_stob);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleSpaceToBatchND *_stob = nullptr;
+ luci::CircleConst *_block_shape = nullptr;
+ luci::CircleConst *_paddings = nullptr;
+};
+
+class SpaceToDepthTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({1, 2, 2, 1}, {1, 1, 1, 4});
+ _stod = g()->nodes()->create<luci::CircleSpaceToDepth>();
+ {
+ _stod->input(input());
+ _stod->block_size(2);
+ }
+ output()->from(_stod);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleSpaceToDepth *_stod = nullptr;
+};
+
+template <Type indexT> class SliceTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _begin = g()->nodes()->create<luci::CircleConst>();
+ {
+ _begin->dtype(indexT);
+ }
+ _size = g()->nodes()->create<luci::CircleConst>();
+ {
+ _size->dtype(indexT);
+ }
+ _slice = g()->nodes()->create<luci::CircleSlice>();
+ {
+ _slice->input(input());
+ _slice->begin(_begin);
+ _slice->size(_size);
+ }
+ output()->from(_slice);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleSlice *_slice = nullptr;
+ luci::CircleConst *_begin = nullptr;
+ luci::CircleConst *_size = nullptr;
+};
+
+class SplitTestGraph final : public luci::test::TestIOGraph
+{
+public:
+ void init(void)
+ {
+ TestIOGraph::init({1, 32}, {32});
+ _split_dim = create_dummy_const<Type::S32>(g(), {1});
+ _split = g()->nodes()->create<luci::CircleSplit>();
+ {
+ _split->input(input());
+ _split->split_dim(_split_dim);
+ }
+ _split_o1 = g()->nodes()->create<luci::CircleSplitOut>();
+ {
+ _split_o1->input(_split);
+ _split_o1->index(0);
+ }
+
+ output()->from(_split_o1);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleSplit *_split = nullptr;
+ luci::CircleSplitOut *_split_o1 = nullptr;
+ luci::CircleConst *_split_dim = nullptr;
+};
+
+class StridedSliceTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _begin = g()->nodes()->create<luci::CircleConst>();
+ {
+ _begin->dtype(Type::S32);
+ }
+ _end = g()->nodes()->create<luci::CircleConst>();
+ {
+ _end->dtype(Type::S32);
+ }
+ _strides = g()->nodes()->create<luci::CircleConst>();
+ {
+ _strides->dtype(Type::S32);
+ }
+ _slice = g()->nodes()->create<luci::CircleStridedSlice>();
+ {
+ _slice->input(input());
+ _slice->begin(_begin);
+ _slice->end(_end);
+ _slice->strides(_strides);
+ }
+ output()->from(_slice);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleStridedSlice *_slice = nullptr;
+ luci::CircleConst *_begin = nullptr;
+ luci::CircleConst *_end = nullptr;
+ luci::CircleConst *_strides = nullptr;
+};
+
+class ReshapeTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _shape = g()->nodes()->create<luci::CircleConst>();
+ {
+ _shape->dtype(Type::S32);
+ }
+ _reshape = g()->nodes()->create<luci::CircleReshape>();
+ {
+ _reshape->tensor(input());
+ _reshape->shape(_shape);
+ }
+ output()->from(_reshape);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleReshape *_reshape = nullptr;
+ luci::CircleConst *_shape = nullptr;
+};
+
+class TanhTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _tanh = g()->nodes()->create<luci::CircleTanh>();
+ {
+ _tanh->x(input());
+ }
+ output()->from(_tanh);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleTanh *_tanh = nullptr;
+};
+
+class FloorTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _floor = g()->nodes()->create<luci::CircleFloor>();
+ {
+ _floor->x(input());
+ }
+ output()->from(_floor);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleFloor *_floor = nullptr;
+};
+
+template <Type indexT> class ArgMaxTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {1});
+ // output dtype is float by default, but ArgMax should have indexType (s32/s64)
+ output()->dtype(indexT);
+ _dimension = g()->nodes()->create<luci::CircleConst>();
+ {
+ _dimension->dtype(indexT);
+ }
+ _argmax = g()->nodes()->create<luci::CircleArgMax>();
+ {
+ _argmax->input(input());
+ _argmax->dimension(_dimension);
+ _argmax->output_type(indexT);
+ _argmax->dtype(indexT);
+ }
+ output()->from(_argmax);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleArgMax *_argmax = nullptr;
+ luci::CircleConst *_dimension = nullptr;
+};
+
+class BatchToSpaceNDTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _block_shape = g()->nodes()->create<luci::CircleConst>();
+ {
+ _block_shape->dtype(Type::S32);
+ }
+ _crops = g()->nodes()->create<luci::CircleConst>();
+ {
+ _crops->dtype(Type::S32);
+ }
+ _btos = g()->nodes()->create<luci::CircleBatchToSpaceND>();
+ {
+ _btos->input(input());
+ _btos->block_shape(_block_shape);
+ _btos->crops(_crops);
+ }
+ output()->from(_btos);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleBatchToSpaceND *_btos = nullptr;
+ luci::CircleConst *_block_shape = nullptr;
+ luci::CircleConst *_crops = nullptr;
+};
+
+class DepthToSpaceTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({1, 1, 1, 4}, {1, 2, 2, 1});
+ _dtos = g()->nodes()->create<luci::CircleDepthToSpace>();
+ {
+ _dtos->input(input());
+ _dtos->block_size(2);
+ }
+ output()->from(_dtos);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleDepthToSpace *_dtos = nullptr;
+};
+
+class PadTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _paddings = g()->nodes()->create<luci::CircleConst>();
+ {
+ _paddings->dtype(Type::S32);
+ }
+ _pad = g()->nodes()->create<luci::CirclePad>();
+ {
+ _pad->input(input());
+ _pad->paddings(_paddings);
+ }
+ output()->from(_pad);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CirclePad *_pad = nullptr;
+ luci::CircleConst *_paddings = nullptr;
+};
+
+class TransposeTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _perm = g()->nodes()->create<luci::CircleConst>();
+ {
+ _perm->dtype(Type::S32);
+ }
+ _transpose = g()->nodes()->create<luci::CircleTranspose>();
+ {
+ _transpose->a(input());
+ _transpose->perm(_perm);
+ }
+ output()->from(_transpose);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleTranspose *_transpose = nullptr;
+ luci::CircleConst *_perm = nullptr;
+};
+
+class ConcatenationTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({16}, {32});
+ _param = create_dummy_const<Type::FLOAT32>(g(), {16});
+ _concat = g()->nodes()->create<luci::CircleConcatenation>(2);
+ {
+ _concat->values(0, input());
+ _concat->values(1, _param);
+ _concat->axis(0);
+ }
+ output()->from(_concat);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleConcatenation *_concat = nullptr;
+ luci::CircleConst *_param = nullptr;
+};
+
+// Test graph for comparison Ops
+// GREATER, GREATER_EQUAL, LESS, LESS_EQUAL, EQUAL, NOT_EQUAL
+template <class Op> class ComparisonOpTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ output()->dtype(loco::DataType::BOOL);
+ _y = create_dummy_const<Type::FLOAT32>(g(), {32});
+ _op = g()->nodes()->create<Op>();
+ {
+ _op->x(input());
+ _op->y(_y);
+ _op->dtype(loco::DataType::BOOL);
+ }
+ output()->from(_op);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+ loco::Node *x(void) const { return _op->x(); }
+ loco::Node *y(void) const { return _op->y(); }
+
+public:
+ Op *_op = nullptr;
+ luci::CircleConst *_y = nullptr;
+};
+
+// Test graph for binary logical Ops
+// LOGICAL_OR, LOGICAL_AND
+template <class Op> class BinaryLogicalOpTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ input()->dtype(loco::DataType::BOOL);
+ output()->dtype(loco::DataType::BOOL);
+ _y = create_dummy_const<Type::BOOL>(g(), {32});
+ _op = g()->nodes()->create<Op>();
+ {
+ _op->x(input());
+ _op->y(_y);
+ _op->dtype(loco::DataType::BOOL);
+ }
+ output()->from(_op);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+ loco::Node *x(void) const { return _op->x(); }
+ loco::Node *y(void) const { return _op->y(); }
+
+public:
+ Op *_op = nullptr;
+ luci::CircleConst *_y = nullptr;
+};
+
+class DivTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+
+ _const = create_dummy_const<Type::FLOAT32>(g(), {32});
+ _div = g()->nodes()->create<luci::CircleDiv>();
+ {
+ _div->x(input());
+ _div->y(_const);
+ }
+ output()->from(_div);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+ loco::Node *x() { return _div->x(); }
+
+ loco::Node *y() { return _div->y(); }
+
+private:
+ luci::CircleDiv *_div = nullptr;
+ luci::CircleConst *_const = nullptr;
+};
+
+class FloorDivTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+
+ _const = create_dummy_const<Type::FLOAT32>(g(), {32});
+ _floor_div = g()->nodes()->create<luci::CircleFloorDiv>();
+ {
+ _floor_div->x(input());
+ _floor_div->y(_const);
+ }
+ output()->from(_floor_div);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+ loco::Node *x() { return _floor_div->x(); }
+
+ loco::Node *y() { return _floor_div->y(); }
+
+private:
+ luci::CircleFloorDiv *_floor_div = nullptr;
+ luci::CircleConst *_const = nullptr;
+};
+
+class RsqrtTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _rsqrt = g()->nodes()->create<luci::CircleRsqrt>();
+ {
+ _rsqrt->x(input());
+ }
+ output()->from(_rsqrt);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleRsqrt *_rsqrt = nullptr;
+};
+
+class SqrtTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _sqrt = g()->nodes()->create<luci::CircleSqrt>();
+ {
+ _sqrt->x(input());
+ }
+ output()->from(_sqrt);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleSqrt *_sqrt = nullptr;
+};
+
+class EluTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+ _elu = g()->nodes()->create<luci::CircleElu>();
+ {
+ _elu->features(input());
+ }
+ output()->from(_elu);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+public:
+ luci::CircleElu *_elu = nullptr;
+};
+
+class PowTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({32}, {32});
+
+ _const = create_dummy_const<Type::FLOAT32>(g(), {32});
+ _pow = g()->nodes()->create<luci::CirclePow>();
+ {
+ _pow->x(input());
+ _pow->y(_const);
+ }
+ output()->from(_pow);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+ loco::Node *x() { return _pow->x(); }
+
+ loco::Node *y() { return _pow->y(); }
+
+private:
+ luci::CirclePow *_pow = nullptr;
+ luci::CircleConst *_const = nullptr;
+};
+
+class ResizeBilinearTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({1, 4, 4, 1}, {1, 8, 8, 1});
+
+ _size = create_const<Type::S32, int32_t>(g(), {2}, {8, 8});
+ _resize_bilinear = g()->nodes()->create<luci::CircleResizeBilinear>();
+ {
+ _resize_bilinear->input(input());
+ _resize_bilinear->size(_size);
+ }
+ output()->from(_resize_bilinear);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+private:
+ luci::CircleResizeBilinear *_resize_bilinear = nullptr;
+ luci::CircleConst *_size = nullptr;
+};
+
+} // namespace
+
+// Quantize and verify with given configurations
+#define TEST_WITH_GRAPH(graph, type, granularity) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ EXPECT_NO_THROW(quantize_and_verify(g.g(), type, granularity)); \
+ } while (0)
+
+// Quantize and verify with wrong type
+#define TEST_WITH_WRONG_TYPE(graph, type, granularity, wrong_dtype) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ EXPECT_ANY_THROW(quantize_and_verify_with_wrong_type(&g, type, granularity, wrong_dtype)); \
+ } while (0)
+
+// Quantize and verify with wrong granularity
+#define TEST_WITH_WRONG_GRANULARITY(graph, type, granularity) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ EXPECT_ANY_THROW(quantize_and_verify_with_wrong_granularity(&g, type, granularity)); \
+ } while (0)
+
+// Quantize and verify with wrong granularity
+// Users can specify the test target
+#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity, target) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ auto node = loco::must_cast<luci::CircleNode *>(target); \
+ EXPECT_ANY_THROW(quantize_and_verify_with_wrong_granularity(&g, type, granularity, node)); \
+ } while (0)
+
+// Test a local helper function
+TEST(QuantizedModelVerifierTest, LocalCreateDummyConst)
+{
+ loco::Graph g;
+
+ EXPECT_NO_THROW(create_dummy_const<Type::FLOAT32>(&g, {32, 32}));
+}
+
+TEST(QuantizedModelVerifierTest, LocalCreateConst)
+{
+ loco::Graph g;
+ std::initializer_list<float> values = {0.1, 0, -5, 100};
+ luci::CircleConst *node = create_const<Type::FLOAT32, float>(&g, {2, 2}, values);
+
+ uint32_t index = 0;
+ for (auto val : values)
+ {
+ EXPECT_EQ(node->at<Type::FLOAT32>(index++), val);
+ }
+}
+
+TEST(QuantizedModelVerifierTest, InstanceNorm)
+{
+ TEST_WITH_GRAPH(InstanceNormTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(InstanceNormTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(InstanceNormTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, InstanceNorm_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(InstanceNormTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(InstanceNormTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(InstanceNormTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, InstanceNorm_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(InstanceNormTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(InstanceNormTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(InstanceNormTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Logistic)
+{
+ TEST_WITH_GRAPH(LogisticTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(LogisticTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(LogisticTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Logistic_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(LogisticTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(LogisticTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(LogisticTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Logistic_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(LogisticTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(LogisticTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(LogisticTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Softmax)
+{
+ TEST_WITH_GRAPH(SoftmaxTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(SoftmaxTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(SoftmaxTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Softmax_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(SoftmaxTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SoftmaxTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SoftmaxTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Softmax_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(SoftmaxTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(SoftmaxTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(SoftmaxTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, SpaceToBatchND)
+{
+ TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(SpaceToBatchNDTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, SpaceToBatchND_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(SpaceToBatchNDTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SpaceToBatchNDTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SpaceToBatchNDTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, SpaceToBatchND_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(SpaceToBatchNDTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(SpaceToBatchNDTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(SpaceToBatchNDTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, SpaceToDepth)
+{
+ TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(SpaceToDepthTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, SpaceToDepth_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(SpaceToDepthTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SpaceToDepthTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SpaceToDepthTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, SpaceToDepth_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(SpaceToDepthTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(SpaceToDepthTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(SpaceToDepthTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Slice)
+{
+ TEST_WITH_GRAPH(SliceTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(SliceTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(SliceTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(SliceTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Slice_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(SliceTestGraph<Type::S32>, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SliceTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SliceTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise, Type::U8);
+
+ TEST_WITH_WRONG_TYPE(SliceTestGraph<Type::S64>, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SliceTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SliceTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Slice_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(SliceTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(SliceTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(SliceTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_WRONG_GRANULARITY(SliceTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(SliceTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(SliceTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Split)
+{
+ TEST_WITH_GRAPH(SplitTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(SplitTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(SplitTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Split_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(SplitTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SplitTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SplitTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Split_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(SplitTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(SplitTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(SplitTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, StridedSlice)
+{
+ TEST_WITH_GRAPH(StridedSliceTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(StridedSliceTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(StridedSliceTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, StridedSlice_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(StridedSliceTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(StridedSliceTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(StridedSliceTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, StridedSlice_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(StridedSliceTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(StridedSliceTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(StridedSliceTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, ArgMax)
+{
+ TEST_WITH_GRAPH(ArgMaxTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(ArgMaxTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(ArgMaxTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_GRAPH(ArgMaxTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(ArgMaxTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(ArgMaxTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, ArgMax_wrong_dimension_type_NEG)
+{
+ ArgMaxTestGraph<Type::S32> g;
+ g.init();
+ luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, Type::U8, Granularity::LayerWise);
+ pass.run(g.g());
+
+ g._dimension->dtype(Type::U8);
+
+ luci::QuantizedModelVerifier verifier(Type::U8, Granularity::LayerWise);
+ EXPECT_ANY_THROW(verifier.verify(g.g()));
+}
+
+TEST(QuantizedModelVerifierTest, ArgMax_wrong_input_granularity_NEG)
+{
+ ArgMaxTestGraph<Type::S32> g;
+ g.init();
+
+ luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, Type::U8, Granularity::LayerWise);
+ pass.run(g.g());
+
+ insert_scale_zp(loco::must_cast<luci::CircleNode *>(g._argmax->input()), 1.0, 1);
+
+ luci::QuantizedModelVerifier verifier(Type::U8, Granularity::LayerWise);
+ EXPECT_ANY_THROW(verifier.verify(g.g()));
+}
+
+TEST(QuantizedModelVerifierTest, BatchToSpaceND)
+{
+ TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(BatchToSpaceNDTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, BatchToSpaceND_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(BatchToSpaceNDTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(BatchToSpaceNDTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(BatchToSpaceNDTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, BatchToSpaceND_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(BatchToSpaceNDTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(BatchToSpaceNDTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(BatchToSpaceNDTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, DepthToSpace)
+{
+ TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(DepthToSpaceTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, DepthToSpace_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(DepthToSpaceTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(DepthToSpaceTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(DepthToSpaceTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, DepthToSpace_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(DepthToSpaceTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(DepthToSpaceTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(DepthToSpaceTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Concatenation)
+{
+ TEST_WITH_GRAPH(ConcatenationTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(ConcatenationTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(ConcatenationTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Concatenation_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(ConcatenationTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(ConcatenationTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(ConcatenationTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Concatenation_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(ConcatenationTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(ConcatenationTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(ConcatenationTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, LogicalOr)
+{
+ TEST_WITH_GRAPH(BinaryLogicalOpTestGraph<luci::CircleLogicalOr>, Type::U8,
+ Granularity::LayerWise);
+ TEST_WITH_GRAPH(BinaryLogicalOpTestGraph<luci::CircleLogicalOr>, Type::U8,
+ Granularity::ChannelWise);
+ TEST_WITH_GRAPH(BinaryLogicalOpTestGraph<luci::CircleLogicalOr>, Type::S16,
+ Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, LogicalOr_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(BinaryLogicalOpTestGraph<luci::CircleLogicalOr>, Type::U8,
+ Granularity::LayerWise, Type::U8);
+ TEST_WITH_WRONG_TYPE(BinaryLogicalOpTestGraph<luci::CircleLogicalOr>, Type::U8,
+ Granularity::ChannelWise, Type::U8);
+ TEST_WITH_WRONG_TYPE(BinaryLogicalOpTestGraph<luci::CircleLogicalOr>, Type::S16,
+ Granularity::ChannelWise, Type::S16);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Reshape)
+{
+ TEST_WITH_GRAPH(ReshapeTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(ReshapeTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(ReshapeTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Reshape_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(ReshapeTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(ReshapeTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(ReshapeTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Reshape_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(ReshapeTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(ReshapeTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(ReshapeTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Tanh)
+{
+ TEST_WITH_GRAPH(TanhTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(TanhTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(TanhTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Tanh_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(TanhTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(TanhTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(TanhTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Tanh_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(TanhTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(TanhTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(TanhTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Pad)
+{
+ TEST_WITH_GRAPH(PadTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(PadTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(PadTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Pad_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(PadTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(PadTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(PadTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Pad_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(PadTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(PadTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(PadTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Transpose)
+{
+ TEST_WITH_GRAPH(TransposeTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(TransposeTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(TransposeTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Transpose_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(TransposeTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(TransposeTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(TransposeTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Transpose_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(TransposeTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(TransposeTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(TransposeTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Floor)
+{
+ TEST_WITH_GRAPH(FloorTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(FloorTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(FloorTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Floor_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(FloorTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(FloorTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(FloorTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Floor_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(FloorTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(FloorTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(FloorTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, GreaterEqual)
+{
+ TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8,
+ Granularity::LayerWise);
+ TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8,
+ Granularity::ChannelWise);
+ TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::S16,
+ Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, GreaterEqual_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8,
+ Granularity::LayerWise, Type::U8);
+ TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8,
+ Granularity::ChannelWise, Type::U8);
+ TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::S16,
+ Granularity::ChannelWise, Type::S16);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, GreaterEqual_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8,
+ Granularity::LayerWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8,
+ Granularity::ChannelWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::S16,
+ Granularity::ChannelWise, g.x());
+
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8,
+ Granularity::LayerWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::U8,
+ Granularity::ChannelWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreaterEqual>, Type::S16,
+ Granularity::ChannelWise, g.y());
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Greater)
+{
+ TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleGreater>, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Greater_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8, Granularity::LayerWise,
+ Type::U8);
+ TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8,
+ Granularity::ChannelWise, Type::U8);
+ TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleGreater>, Type::S16,
+ Granularity::ChannelWise, Type::S16);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Greater_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8,
+ Granularity::LayerWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8,
+ Granularity::ChannelWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreater>, Type::S16,
+ Granularity::ChannelWise, g.x());
+
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8,
+ Granularity::LayerWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreater>, Type::U8,
+ Granularity::ChannelWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleGreater>, Type::S16,
+ Granularity::ChannelWise, g.y());
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, NotEqual)
+{
+ TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, NotEqual_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8,
+ Granularity::LayerWise, Type::U8);
+ TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8,
+ Granularity::ChannelWise, Type::U8);
+ TEST_WITH_WRONG_TYPE(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::S16,
+ Granularity::ChannelWise, Type::S16);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, NotEqual_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8,
+ Granularity::LayerWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8,
+ Granularity::ChannelWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::S16,
+ Granularity::ChannelWise, g.x());
+
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8,
+ Granularity::LayerWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::U8,
+ Granularity::ChannelWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(ComparisonOpTestGraph<luci::CircleNotEqual>, Type::S16,
+ Granularity::ChannelWise, g.y());
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Div)
+{
+ TEST_WITH_GRAPH(DivTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(DivTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(DivTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Div_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(DivTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(DivTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(DivTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Div_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY_TARGET(DivTestGraph, Type::U8, Granularity::LayerWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(DivTestGraph, Type::U8, Granularity::ChannelWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(DivTestGraph, Type::S16, Granularity::ChannelWise, g.x());
+
+ TEST_WITH_WRONG_GRANULARITY_TARGET(DivTestGraph, Type::U8, Granularity::LayerWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(DivTestGraph, Type::U8, Granularity::ChannelWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(DivTestGraph, Type::S16, Granularity::ChannelWise, g.y());
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, FloorDiv)
+{
+ TEST_WITH_GRAPH(FloorDivTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(FloorDivTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(FloorDivTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, FloorDiv_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(FloorDivTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(FloorDivTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(FloorDivTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, FloorDiv_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY_TARGET(FloorDivTestGraph, Type::U8, Granularity::LayerWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(FloorDivTestGraph, Type::U8, Granularity::ChannelWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(FloorDivTestGraph, Type::S16, Granularity::ChannelWise, g.x());
+
+ TEST_WITH_WRONG_GRANULARITY_TARGET(FloorDivTestGraph, Type::U8, Granularity::LayerWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(FloorDivTestGraph, Type::U8, Granularity::ChannelWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(FloorDivTestGraph, Type::S16, Granularity::ChannelWise, g.y());
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Rsqrt)
+{
+ TEST_WITH_GRAPH(RsqrtTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(RsqrtTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(RsqrtTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Rsqrt_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(RsqrtTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(RsqrtTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(RsqrtTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Rsqrt_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(RsqrtTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(RsqrtTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(RsqrtTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Sqrt)
+{
+ TEST_WITH_GRAPH(SqrtTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(SqrtTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(SqrtTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Sqrt_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(SqrtTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SqrtTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(SqrtTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Sqrt_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(SqrtTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(SqrtTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(SqrtTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Elu)
+{
+ TEST_WITH_GRAPH(EluTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(EluTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(EluTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Elu_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(EluTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(EluTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(EluTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Elu_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(EluTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(EluTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(EluTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Pow)
+{
+ TEST_WITH_GRAPH(PowTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(PowTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(PowTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Pow_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(PowTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(PowTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(PowTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, Pow_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY_TARGET(PowTestGraph, Type::U8, Granularity::LayerWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(PowTestGraph, Type::U8, Granularity::ChannelWise, g.x());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(PowTestGraph, Type::S16, Granularity::ChannelWise, g.x());
+
+ TEST_WITH_WRONG_GRANULARITY_TARGET(PowTestGraph, Type::U8, Granularity::LayerWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(PowTestGraph, Type::U8, Granularity::ChannelWise, g.y());
+ TEST_WITH_WRONG_GRANULARITY_TARGET(PowTestGraph, Type::S16, Granularity::ChannelWise, g.y());
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, ResizeBilinear)
+{
+ TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, ResizeBilinear_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, ResizeBilinear_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(ResizeBilinearTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(ResizeBilinearTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+#undef TEST_WITH_GRAPH
+#undef TEST_WITH_WRONG_TYPE
+#undef TEST_WITH_WRONG_GRANULARITY
diff --git a/compiler/luci/pass/src/RemoveRedundantReshape.cpp b/compiler/luci/pass/src/RemoveRedundantReshape.cpp
new file mode 100644
index 000000000..2f0b22ae6
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantReshape.cpp
@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveRedundantReshapePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+bool remove_redundant_reshape(luci::CircleReshape *node)
+{
+ auto pred_node = dynamic_cast<luci::CircleReshape *>(node->tensor());
+ if (pred_node == nullptr)
+ return false;
+
+ node->tensor(pred_node->tensor());
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * BEFORE
+ *
+ * [CircleNode]
+ * |
+ * [CircleReshape_1]
+ * |
+ * [CircleReshape_2]
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ *
+ * [CircleNode]
+ * / \
+ * [CircleReshape_1] [CircleReshape_2]
+ * |
+ * [CircleNode]
+ **/
+bool RemoveRedundantReshapePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto reshape_node = dynamic_cast<luci::CircleReshape *>(node))
+ {
+ if (remove_redundant_reshape(reshape_node))
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/RemoveRedundantReshape.test.cpp b/compiler/luci/pass/src/RemoveRedundantReshape.test.cpp
new file mode 100644
index 000000000..617840f3a
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantReshape.test.cpp
@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "luci/Pass/RemoveRedundantReshapePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+class RemoveRedundantReshape : public ::testing::Test
+{
+public:
+ RemoveRedundantReshape() {}
+
+ void createReshapeConst(luci::CircleReshape *target, const std::vector<int32_t> shape)
+ {
+ auto shape_const = g.nodes()->create<luci::CircleConst>();
+ shape_const->dtype(loco::DataType::S32);
+ shape_const->size<loco::DataType::S32>(shape.size());
+ shape_const->shape_status(luci::ShapeStatus::VALID);
+ shape_const->rank(1);
+ shape_const->dim(0).set(shape.size());
+ for (int32_t i = 0; i < shape.size(); i++)
+ {
+ shape_const->at<loco::DataType::S32>(i) = shape.at(i);
+ }
+ shape_const->name("shape_const");
+ target->shape(shape_const);
+ }
+
+ void buildGraph(const std::initializer_list<uint32_t> base_shape,
+ const std::vector<int32_t> first_shape, const std::vector<int32_t> second_shape)
+ {
+ // Input Create.
+ input = g.nodes()->create<luci::CircleInput>();
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ input->shape_status(luci::ShapeStatus::VALID);
+ input->rank(base_shape.size());
+ input->shape(base_shape);
+ input->name("input");
+
+ // Create first reshape.
+ first_reshape = g.nodes()->create<luci::CircleReshape>();
+ first_reshape->tensor(input);
+ first_reshape->name("Reshape");
+ createReshapeConst(first_reshape, first_shape);
+
+ // Create second reshape.
+ second_reshape = g.nodes()->create<luci::CircleReshape>();
+ second_reshape->tensor(first_reshape);
+ second_reshape->name("second_reshape");
+ createReshapeConst(second_reshape, second_shape);
+
+ // Output Connect.
+ output = g.nodes()->create<luci::CircleOutput>();
+ output->from(second_reshape);
+ output->name("output");
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleReshape *first_reshape = nullptr;
+ luci::CircleReshape *second_reshape = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+} // namespace
+
+TEST(RemoveRedundantReshapePassTest, name)
+{
+ luci::RemoveRedundantReshapePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(RemoveRedundantReshape, simple_case)
+{
+ buildGraph({4, 6}, {-1, 4, 6}, {1, -1, 2, 3});
+ luci::RemoveRedundantReshapePass pass;
+ while (pass.run(&g))
+ ;
+ int count = 0;
+ for (auto node : loco::active_nodes(loco::output_nodes(&g)))
+ {
+ if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
+ {
+ count++;
+ }
+ }
+ ASSERT_EQ(1, count);
+}
diff --git a/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp b/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp
deleted file mode 100644
index db608b674..000000000
--- a/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp
+++ /dev/null
@@ -1,156 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#include "luci/Pass/RemoveRedundantTransposePass.h"
-
-#include <luci/IR/CircleNodes.h>
-
-#include <vector>
-
-#include <gtest/gtest.h>
-
-namespace
-{
-
-void setValue(luci::CircleConst *node, const std::vector<int> &v)
-{
- node->dtype(loco::DataType::S32);
- node->size<loco::DataType::S32>(v.size());
- node->rank(1);
- node->dim(0).set(v.size());
- for (int i = 0; i < v.size(); ++i)
- {
- node->at<loco::DataType::S32>(i) = v[i];
- }
-}
-
-/**
- * Type1
- * BEFORE
- * |
- * [CircleNode] [CircleConst]
- * \ /
- * [CircleTranspose] [CircleConst]
- * \ /
- * [CircleTranspose]
- * |
- *
- * AFTER
- * |
- * [CircleNode]
- * | Remove Both
- *
- * --------------------------------------------
- *
- * Type2
- * BEFORE
- * |
- * [CircleNode] [CircleConst]
- * \ /
- * [CircleTranspose] [CircleConst]
- * \ /
- * [CircleTranspose]
- * |
- *
- * AFTER
- * | |
- * [CircleNode] [CircleConst]
- * \ /
- * [CircleTranspose]
- * |
- *
- */
-void create_redundunt_transpose(loco::Graph *g, const std::vector<int32_t> &perm1,
- const std::vector<int32_t> &perm2)
-{
- assert(g);
-
- auto input = g->nodes()->create<luci::CircleInput>();
- auto graph_input = g->inputs()->create();
- input->index(graph_input->index());
-
- // Create perm1
- auto perm1_node = g->nodes()->create<luci::CircleConst>();
- setValue(perm1_node, perm1);
-
- auto transpose1 = g->nodes()->create<luci::CircleTranspose>();
- transpose1->dtype(loco::DataType::FLOAT32);
- transpose1->a(input);
- transpose1->perm(perm1_node);
-
- // Create perm2
- auto perm2_node = g->nodes()->create<luci::CircleConst>();
- setValue(perm2_node, perm2);
-
- auto transpose2 = g->nodes()->create<luci::CircleTranspose>();
- transpose2->dtype(loco::DataType::FLOAT32);
- transpose2->a(transpose1);
- transpose2->perm(perm2_node);
-
- // Output
- auto output = g->nodes()->create<luci::CircleOutput>();
- output->from(transpose2);
- auto graph_output = g->outputs()->create();
- output->index(graph_output->index());
-}
-
-} // namespace
-
-TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type1)
-{
- auto graph = loco::make_graph();
- create_redundunt_transpose(graph.get(), {1, 0, 2, 3}, {1, 0, 2, 3});
-
- luci::RemoveRedundantTransposePass pass;
- while (pass.run(graph.get()))
- ;
- luci::CircleTranspose *transpose_node = nullptr;
- for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
- {
- auto trans = dynamic_cast<luci::CircleTranspose *>(node);
- if (not trans)
- continue;
- transpose_node = trans;
- break;
- }
- // No transpose node is in graph.
- ASSERT_EQ(nullptr, transpose_node);
-}
-
-TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type2)
-{
- auto graph = loco::make_graph();
- create_redundunt_transpose(graph.get(), {0, 1, 3, 2}, {1, 0, 2, 3});
-
- luci::RemoveRedundantTransposePass pass;
- while (pass.run(graph.get()))
- ;
- luci::CircleTranspose *transpose_node = nullptr;
- for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
- {
- auto trans = dynamic_cast<luci::CircleTranspose *>(node);
- if (not trans)
- continue;
- transpose_node = trans;
- break;
- }
- // Just one transpose node, with updated perm constant.
- ASSERT_NE(nullptr, transpose_node);
- auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm());
- ASSERT_EQ(1, perm->at<loco::DataType::S32>(0));
- ASSERT_EQ(0, perm->at<loco::DataType::S32>(1));
- ASSERT_EQ(3, perm->at<loco::DataType::S32>(2));
- ASSERT_EQ(2, perm->at<loco::DataType::S32>(3));
-}
diff --git a/compiler/luci/pass/src/RemoveRedundantTranspose.cpp b/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp
index 33cb76520..71c51ecda 100644
--- a/compiler/luci/pass/src/RemoveRedundantTranspose.cpp
+++ b/compiler/luci/pass/src/RemoveRedundantTransposePass.cpp
@@ -17,6 +17,7 @@
#include "luci/Pass/RemoveRedundantTransposePass.h"
#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
namespace
{
@@ -35,47 +36,54 @@ bool check_perm(const luci::CircleConst *first_perm, const luci::CircleConst *se
return true;
}
-bool remove_consecutive_transpose_function(luci::CircleNode *node)
+bool remove_consecutive_transpose_function(luci::CircleTranspose *target_node)
{
- auto target_node = dynamic_cast<luci::CircleTranspose *>(node);
- if (target_node == nullptr)
- return false;
auto pred_node = dynamic_cast<luci::CircleTranspose *>(target_node->a());
if (pred_node == nullptr)
return false;
- if (loco::succs(pred_node).size() != 1)
- return false;
- auto pred_perm = dynamic_cast<luci::CircleConst *>(target_node->perm());
- if (pred_perm == nullptr)
+ auto target_perm = dynamic_cast<luci::CircleConst *>(target_node->perm());
+ if (target_perm == nullptr)
return false;
- auto main_perm = dynamic_cast<luci::CircleConst *>(pred_node->perm());
- if (main_perm == nullptr)
+ auto pred_perm = dynamic_cast<luci::CircleConst *>(pred_node->perm());
+ if (pred_perm == nullptr)
return false;
auto main_node = loco::must_cast<luci::CircleNode *>(pred_node->a());
- if (check_perm(pred_perm, main_perm))
+ if (check_perm(target_perm, pred_perm))
{
- replace(node).with(main_node);
+ replace(target_node).with(main_node);
}
else
{
- auto g = main_perm->graph();
+ auto name = target_node->name();
+ assert(name.length() > 0);
+
+ auto g = pred_perm->graph();
auto new_const_node = g->nodes()->create<luci::CircleConst>();
new_const_node->dtype(loco::DataType::S32);
new_const_node->rank(1);
- new_const_node->dim(0) = main_perm->dim(0);
- new_const_node->size<loco::DataType::S32>(main_perm->dim(0).value());
+ new_const_node->dim(0) = pred_perm->dim(0);
+ new_const_node->size<loco::DataType::S32>(pred_perm->dim(0).value());
new_const_node->shape_status(luci::ShapeStatus::VALID);
- for (uint32_t i = 0; i < main_perm->size<loco::DataType::S32>(); i++)
+ for (uint32_t i = 0; i < pred_perm->size<loco::DataType::S32>(); i++)
{
new_const_node->at<loco::DataType::S32>(i) =
- pred_perm->at<loco::DataType::S32>(main_perm->at<loco::DataType::S32>(i));
+ target_perm->at<loco::DataType::S32>(pred_perm->at<loco::DataType::S32>(i));
}
- pred_node->perm(new_const_node);
- replace(node).with(pred_node);
+ new_const_node->name(name + "/Transpose/perm");
+
+ // Create New Transpose Node
+ auto new_transpose_node = g->nodes()->create<luci::CircleTranspose>();
+ new_transpose_node->dtype(target_node->dtype());
+ new_transpose_node->a(main_node);
+ new_transpose_node->perm(new_const_node);
+ new_transpose_node->name(name + "/Transpose");
+ luci::add_origin(new_transpose_node, luci::get_origin(target_node));
+
+ replace(target_node).with(new_transpose_node);
}
return true;
}
@@ -84,41 +92,36 @@ bool remove_consecutive_transpose_function(luci::CircleNode *node)
namespace luci
{
+
/**
* BEFORE
* |
* [CircleNode] [CircleConst]
- * (main_node) (main_perm)
- * \ /
+ * | (pred_perm)
+ * \ /
* [CircleTranspose] [CircleConst]
- * (pred_node) (pred_perm)
+ * (pred_node) (target_perm)
* \ /
* [CircleTranspose]
* (target_node)
* |
*
* AFTER
- * <Optional Case>
- *
- * | | |
- * [CircleNode] [CircleConst] |
- * (main_node) (new_const_node) |
- * \ / or [CircleNode]
- * [CircleTranspose] (main_node)
- * (pred_node) |
+ * | |
+ * [CircleNode] [CircleConst](new) |
+ * \ / or [CircleNode]
+ * [CircleTranspose](new) |
* | |
- *
*/
bool RemoveRedundantTransposePass::run(loco::Graph *g)
{
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- if (remove_consecutive_transpose_function(circle_node))
+ if (auto transpose = dynamic_cast<luci::CircleTranspose *>(node))
{
- changed = true;
- break;
+ if (remove_consecutive_transpose_function(transpose))
+ changed = true;
}
}
return changed;
diff --git a/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp b/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
new file mode 100644
index 000000000..e80623499
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantTransposePass.test.cpp
@@ -0,0 +1,321 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "luci/Pass/RemoveRedundantTransposePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <vector>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+void setValue(luci::CircleConst *node, const std::vector<int> &v)
+{
+ node->dtype(loco::DataType::S32);
+ node->size<loco::DataType::S32>(v.size());
+ node->rank(1);
+ node->dim(0).set(v.size());
+ for (int i = 0; i < v.size(); ++i)
+ {
+ node->at<loco::DataType::S32>(i) = v[i];
+ }
+}
+
+/**
+ * Remove for consecutive Transpose
+ *
+ * Type1: Remove both Transpose
+ * BEFORE
+ * |
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleTranspose] [CircleConst]
+ * \ /
+ * [CircleTranspose]
+ * |
+ *
+ * AFTER
+ * |
+ * [CircleNode]
+ * |
+ *
+ * --------------------------------------------
+ *
+ * Type2: Merge to one Transpose
+ * BEFORE
+ * |
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleTranspose] [CircleConst]
+ * \ /
+ * [CircleTranspose]
+ * |
+ *
+ * AFTER
+ * |
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleTranspose]
+ * |
+ *
+ */
+void create_redundunt_transpose(loco::Graph *g, const std::vector<int32_t> &perm1,
+ const std::vector<int32_t> &perm2)
+{
+ assert(g);
+
+ auto input = g->nodes()->create<luci::CircleInput>();
+ auto graph_input = g->inputs()->create();
+ input->index(graph_input->index());
+ input->name("input");
+
+ // Create perm1
+ auto perm1_node = g->nodes()->create<luci::CircleConst>();
+ setValue(perm1_node, perm1);
+ perm1_node->name("perm1_node");
+
+ auto transpose1 = g->nodes()->create<luci::CircleTranspose>();
+ transpose1->dtype(loco::DataType::FLOAT32);
+ transpose1->a(input);
+ transpose1->perm(perm1_node);
+ transpose1->name("transpose1");
+
+ // Create perm2
+ auto perm2_node = g->nodes()->create<luci::CircleConst>();
+ setValue(perm2_node, perm2);
+ perm2_node->name("perm2_node");
+
+ auto transpose2 = g->nodes()->create<luci::CircleTranspose>();
+ transpose2->dtype(loco::DataType::FLOAT32);
+ transpose2->a(transpose1);
+ transpose2->perm(perm2_node);
+ transpose2->name("transpose2");
+
+ // Output
+ auto output = g->nodes()->create<luci::CircleOutput>();
+ output->from(transpose2);
+ auto graph_output = g->outputs()->create();
+ output->index(graph_output->index());
+ output->name("output");
+}
+
+/**
+ * Remove for consecutive Transposes with branching
+ *
+ * BEFORE
+ * |
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleConst] [CircleTranspose] [CircleConst]
+ * \ / \ /
+ * [CircleTranspose] [CircleTranspose]
+ * | |
+ * [CircleNode] [CircleNode]
+ * | |
+ *
+ * AFTER
+ * Type 1: Remove all Transpose
+ * |
+ * [CircleNode]
+ * / \
+ * [CircleNode] [CircleNode]
+ * | |
+ *
+ * Type 2: Remove both for one side and create new for another side
+ * |
+ * [CircleNode] [CircleConst](new)
+ * / \ /
+ * / [CircleTranspose](new)
+ * | |
+ * [CircleNode] [CircleNode]
+ * | |
+ */
+void create_redundunt_transpose_with_branch(loco::Graph *g, const std::vector<int32_t> &perm1,
+ const std::vector<int32_t> &perm2,
+ const std::vector<int32_t> &perm3)
+{
+ assert(g);
+
+ auto input = g->nodes()->create<luci::CircleInput>();
+ auto graph_input = g->inputs()->create();
+ input->dtype(loco::DataType::FLOAT32);
+ input->index(graph_input->index());
+ input->name("input");
+ graph_input->dtype(loco::DataType::FLOAT32);
+
+ graph_input->shape({4, 4, 4, 4});
+ input->shape({4, 4, 4, 4});
+
+ // Create perm1
+ auto perm1_node = g->nodes()->create<luci::CircleConst>();
+ setValue(perm1_node, perm1);
+ perm1_node->name("perm1_node");
+
+ auto transpose1 = g->nodes()->create<luci::CircleTranspose>();
+ transpose1->dtype(loco::DataType::FLOAT32);
+ transpose1->a(input);
+ transpose1->perm(perm1_node);
+ transpose1->name("transpose1");
+
+ // Create perm2
+ auto perm2_node = g->nodes()->create<luci::CircleConst>();
+ setValue(perm2_node, perm2);
+ perm2_node->name("perm2_node");
+
+ auto transpose2 = g->nodes()->create<luci::CircleTranspose>();
+ transpose2->dtype(loco::DataType::FLOAT32);
+ transpose2->a(transpose1);
+ transpose2->perm(perm2_node);
+ transpose2->name("transpose2");
+
+ // create perm3
+ auto perm3_node = g->nodes()->create<luci::CircleConst>();
+ setValue(perm3_node, perm3);
+ perm3_node->name("perm3_node");
+
+ auto transpose3 = g->nodes()->create<luci::CircleTranspose>();
+ transpose3->dtype(loco::DataType::FLOAT32);
+ transpose3->a(transpose1);
+ transpose3->perm(perm3_node);
+ transpose3->name("transpose3");
+
+ // Output
+ auto output1 = g->nodes()->create<luci::CircleOutput>();
+ output1->from(transpose2);
+ output1->name("output1");
+ auto output2 = g->nodes()->create<luci::CircleOutput>();
+ output2->from(transpose3);
+ output2->name("output2");
+ auto graph_output1 = g->outputs()->create();
+ output1->index(graph_output1->index());
+ auto graph_output2 = g->outputs()->create();
+ output2->index(graph_output2->index());
+ output1->dtype(loco::DataType::FLOAT32);
+ output2->dtype(loco::DataType::FLOAT32);
+ graph_output1->dtype(loco::DataType::FLOAT32);
+ graph_output2->dtype(loco::DataType::FLOAT32);
+ output1->shape({4, 4, 4, 4});
+ output2->shape({4, 4, 4, 4});
+ graph_output1->shape({4, 4, 4, 4});
+ graph_output2->shape({4, 4, 4, 4});
+}
+
+} // namespace
+
+TEST(RemoveRedundantTransposePassTest, name)
+{
+ luci::RemoveRedundantTransposePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type1)
+{
+ auto graph = loco::make_graph();
+ create_redundunt_transpose(graph.get(), {1, 0, 2, 3}, {1, 0, 2, 3});
+
+ luci::RemoveRedundantTransposePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleTranspose *transpose_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ auto trans = dynamic_cast<luci::CircleTranspose *>(node);
+ if (not trans)
+ continue;
+ transpose_node = trans;
+ break;
+ }
+ // No transpose node is in graph.
+ ASSERT_EQ(nullptr, transpose_node);
+}
+
+TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type2)
+{
+ auto graph = loco::make_graph();
+ create_redundunt_transpose(graph.get(), {0, 1, 3, 2}, {1, 0, 2, 3});
+
+ luci::RemoveRedundantTransposePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleTranspose *transpose_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ auto trans = dynamic_cast<luci::CircleTranspose *>(node);
+ if (not trans)
+ continue;
+ transpose_node = trans;
+ break;
+ }
+ // Just one transpose node, with updated perm constant.
+ ASSERT_NE(nullptr, transpose_node);
+ auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm());
+ ASSERT_EQ(1, perm->at<loco::DataType::S32>(0));
+ ASSERT_EQ(0, perm->at<loco::DataType::S32>(1));
+ ASSERT_EQ(3, perm->at<loco::DataType::S32>(2));
+ ASSERT_EQ(2, perm->at<loco::DataType::S32>(3));
+}
+
+/**
+ * @brief Test case that first transpose output become input of operations more than one.
+ */
+TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_with_branch_remove_case)
+{
+ auto graph = loco::make_graph();
+ create_redundunt_transpose_with_branch(graph.get(), {1, 0, 2, 3}, {1, 0, 2, 3}, {1, 0, 2, 3});
+
+ luci::RemoveRedundantTransposePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleTranspose *transpose_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ auto trans = dynamic_cast<luci::CircleTranspose *>(node);
+ if (not trans)
+ continue;
+ transpose_node = trans;
+ break;
+ }
+ // No transpose node is in graph.
+ ASSERT_EQ(nullptr, transpose_node);
+}
+
+TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_with_branch_leave_one)
+{
+ auto graph = loco::make_graph();
+ create_redundunt_transpose_with_branch(graph.get(), {1, 0, 2, 3}, {1, 0, 2, 3}, {0, 1, 3, 2});
+
+ luci::RemoveRedundantTransposePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleTranspose *transpose_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ auto trans = dynamic_cast<luci::CircleTranspose *>(node);
+ if (not trans)
+ continue;
+ transpose_node = trans;
+ break;
+ }
+ ASSERT_NE(nullptr, transpose_node);
+ auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm());
+ ASSERT_EQ(1, perm->at<loco::DataType::S32>(0));
+ ASSERT_EQ(0, perm->at<loco::DataType::S32>(1));
+ ASSERT_EQ(3, perm->at<loco::DataType::S32>(2));
+ ASSERT_EQ(2, perm->at<loco::DataType::S32>(3));
+}
diff --git a/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp
new file mode 100644
index 000000000..3f0c4ee82
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.cpp
@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveUnnecessaryReshapePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+bool remove_no_effect_reshape(luci::CircleNode *node)
+{
+ auto target_node = dynamic_cast<luci::CircleReshape *>(node);
+ if (target_node == nullptr)
+ return false;
+
+ auto new_shape = dynamic_cast<luci::CircleConst *>(target_node->shape());
+ if (new_shape == nullptr)
+ return false;
+
+ // Compare updated shape and input shape.
+ auto input_node = loco::must_cast<luci::CircleNode *>(target_node->tensor());
+ if (input_node->rank() != new_shape->dim(0).value())
+ return false;
+ for (uint32_t i = 0; i < input_node->rank(); i++)
+ {
+ // If update_shape is -1, don't care
+ // TODO check updated shape has value -1 at most one.
+ if (new_shape->at<loco::DataType::S32>(i) == -1)
+ continue;
+ // If input_shape dynamic, can't remove this.
+ if (!input_node->dim(i).known())
+ return false;
+ // If input_shape and updated shape differ, also can't remove.
+ if (input_node->dim(i).value() != static_cast<uint32_t>(new_shape->at<loco::DataType::S32>(i)))
+ return false;
+ }
+
+ replace(target_node).with(input_node);
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool RemoveUnnecessaryReshapePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (remove_no_effect_reshape(circle_node))
+ {
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.test.cpp b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.test.cpp
new file mode 100644
index 000000000..9d2e758b4
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveUnnecessaryReshapePass.test.cpp
@@ -0,0 +1,141 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveUnnecessaryReshapePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+#include "test/TestFirstNode.h"
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class ReshapeGraphlet
+{
+public:
+ ReshapeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 input_shape, bool remove)
+ {
+ std::vector<uint32_t> shape_vector{input_shape};
+
+ auto dim0_val = remove ? shape_vector.size() : 1;
+ _reshape_shape = g->nodes()->create<luci::CircleConst>();
+ _reshape_shape->rank(1);
+ _reshape_shape->dim(0).set(dim0_val);
+ _reshape_shape->shape_status(luci::ShapeStatus::VALID);
+ _reshape_shape->dtype(loco::DataType::S32);
+
+ _reshape_shape->size<loco::DataType::S32>(dim0_val);
+ for (uint32_t i = 0; i < dim0_val; i++)
+ {
+ if (remove)
+ _reshape_shape->at<loco::DataType::S32>(i) = static_cast<int32_t>(shape_vector.at(i));
+ else
+ _reshape_shape->at<loco::DataType::S32>(i) = -1;
+ }
+ _reshape_shape->name("reshape_shape");
+
+ // Reshape create
+ auto newshape_rank = remove ? shape_vector.size() : 1;
+ _reshape = g->nodes()->create<luci::CircleReshape>();
+ _reshape->newShape()->rank(newshape_rank);
+ for (uint32_t i = 0; i < newshape_rank; i++)
+ {
+ if (remove)
+ _reshape->newShape()->dim(i) = static_cast<int32_t>(shape_vector.at(i));
+ else
+ _reshape->newShape()->dim(i) = -1;
+ }
+ _reshape->name("reshape");
+ }
+
+protected:
+ luci::CircleReshape *_reshape = nullptr;
+ luci::CircleConst *_reshape_shape = nullptr;
+};
+
+class ReshapeGraph : public TestIOGraph, public ReshapeGraphlet
+{
+public:
+ ReshapeGraph() = default;
+
+public:
+ void init(const ShapeU32 shape, bool remove)
+ {
+ TestIOGraph::init(shape, shape);
+ ReshapeGraphlet::init(g(), shape, remove);
+
+ // connect graph
+ _reshape->tensor(input());
+ _reshape->shape(_reshape_shape);
+
+ output()->from(_reshape);
+ }
+};
+
+// TODO use ::testing::Test
+
+} // namespace
+
+TEST(RemoveUnnecessaryReshapePassTest, name)
+{
+ luci::RemoveUnnecessaryReshapePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(RemoveUnnecessaryReshapePass, removed)
+{
+ ReshapeGraph g;
+
+ g.init({1, 2, 3, 4}, true);
+
+ // confirm graph has Reshape
+ auto reshape_node = luci::test::first_node<luci::CircleReshape>(g.g());
+ ASSERT_NE(nullptr, reshape_node);
+ luci::RemoveUnnecessaryReshapePass pass;
+ while (pass.run(g.g()))
+ ;
+
+ // check Reshape is removed
+ reshape_node = luci::test::first_node<luci::CircleReshape>(g.g());
+ ASSERT_EQ(nullptr, reshape_node);
+}
+
+TEST(RemoveUnnecessaryReshapePass, not_removed_NEG)
+{
+ ReshapeGraph g;
+
+ g.init({1, 2, 3, 4}, false);
+
+ // confirm graph has Reshape
+ auto reshape_node = luci::test::first_node<luci::CircleReshape>(g.g());
+ ASSERT_NE(nullptr, reshape_node);
+ luci::RemoveUnnecessaryReshapePass pass;
+ while (pass.run(g.g()))
+ ;
+
+ // check Reshape is NOT removed
+ reshape_node = luci::test::first_node<luci::CircleReshape>(g.g());
+ ASSERT_NE(nullptr, reshape_node);
+}
diff --git a/compiler/luci/pass/src/RemoveUnnecessarySlicePass.cpp b/compiler/luci/pass/src/RemoveUnnecessarySlicePass.cpp
new file mode 100644
index 000000000..0720813cd
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveUnnecessarySlicePass.cpp
@@ -0,0 +1,111 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveUnnecessarySlicePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+/**
+ * @brief Return value in CircleConst.
+ * @details Return value in position on CircleConst with int64 format.
+ * Begin must be larger than or equal to 0. Size must be larger
+ * than or equal to -1.
+ */
+int64_t value_from_circle_const(const luci::CircleConst *node, uint32_t idx)
+{
+ assert(node->rank() == 1 && node->dim(0).value() > idx);
+ assert(node->dtype() == loco::DataType::S64 || node->dtype() == loco::DataType::S32);
+
+ if (node->dtype() == loco::DataType::S64)
+ return node->at<loco::DataType::S64>(idx);
+ return static_cast<int64_t>(node->at<loco::DataType::S32>(idx));
+}
+
+bool remove_no_effect_slice(luci::CircleNode *node)
+{
+ auto target_node = dynamic_cast<luci::CircleSlice *>(node);
+ if (target_node == nullptr)
+ return false;
+
+ auto begin_const = dynamic_cast<luci::CircleConst *>(target_node->begin());
+ if (begin_const == nullptr)
+ return false;
+
+ auto size_const = dynamic_cast<luci::CircleConst *>(target_node->size());
+ if (size_const == nullptr)
+ return false;
+
+ // Check input output shape.
+ auto input_node = loco::must_cast<luci::CircleNode *>(target_node->input());
+ for (uint32_t i = 0; i < input_node->rank(); i++)
+ {
+ if (value_from_circle_const(begin_const, i) != 0)
+ return false;
+
+ int64_t size_value = value_from_circle_const(size_const, i);
+ if (size_value == -1)
+ continue;
+ if (size_value != static_cast<int64_t>(input_node->dim(i).value()))
+ return false;
+
+ if (!input_node->dim(i).known())
+ return false;
+ }
+ replace(target_node).with(input_node);
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+/**
+ * BEFORE
+ *
+ * [CircleNode]
+ * |
+ * [CircleSlice]
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ *
+ * [CircleNode]
+ * |
+ * [CircleNode]
+ *
+ * Slice OP has no effect if,
+ * 1. Static Shape : begin_const[idx] is 0 AND size_const[idx] is (-1 OR input_dimension[idx])
+ * 2. Dynamic Shape : begin_const[idx] is 0 AND size_const[idx] is -1
+ */
+bool RemoveUnnecessarySlicePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (remove_no_effect_slice(circle_node))
+ {
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/RemoveUnnecessarySlicePass.test.cpp b/compiler/luci/pass/src/RemoveUnnecessarySlicePass.test.cpp
new file mode 100644
index 000000000..80921a93a
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveUnnecessarySlicePass.test.cpp
@@ -0,0 +1,134 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "luci/Pass/RemoveUnnecessarySlicePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+#include "test/TestFirstNode.h"
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class SliceGraphlet
+{
+public:
+ SliceGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 input_shape, bool remove)
+ {
+ // Begin Create.
+ _begin = g->nodes()->create<luci::CircleConst>();
+ _begin->rank(1);
+ _begin->dim(0).set(input_shape.size());
+ _begin->shape_status(luci::ShapeStatus::VALID);
+ _begin->dtype(loco::DataType::S32);
+ _begin->size<loco::DataType::S32>(input_shape.size());
+ for (int i = 0; i < input_shape.size(); ++i)
+ _begin->at<loco::DataType::S32>(i) = remove ? 0 : 1;
+ _begin->name("begin");
+
+ // Size Create.
+ _size = g->nodes()->create<luci::CircleConst>();
+ _size->rank(1);
+ _size->dim(0).set(input_shape.size());
+ _size->shape_status(luci::ShapeStatus::VALID);
+ _size->dtype(loco::DataType::S32);
+ _size->size<loco::DataType::S32>(input_shape.size());
+ for (int i = 0; i < input_shape.size(); ++i)
+ _size->at<loco::DataType::S32>(i) = -1;
+ _size->name("size");
+
+ // Slice Node create.
+ _slice = g->nodes()->create<luci::CircleSlice>();
+ _slice->dtype(loco::DataType::S32);
+ _slice->name("slice");
+ }
+
+protected:
+ luci::CircleSlice *_slice = nullptr;
+ luci::CircleConst *_begin = nullptr;
+ luci::CircleConst *_size = nullptr;
+};
+
+class SliceGraph : public TestIOGraph, public SliceGraphlet
+{
+public:
+ SliceGraph() = default;
+
+public:
+ void init(const ShapeU32 shape, bool remove)
+ {
+ TestIOGraph::init(shape, shape);
+ SliceGraphlet::init(g(), shape, remove);
+
+ _slice->input(input());
+ _slice->begin(_begin);
+ _slice->size(_size);
+
+ output()->from(_slice);
+ }
+};
+
+} // namespace
+
+TEST(RemoveUnnecessarySlicePass, name)
+{
+ luci::RemoveUnnecessarySlicePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(RemoveUnnecessarySlicePass, removed)
+{
+ SliceGraph g;
+
+ g.init({2, 4, 2, 3}, true);
+
+ // confirm graph has Slice
+ auto slice_node = luci::test::first_node<luci::CircleSlice>(g.g());
+ ASSERT_NE(nullptr, slice_node);
+ luci::RemoveUnnecessarySlicePass pass;
+ while (pass.run(g.g()))
+ ;
+
+ // check Slice is removed
+ slice_node = luci::test::first_node<luci::CircleSlice>(g.g());
+ ASSERT_EQ(nullptr, slice_node);
+}
+
+TEST(RemoveUnnecessarySlicePass, not_removed_NEG)
+{
+ SliceGraph g;
+
+ g.init({2, 4, 2, 3}, false);
+
+ // confirm graph has Slice
+ auto slice_node = luci::test::first_node<luci::CircleSlice>(g.g());
+ ASSERT_NE(nullptr, slice_node);
+ luci::RemoveUnnecessarySlicePass pass;
+ while (pass.run(g.g()))
+ ;
+
+ // check Slice is NOT removed
+ slice_node = luci::test::first_node<luci::CircleSlice>(g.g());
+ ASSERT_NE(nullptr, slice_node);
+}
diff --git a/compiler/luci/pass/src/ShapeSignatureInferencePass.cpp b/compiler/luci/pass/src/RemoveUnnecessarySplitPass.cpp
index 115b77a96..3243f6213 100644
--- a/compiler/luci/pass/src/ShapeSignatureInferencePass.cpp
+++ b/compiler/luci/pass/src/RemoveUnnecessarySplitPass.cpp
@@ -14,49 +14,50 @@
* limitations under the License.
*/
-#include "luci/Pass/ShapeSignatureInferencePass.h"
+#include "luci/Pass/RemoveUnnecessarySplitPass.h"
-#include <luci/IR/CircleShapeSignature.h>
-#include <luci/Service/CircleShapeSignatureInference.h>
+#include <luci/IR/CircleNodes.h>
-#include <loco.h>
-
-namespace luci
+namespace
{
-
-bool ShapeSignatureInferencePass::run(luci::Module *m)
+bool remove_unnecessary_split(luci::CircleNode *node)
{
- bool changed = false;
+ auto target_node = dynamic_cast<luci::CircleSplitOut *>(node);
+ if (target_node == nullptr)
+ return false;
+
+ auto split_node = dynamic_cast<luci::CircleSplit *>(target_node->input());
+ if (split_node == nullptr)
+ return false;
- for (size_t g = 0; g < m->size(); ++g)
+ if (loco::succs(split_node).size() != 1)
+ return false;
+
+ if (split_node->num_split() == 1)
{
- if (run(m->graph(g)))
- changed = true;
+ auto input_node = loco::must_cast<luci::CircleNode *>(split_node->input());
+ replace(target_node).with(input_node);
+ return true;
}
-
- return changed;
+ return false;
}
-bool ShapeSignatureInferencePass::run(loco::Graph *g)
+} // namespace
+
+namespace luci
{
- luci::ssinf::Rule signature_inference_rule;
- bool changed = false;
- for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+bool RemoveUnnecessarySplitPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- luci::ShapeSignature shape_signature;
-
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- if (signature_inference_rule.infer(circle_node, shape_signature))
+ if (remove_unnecessary_split(circle_node))
{
- if (!(circle_node->shape_signature() == shape_signature))
- {
- circle_node->shape_signature(shape_signature);
- changed = true;
- }
+ changed = true;
}
}
-
return changed;
}
diff --git a/compiler/luci/pass/src/RemoveUnnecessarySplitPass.test.cpp b/compiler/luci/pass/src/RemoveUnnecessarySplitPass.test.cpp
new file mode 100644
index 000000000..f292b5357
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveUnnecessarySplitPass.test.cpp
@@ -0,0 +1,149 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveUnnecessarySplitPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+#include "test/TestFirstNode.h"
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class SplitGraphlet
+{
+public:
+ SplitGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, uint32_t nout)
+ {
+ assert(nout == 1 || nout == 2);
+
+ _dim = g->nodes()->create<luci::CircleConst>();
+ set_shape_vector(_dim, {0});
+ _dim->name("dim");
+
+ _split = g->nodes()->create<luci::CircleSplit>();
+ _split->num_split(nout);
+ _split->name("split");
+
+ _split_out_0 = g->nodes()->create<luci::CircleSplitOut>();
+ _split_out_0->index(0);
+ _split_out_0->name("split_out_0");
+
+ if (nout == 2)
+ {
+ _split_out_1 = g->nodes()->create<luci::CircleSplitOut>();
+ _split_out_1->index(1);
+ _split_out_1->name("split_out_1");
+ }
+ }
+
+protected:
+ luci::CircleSplit *_split = nullptr;
+ luci::CircleConst *_dim = nullptr;
+ luci::CircleSplitOut *_split_out_0 = nullptr;
+ luci::CircleSplitOut *_split_out_1 = nullptr;
+};
+
+class SplitOneGraph : public TestIGraphlet, public TestOGraphlet, public SplitGraphlet
+{
+public:
+ SplitOneGraph() = default;
+
+public:
+ void init()
+ {
+ TestIGraphlet::init(g(), {1});
+ TestOGraphlet::init(g(), {1});
+ SplitGraphlet::init(g(), 1);
+
+ _split->input(input());
+ _split->split_dim(_dim);
+ _split_out_0->input(_split);
+
+ output()->from(_split_out_0);
+ }
+};
+
+class SplitTwoGraph : public TestIGraphlet, public TestOsGraphlet<2>, public SplitGraphlet
+{
+public:
+ SplitTwoGraph() = default;
+
+public:
+ void init()
+ {
+ TestIGraphlet::init(g(), {1});
+ TestOsGraphlet<2>::init(g(), {{1}, {1}});
+ SplitGraphlet::init(g(), 2);
+
+ _split->input(input());
+ _split->split_dim(_dim);
+ _split_out_0->input(_split);
+ _split_out_1->input(_split);
+
+ output(0)->from(_split_out_0);
+ output(1)->from(_split_out_1);
+ }
+};
+
+// TODO use ::testing::Test
+
+} // namespace
+
+TEST(RemoveUnnecessarySplitPass, name)
+{
+ luci::RemoveUnnecessarySplitPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(RemoveUnnecessarySplitPass, create_unnecessary_split)
+{
+ SplitOneGraph g;
+
+ g.init();
+
+ luci::RemoveUnnecessarySplitPass pass;
+ while (pass.run(g.g()))
+ ;
+
+ auto split_node = luci::test::first_node<luci::CircleSplit>(g.g());
+ // No Split node is in graph.
+ ASSERT_EQ(nullptr, split_node);
+}
+
+TEST(RemoveUnnecessarySplitPass, create_unnecessary_split_NEG)
+{
+ SplitTwoGraph g;
+
+ g.init();
+
+ luci::RemoveUnnecessarySplitPass pass;
+ while (pass.run(g.g()))
+ ;
+
+ auto split_node = luci::test::first_node<luci::CircleSplit>(g.g());
+ // Split node is in graph.
+ ASSERT_NE(nullptr, split_node);
+}
diff --git a/compiler/luci/pass/src/RemoveUnnecessaryStridedSlicePass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryStridedSlicePass.cpp
new file mode 100644
index 000000000..22b1aa64f
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveUnnecessaryStridedSlicePass.cpp
@@ -0,0 +1,124 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+/**
+ * @brief Return value in CircleConst.
+ * @details Return value in position on CircleConst with int64 format.
+ */
+int64_t value_from_circle_const(const luci::CircleConst *node, uint32_t idx)
+{
+ assert(node->rank() == 1 && node->dim(0).value() > idx);
+ assert(node->dtype() == loco::DataType::S64 || node->dtype() == loco::DataType::S32);
+
+ if (node->dtype() == loco::DataType::S64)
+ return node->at<loco::DataType::S64>(idx);
+ return static_cast<int64_t>(node->at<loco::DataType::S32>(idx));
+}
+
+bool remove_no_effect_strided_slice(luci::CircleStridedSlice *target_node)
+{
+ auto begin_const = dynamic_cast<luci::CircleConst *>(target_node->begin());
+ if (begin_const == nullptr)
+ return false;
+
+ auto strides_const = dynamic_cast<luci::CircleConst *>(target_node->strides());
+ if (strides_const == nullptr)
+ return false;
+
+ auto end_const = dynamic_cast<luci::CircleConst *>(target_node->end());
+ if (end_const == nullptr)
+ return false;
+
+ auto input_node = loco::must_cast<luci::CircleNode *>(target_node->input());
+ for (uint32_t i = 0; i < input_node->rank(); i++)
+ {
+ if (value_from_circle_const(begin_const, i) != 0)
+ return false;
+
+ int64_t strides_value = value_from_circle_const(strides_const, i);
+ if (strides_value != 1)
+ return false;
+
+ int64_t end_value = value_from_circle_const(end_const, i);
+ if (end_value == -1)
+ continue;
+
+ if (end_value != input_node->dim(i).value())
+ return false;
+
+ if (!input_node->dim(i).known())
+ return false;
+ }
+
+ /**
+ * We check additional attributes on zero after shapes
+ * for skipping wrong StridedSlice operator.
+ */
+ if (target_node->new_axis_mask() != 0 || target_node->shrink_axis_mask() != 0)
+ return false;
+
+ replace(target_node).with(input_node);
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+/**
+ * BEFORE
+ *
+ * [CircleNode]
+ * |
+ * [CircleStridedSlice]
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ *
+ * [CircleNode]
+ * |
+ * [CircleNode] [CircleStridedSlice]
+ *
+ * StridedSlice OP has no effect if,
+ * 1. Static Shape : begin_const[idx] is 0 AND strides_const[idx] is (not 1 OR
+ * input_dimension[idx])
+ * 2. Dynamic Shape : begin_const[idx] is 0 AND strides_const[idx] is not 1
+ *
+ * StridedSlice OP has effect if,
+ * 1. begin_const[idx] is 0 AND input_shape[idx] are equal to end_shape[idx]
+ */
+bool RemoveUnnecessaryStridedSlicePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto target_node = dynamic_cast<luci::CircleStridedSlice *>(node);
+ if (target_node != nullptr)
+ if (remove_no_effect_strided_slice(target_node))
+ changed = true;
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/RemoveUnnecessaryStridedSlicePass.test.cpp b/compiler/luci/pass/src/RemoveUnnecessaryStridedSlicePass.test.cpp
new file mode 100644
index 000000000..7d611c864
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveUnnecessaryStridedSlicePass.test.cpp
@@ -0,0 +1,142 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+#include "test/TestFirstNode.h"
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class StridedSliceGraphlet
+{
+public:
+ StridedSliceGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 input_shape, bool remove)
+ {
+ // Begin create
+ _begin = g->nodes()->create<luci::CircleConst>();
+ _begin->rank(1);
+ _begin->dim(0).set(input_shape.size());
+ _begin->shape_status(luci::ShapeStatus::VALID);
+ _begin->dtype(loco::DataType::S32);
+ _begin->size<loco::DataType::S32>(input_shape.size());
+ for (int i = 0; i < input_shape.size(); ++i)
+ {
+ _begin->at<loco::DataType::S32>(i) = remove ? 0 : 1;
+ }
+
+ // Strides create
+ _strides = g->nodes()->create<luci::CircleConst>();
+ _strides->rank(1);
+ _strides->dim(0).set(input_shape.size());
+ _strides->shape_status(luci::ShapeStatus::VALID);
+ _strides->dtype(loco::DataType::S32);
+ _strides->size<loco::DataType::S32>(input_shape.size());
+ for (int i = 0; i < input_shape.size(); ++i)
+ {
+ _strides->at<loco::DataType::S32>(i) = remove ? 1 : -1;
+ }
+
+ std::vector<uint32_t> shape_vector{input_shape};
+
+ _end = g->nodes()->create<luci::CircleConst>();
+ _end->rank(1);
+ _end->dim(0).set(input_shape.size());
+ _end->shape_status(luci::ShapeStatus::VALID);
+ _end->dtype(loco::DataType::S32);
+ _end->size<loco::DataType::S32>(input_shape.size());
+ for (int i = 0; i < input_shape.size(); ++i)
+ {
+ if (remove)
+ _end->at<loco::DataType::S32>(i) = static_cast<int32_t>(shape_vector.at(i));
+ else
+ _end->at<loco::DataType::S32>(i) = -1;
+ }
+
+ // StridedSlice Node create
+ _strided_slice = g->nodes()->create<luci::CircleStridedSlice>();
+ _strided_slice->dtype(loco::DataType::S32);
+ }
+
+protected:
+ luci::CircleStridedSlice *_strided_slice = nullptr;
+ luci::CircleConst *_begin = nullptr;
+ luci::CircleConst *_strides = nullptr;
+ luci::CircleConst *_end = nullptr;
+};
+
+class StridedSliceGraph : public TestIOGraph, public StridedSliceGraphlet
+{
+public:
+ StridedSliceGraph() = default;
+
+public:
+ void init(const ShapeU32 shape, bool remove)
+ {
+ TestIOGraph::init(shape, shape);
+ StridedSliceGraphlet::init(g(), shape, remove);
+
+ _strided_slice->input(input());
+ _strided_slice->begin(_begin);
+ _strided_slice->strides(_strides);
+ _strided_slice->end(_end);
+
+ output()->from(_strided_slice);
+ }
+};
+
+} // namespace
+
+TEST(RemoveUnnecessaryStridedSlicePass, basic_case)
+{
+ StridedSliceGraph g;
+
+ g.init({2, 4, 2, 3}, true);
+
+ auto strided_slice_node = luci::test::first_node<luci::CircleStridedSlice>(g.g());
+ ASSERT_NE(nullptr, strided_slice_node);
+ luci::RemoveUnnecessaryStridedSlicePass pass;
+ while (pass.run(g.g()))
+ ;
+
+ strided_slice_node = luci::test::first_node<luci::CircleStridedSlice>(g.g());
+ ASSERT_EQ(nullptr, strided_slice_node);
+}
+
+TEST(RemoveUnnecessaryStridedSlicePass, basic_fail_case_NEG)
+{
+ StridedSliceGraph g;
+
+ g.init({2, 4, 2, 3}, false);
+
+ auto strided_slice_node = luci::test::first_node<luci::CircleStridedSlice>(g.g());
+ ASSERT_NE(nullptr, strided_slice_node);
+ luci::RemoveUnnecessaryStridedSlicePass pass;
+ while (pass.run(g.g()))
+ ;
+
+ strided_slice_node = luci::test::first_node<luci::CircleStridedSlice>(g.g());
+ ASSERT_NE(nullptr, strided_slice_node);
+}
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
index 7096c2591..a0cc0194f 100644
--- a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
+++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
@@ -16,7 +16,10 @@
#include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
+#include "BatchNormPatternFinder.h"
+
#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
namespace
{
@@ -26,6 +29,9 @@ luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma)
assert(gamma->rank() == 1);
auto channel_size = gamma->dim(0).value();
+ auto name = gamma->name();
+ assert(name.length() > 0);
+
// Channel-wise MUL is the same as DEPTHWISE_CONV2D with filter shape (1,1,1,channel_size)
auto weights = gamma->graph()->nodes()->create<luci::CircleConst>();
weights->dtype(loco::DataType::FLOAT32);
@@ -40,6 +46,7 @@ luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma)
{
weights->at<loco::DataType::FLOAT32>(i) = gamma->at<loco::DataType::FLOAT32>(i);
}
+ weights->name(name + "_weights");
return weights;
}
@@ -49,6 +56,9 @@ luci::CircleConst *create_bias_from_beta(luci::CircleConst *beta)
assert(beta->rank() == 1);
auto channel_size = beta->dim(0).value();
+ auto name = beta->name();
+ assert(name.length() > 0);
+
// Channel-wise ADD is the same as bias (shape = (channel_size)) of DEPTHWISE_CONV2D
auto bias = beta->graph()->nodes()->create<luci::CircleConst>();
bias->dtype(loco::DataType::FLOAT32);
@@ -60,83 +70,11 @@ luci::CircleConst *create_bias_from_beta(luci::CircleConst *beta)
{
bias->at<loco::DataType::FLOAT32>(i) = beta->at<loco::DataType::FLOAT32>(i);
}
+ bias->name(name + "_bias");
return bias;
}
-bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::CircleConst *&beta)
-{
- auto x = loco::must_cast<luci::CircleNode *>(add->x());
- auto y = loco::must_cast<luci::CircleNode *>(add->y());
-
- luci::CircleMul *pred = nullptr;
- luci::CircleConst *constant = nullptr;
-
- if (x->opcode() == luci::CircleOpcode::CIRCLECONST && y->opcode() == luci::CircleOpcode::MUL)
- {
- pred = loco::must_cast<luci::CircleMul *>(y);
- constant = loco::must_cast<luci::CircleConst *>(x);
- }
- else if (x->opcode() == luci::CircleOpcode::MUL && y->opcode() == luci::CircleOpcode::CIRCLECONST)
- {
- pred = loco::must_cast<luci::CircleMul *>(x);
- constant = loco::must_cast<luci::CircleConst *>(y);
- }
- else
- {
- return false;
- }
-
- if (constant->rank() != 1)
- return false;
-
- auto channel_dim = constant->dim(0);
- // Assumption: Layout is channel-last
- if (!(channel_dim == add->dim(add->rank() - 1)))
- return false;
-
- mul = pred;
- beta = constant;
- return true;
-}
-
-// Check if mul is batchnorm mul
-bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node,
- luci::CircleConst *&gamma)
-{
- auto x = dynamic_cast<luci::CircleConst *>(mul->x());
- auto y = dynamic_cast<luci::CircleConst *>(mul->y());
-
- luci::CircleNode *pred = nullptr;
- luci::CircleConst *constant = nullptr;
-
- if (x != nullptr && y == nullptr)
- {
- pred = loco::must_cast<luci::CircleNode *>(mul->y());
- constant = x;
- }
- else if (x == nullptr && y != nullptr)
- {
- pred = loco::must_cast<luci::CircleNode *>(mul->x());
- constant = y;
- }
- else
- {
- return false;
- }
-
- if (constant->rank() != 1)
- return false;
-
- auto channel_dim = constant->dim(0);
- if (!(channel_dim == mul->dim(mul->rank() - 1)))
- return false;
-
- pred_node = pred;
- gamma = constant;
- return true;
-}
-
/**
* Replace channel-wise Mul/Add with DepthwiseConv2D
*
@@ -180,6 +118,9 @@ bool replace_mul_add_with_dwconv(luci::CircleAdd *add)
auto weights = create_weights_from_gamma(gamma);
auto bias = create_bias_from_beta(beta);
+ auto name = add->name();
+ assert(name.length() > 0);
+
auto dwconv = add->graph()->nodes()->create<luci::CircleDepthwiseConv2D>();
dwconv->input(pred_node);
dwconv->filter(weights);
@@ -191,6 +132,8 @@ bool replace_mul_add_with_dwconv(luci::CircleAdd *add)
dwconv->dilation()->w(1);
dwconv->dilation()->h(1);
dwconv->fusedActivationFunction(add->fusedActivationFunction());
+ dwconv->name(name + "/DepthwiseConv2D");
+ luci::add_origin(dwconv, luci::composite_origin({luci::get_origin(mul), luci::get_origin(add)}));
loco::replace(add).with(dwconv);
return true;
@@ -206,14 +149,10 @@ bool ReplaceMulAddWithDepthwiseConvPass::run(loco::Graph *g)
bool changed = false;
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
- auto add = dynamic_cast<luci::CircleAdd *>(node);
- if (not add)
- continue;
-
- if (replace_mul_add_with_dwconv(add))
+ if (auto add = dynamic_cast<luci::CircleAdd *>(node))
{
- changed = true;
- break;
+ if (replace_mul_add_with_dwconv(add))
+ changed = true;
}
}
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
index a90182aaa..903d4dcc9 100644
--- a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
+++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
@@ -85,6 +85,13 @@ public:
add->x(mul);
add->y(beta);
output->from(add);
+
+ input->name("input");
+ mul->name("mul");
+ gamma->name("gamma");
+ add->name("add");
+ beta->name("beta");
+ output->name("output");
}
public:
@@ -99,6 +106,13 @@ public:
} // namespace
+TEST(ReplaceMulAddWithDepthwiseConv, name)
+{
+ luci::ReplaceMulAddWithDepthwiseConvPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
TEST(ReplaceMulAddWithDepthwiseConv, simple)
{
SimpleGraph g;
diff --git a/compiler/luci/pass/src/RequantizePass.cpp b/compiler/luci/pass/src/RequantizePass.cpp
index fe84e3bc3..a56536251 100644
--- a/compiler/luci/pass/src/RequantizePass.cpp
+++ b/compiler/luci/pass/src/RequantizePass.cpp
@@ -113,7 +113,7 @@ void requant_const_int8_to_uint8(CircleConst *node)
struct RequantizeNonConst final : public luci::CircleNodeMutableVisitor<bool>
{
RequantizeNonConst(loco::DataType input, loco::DataType output)
- : _input_type(input), _output_type(output)
+ : _input_type(input), _output_type(output)
{
}
@@ -157,7 +157,7 @@ struct RequantizeNonConst final : public luci::CircleNodeMutableVisitor<bool>
struct RequantizeConst final : public luci::CircleNodeMutableVisitor<bool>
{
RequantizeConst(loco::DataType input, loco::DataType output)
- : _input_type(input), _output_type(output)
+ : _input_type(input), _output_type(output)
{
}
diff --git a/compiler/luci/pass/src/RequantizePass.test.cpp b/compiler/luci/pass/src/RequantizePass.test.cpp
new file mode 100644
index 000000000..d26743c9d
--- /dev/null
+++ b/compiler/luci/pass/src/RequantizePass.test.cpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RequantizePass.h"
+
+#include <gtest/gtest.h>
+
+TEST(RequantizePassTest, name)
+{
+ luci::RequantizePass pass(loco::DataType::FLOAT32, loco::DataType::U8);
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/ResolveCustomOpAddPass.cpp b/compiler/luci/pass/src/ResolveCustomOpAddPass.cpp
index e52d667d7..1737e5dd6 100644
--- a/compiler/luci/pass/src/ResolveCustomOpAddPass.cpp
+++ b/compiler/luci/pass/src/ResolveCustomOpAddPass.cpp
@@ -20,6 +20,7 @@
#include <luci/IR/CircleNodes.h>
#include <luci/IR/AttrFusedActFunc.h>
+#include <luci/Profile/CircleNodeOrigin.h>
namespace
{
@@ -67,10 +68,17 @@ bool resolve_with_BroadcastTo(luci::CircleCustom *addv2)
auto input = loco::must_cast<const luci::CircleCustomOut *>(addv2->inputs(broadcastTo_idx));
auto broadcastTo = loco::must_cast<luci::CircleCustom *>(input->input());
+ auto name = addv2->name();
+ assert(name.length() > 0);
+
auto add = addv2->graph()->nodes()->create<luci::CircleAdd>();
add->fusedActivationFunction(luci::FusedActFunc::NONE);
add->x(addv2->inputs(1 - broadcastTo_idx));
add->y(broadcastTo->inputs(0));
+ add->name(name + "/Add");
+ luci::add_origin(
+ add, luci::composite_origin({luci::get_origin(broadcastTo), luci::get_origin(addv2)}));
+
auto customOut = loco::succs(addv2);
assert(customOut.size() == 1);
replace(*customOut.begin()).with(add);
@@ -86,13 +94,39 @@ bool resolve_custom_op(luci::CircleCustom *addv2)
if (custom_code != "AddV2")
return false;
+ if (addv2->numInputs() != 2)
+ return false;
+
+ // check if inputs are suppport data types
+ for (uint32_t i = 0; i < addv2->numInputs(); i++)
+ {
+ auto input = loco::must_cast<luci::CircleNode *>(addv2->inputs(i));
+ switch (input->dtype())
+ {
+ case loco::DataType::U8:
+ case loco::DataType::S8:
+ case loco::DataType::S16:
+ case loco::DataType::S32:
+ case loco::DataType::FLOAT32:
+ break;
+ default:
+ return false;
+ }
+ }
+
if (resolve_with_BroadcastTo(addv2))
return true;
+ auto name = addv2->name();
+ assert(name.length() > 0);
+
auto add = addv2->graph()->nodes()->create<luci::CircleAdd>();
add->fusedActivationFunction(luci::FusedActFunc::NONE);
add->x(addv2->inputs(0));
add->y(addv2->inputs(1));
+ add->name(name + "/Add");
+ luci::add_origin(add, luci::get_origin(addv2));
+
auto customOut = loco::succs(addv2);
assert(customOut.size() == 1);
replace(*customOut.begin()).with(add);
@@ -115,7 +149,8 @@ bool ResolveCustomOpAddPass::run(loco::Graph *g)
if (not cop)
continue;
- changed |= resolve_custom_op(cop);
+ if (resolve_custom_op(cop))
+ changed = true;
}
return changed;
diff --git a/compiler/luci/pass/src/ResolveCustomOpAddPass.test.cpp b/compiler/luci/pass/src/ResolveCustomOpAddPass.test.cpp
new file mode 100644
index 000000000..31c245b0e
--- /dev/null
+++ b/compiler/luci/pass/src/ResolveCustomOpAddPass.test.cpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ResolveCustomOpAddPass.h"
+
+#include <gtest/gtest.h>
+
+TEST(ResolveCustomOpAddPassTest, name)
+{
+ luci::ResolveCustomOpAddPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.cpp b/compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.cpp
index 145e9cb62..5e9466a63 100644
--- a/compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.cpp
+++ b/compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.cpp
@@ -19,6 +19,7 @@
#include "flatbuffers/flexbuffers.h"
#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
namespace
{
@@ -30,6 +31,9 @@ bool resolve_custom_op(luci::CircleCustom *cop)
if (custom_code == "BatchMatMulV2")
{
+ auto name = cop->name();
+ assert(name.length() > 0);
+
auto batch_matmul = cop->graph()->nodes()->create<luci::CircleBatchMatMul>();
// input
batch_matmul->x(cop->inputs(0));
@@ -39,10 +43,16 @@ bool resolve_custom_op(luci::CircleCustom *cop)
auto map = flexbuffers::GetRoot(custom_options).AsMap();
batch_matmul->adj_x(map["adj_x"].AsBool());
batch_matmul->adj_y(map["adj_y"].AsBool());
+ batch_matmul->name(name + "/BatchMatMul");
+ luci::add_origin(batch_matmul, luci::get_origin(cop));
+
+ auto customOut = loco::succs(cop);
+ assert(customOut.size() == 1);
+ replace(*customOut.begin()).with(batch_matmul);
- replace(cop).with(batch_matmul);
return true;
}
+
return false;
}
@@ -51,6 +61,27 @@ bool resolve_custom_op(luci::CircleCustom *cop)
namespace luci
{
+/**
+ * BEFORE
+ * | |
+ * [CircleNode] [CircleNode]
+ * \ /
+ * [CircleCustom]("BatchMatMulV2")
+ * |
+ * [CircleCustomOut]
+ * |
+ * [CircleNode]
+ * |
+ *
+ * AFTER
+ * | |
+ * [CircleNode] [CircleNode]
+ * \ /
+ * [CircleBatchMatMul]
+ * |
+ * [CircleNode]
+ * |
+ */
bool ResolveCustomOpBatchMatMulPass::run(loco::Graph *g)
{
bool changed = false;
@@ -60,7 +91,8 @@ bool ResolveCustomOpBatchMatMulPass::run(loco::Graph *g)
if (not cop)
continue;
- changed |= resolve_custom_op(cop);
+ if (resolve_custom_op(cop))
+ changed = true;
}
return changed;
diff --git a/compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.test.cpp b/compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.test.cpp
new file mode 100644
index 000000000..435016f9d
--- /dev/null
+++ b/compiler/luci/pass/src/ResolveCustomOpBatchMatMulPass.test.cpp
@@ -0,0 +1,169 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/flexbuffers.h"
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+const int N = 1;
+const int C = 2;
+const int H_X = 1;
+const int W_X = 4;
+const int H_Y = 4;
+const int W_Y = 4;
+
+/**
+ * graph having Custom operator BatchMatMulV2
+ *
+ * [CircleInput] [CircleInput]
+ * \ /
+ * [CircleCustom]
+ * |
+ * [CircleCustomOut]
+ * |
+ * [CircleOutput]
+ */
+class BatchMatmulV2Graphlet
+{
+public:
+ BatchMatmulV2Graphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ // custom option
+ auto flatbuffer_builder =
+ std::unique_ptr<flatbuffers::FlatBufferBuilder>(new flatbuffers::FlatBufferBuilder(1024));
+ auto flex_buffers = std::make_unique<flexbuffers::Builder>();
+ size_t map_start = flex_buffers->StartMap();
+ flex_buffers->Bool("adj_x", false);
+ flex_buffers->Bool("adj_y", false);
+ flex_buffers->Int("T", 0 /* circle::TensorType_FLOAT32 */);
+ flex_buffers->EndMap(map_start);
+ flex_buffers->Finish();
+
+ // CircleCustom(BatchMatMulV2, adj_x=False, adj_y=False)
+ _batchmatmulv2 = g->nodes()->create<luci::CircleCustom>(2, 1);
+ _batchmatmulv2->custom_code("BatchMatMulV2");
+ _batchmatmulv2->custom_options(flex_buffers->GetBuffer());
+ _batchmatmulv2->shape({N, C, H_X, W_Y});
+ _batchmatmulv2->dtype(loco::DataType::FLOAT32);
+ _batchmatmulv2->name("batchmatmulv2");
+
+ // CircleCustomOut
+ _batchmatmulv2_out = g->nodes()->create<luci::CircleCustomOut>();
+ _batchmatmulv2_out->shape({N, C, H_X, W_Y});
+ _batchmatmulv2_out->dtype(loco::DataType::FLOAT32);
+ _batchmatmulv2_out->index(0);
+ }
+
+public:
+ luci::CircleCustom *batchmatmulv2() { return _batchmatmulv2; }
+
+protected:
+ luci::CircleCustom *_batchmatmulv2 = nullptr;
+ luci::CircleCustomOut *_batchmatmulv2_out = nullptr;
+};
+
+class BatchMatmulV2Graph : public TestIsGraphlet<2>,
+ public TestOGraphlet,
+ public BatchMatmulV2Graphlet
+{
+public:
+ BatchMatmulV2Graph() = default;
+
+ void init(void)
+ {
+ TestIsGraphlet<2>::init(g(), {{N, C, H_X, W_X}, {N, C, H_X, W_X}});
+ TestOGraphlet::init(g(), {N, C, H_X, W_Y});
+ BatchMatmulV2Graphlet::init(g());
+
+ // TODO how set multiple of shape vector for TestIsGraphlet?
+ // update shape for second input
+ input(1)->shape({N, C, H_Y, W_Y});
+
+ // connect graph
+ _batchmatmulv2->inputs(0, input(0));
+ _batchmatmulv2->inputs(1, input(1));
+ _batchmatmulv2_out->input(_batchmatmulv2);
+
+ output()->from(_batchmatmulv2_out);
+ }
+};
+
+class BatchMatmulV2GraphTest : public ::testing::Test
+{
+public:
+ BatchMatmulV2Graph g;
+ luci::ResolveCustomOpBatchMatMulPass pass;
+};
+
+} // namespace
+
+TEST(ResolveCustomOpBatchMatMulPassTest, name)
+{
+ luci::ResolveCustomOpBatchMatMulPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+/**
+ * Optimized graph looks like below.
+ *
+ * [CircleInput]
+ * |
+ * [CircleBatchMatMul]
+ * |
+ * [CircleOutput]
+ */
+TEST_F(BatchMatmulV2GraphTest, simple_test)
+{
+ g.init();
+
+ auto ret = pass.run(g.g());
+ EXPECT_EQ(true, ret);
+
+ auto batchmatmul = dynamic_cast<luci::CircleBatchMatMul *>(g.output()->from());
+ EXPECT_NE(nullptr, batchmatmul);
+
+ auto input_0 = dynamic_cast<luci::CircleInput *>(batchmatmul->x());
+ auto input_1 = dynamic_cast<luci::CircleInput *>(batchmatmul->y());
+ EXPECT_NE(nullptr, input_0);
+ EXPECT_NE(nullptr, input_1);
+}
+
+TEST_F(BatchMatmulV2GraphTest, wrong_condition_NEG)
+{
+ g.init();
+
+ // wrong custom code
+ g.batchmatmulv2()->custom_code("BatchMatMulv2"); // v is lower case
+ auto ret = pass.run(g.g());
+
+ EXPECT_EQ(false, ret);
+}
diff --git a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp
index 547fd22fc..216778066 100644
--- a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp
+++ b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp
@@ -20,11 +20,10 @@
#include <loco/IR/DataTypeTraits.h>
#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
#include <loco.h>
#include <oops/InternalExn.h>
-#include <loco/Service/ShapeInference.h>
-#include <loco/Service/TypeInference.h>
namespace
{
@@ -44,6 +43,7 @@ luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
node->dim(i) = shape.at(i);
size *= shape.at(i);
}
+ node->shape_status(luci::ShapeStatus::VALID);
#define INIT_VALUES(DT) \
{ \
@@ -90,6 +90,9 @@ bool resolve_matmul(luci::CircleCustom *cop)
const auto S32 = loco::DataType::S32;
const auto FLOAT32 = loco::DataType::FLOAT32;
+ auto name = cop->name();
+ assert(name.length() > 0);
+
bool transpose_a = map["transpose_a"].AsBool();
bool transpose_b = map["transpose_b"].AsBool();
@@ -97,34 +100,38 @@ bool resolve_matmul(luci::CircleCustom *cop)
loco::Node *rhs = cop->inputs(1);
// Check that the type of the first input is known
- CHECK_OR_FALSE(loco::dtype_known(lhs));
- auto lhs_dtype = loco::dtype_get(cop->inputs(0));
+ auto lhs_dtype = loco::must_cast<luci::CircleNode *>(cop->inputs(0))->dtype();
+ CHECK_OR_FALSE(lhs_dtype != loco::DataType::Unknown);
// If transpose of first input is requested, its shape must be known
- CHECK_OR_FALSE(!transpose_a || loco::shape_known(lhs));
+ auto circle_lhs = loco::must_cast<luci::CircleNode *>(lhs);
+ CHECK_OR_FALSE(!transpose_a || circle_lhs->shape_status() == luci::ShapeStatus::VALID);
// and its rank should be at least 2
- CHECK_OR_FALSE(!transpose_a || loco::shape_get(lhs).as<loco::TensorShape>().rank() >= 2);
+ CHECK_OR_FALSE(!transpose_a || circle_lhs->rank() >= 2);
// Check that the shape of the 2nd input is known
- CHECK_OR_FALSE(loco::shape_known(rhs));
+ auto circle_rhs = loco::must_cast<luci::CircleNode *>(rhs);
+ CHECK_OR_FALSE(circle_rhs->shape_status() == luci::ShapeStatus::VALID);
// TODO as of 06/23/20 TFLite only supports rank 2 for 2nd input. Fix this once that changes!
- CHECK_OR_FALSE(loco::shape_get(rhs).as<loco::TensorShape>().rank() == 2);
+ CHECK_OR_FALSE(circle_rhs->rank() == 2);
// Check that input data type is supported
CHECK_OR_THROW(lhs_dtype == U8 || lhs_dtype == S16 || lhs_dtype == FLOAT32,
"Only UInt8, Int16 and Float32 data types are supported by MatMul");
if (transpose_a)
{
- auto a_shape = loco::shape_get(lhs).as<loco::TensorShape>();
// Create a permutation constant node
std::vector<uint32_t> perm;
- for (uint32_t i = 0; i < a_shape.rank(); ++i)
+ for (uint32_t i = 0; i < circle_lhs->rank(); ++i)
perm.push_back(i);
- std::swap(perm[a_shape.rank() - 1], perm[a_shape.rank() - 2]);
- auto perm_node = create_const_node(graph, S32, {a_shape.rank()}, perm);
+ std::swap(perm[circle_lhs->rank() - 1], perm[circle_lhs->rank() - 2]);
+ auto perm_node = create_const_node(graph, S32, {circle_lhs->rank()}, perm);
+ perm_node->name(name + "/lhs/Transpose/perm");
// Now make a transpose node
auto transpose_node = graph->nodes()->create<luci::CircleTranspose>();
transpose_node->a(lhs);
transpose_node->perm(perm_node);
+ transpose_node->name(name + "/lhs/Transpose");
+ luci::add_origin(transpose_node, luci::get_origin(cop));
lhs = transpose_node;
}
@@ -135,24 +142,29 @@ bool resolve_matmul(luci::CircleCustom *cop)
{
const std::vector<uint32_t> perm{1, 0};
auto perm_node = create_const_node(graph, S32, {2}, perm);
+ perm_node->name(name + "/rhs/Transpose/perm");
auto transpose_node = graph->nodes()->create<luci::CircleTranspose>();
transpose_node->a(rhs);
transpose_node->perm(perm_node);
+ transpose_node->name(name + "/rhs/Transpose");
+ luci::add_origin(transpose_node, luci::get_origin(cop));
rhs = transpose_node;
}
- // Make a constant zero-filled bias node
- auto b_shape = loco::shape_get(cop->inputs(1)).as<loco::TensorShape>();
- uint32_t bias_size = b_shape.dim(transpose_b ? 1 : 0).value();
- const std::vector<float> val(bias_size, .0f);
- auto bias_node = create_const_node(graph, lhs_dtype, {bias_size}, val);
+ auto empty_bias = graph->nodes()->create<luci::CircleOutputExclude>();
+ empty_bias->dtype(loco::DataType::FLOAT32); // Needed for type inference
+
auto fc_node = graph->nodes()->create<luci::CircleFullyConnected>();
fc_node->input(lhs);
fc_node->weights(rhs);
- fc_node->bias(bias_node);
+ fc_node->bias(empty_bias);
fc_node->fusedActivationFunction(luci::FusedActFunc::NONE);
+ fc_node->name(name + "/FullyConnected");
+ luci::add_origin(fc_node, luci::get_origin(cop));
- replace(cop).with(fc_node);
+ auto customOut = loco::succs(cop);
+ assert(customOut.size() == 1);
+ replace(*customOut.begin()).with(fc_node);
return true;
}
diff --git a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.test.cpp b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.test.cpp
new file mode 100644
index 000000000..c4ea3ea06
--- /dev/null
+++ b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.test.cpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ResolveCustomOpMatMulPass.h"
+
+#include <gtest/gtest.h>
+
+TEST(ResolveCustomOpMatMulPassTest, name)
+{
+ luci::ResolveCustomOpMatMulPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/ShapeInferencePass.cpp b/compiler/luci/pass/src/ShapeInferencePass.cpp
deleted file mode 100644
index 4bd0aaed4..000000000
--- a/compiler/luci/pass/src/ShapeInferencePass.cpp
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "luci/Pass/ShapeInferencePass.h"
-
-#include <luci/IR/CircleDialect.h>
-#include <luci/Service/CircleShapeInferenceRule.h>
-
-#include <loco.h>
-#include <loco/IR/CanonicalDialect.h>
-#include <loco/Service/CanonicalShapeInferenceRule.h>
-#include <loco/Service/ShapeInference.h>
-#include <loco/Service/MultiDialectShapeInferenceRule.h>
-
-namespace luci
-{
-
-bool ShapeInferencePass::run(luci::Module *m)
-{
- bool changed = false;
-
- for (size_t g = 0; g < m->size(); ++g)
- {
- if (run(m->graph(g)))
- changed = true;
- }
-
- return changed;
-}
-
-bool ShapeInferencePass::run(loco::Graph *g)
-{
- loco::CanonicalShapeInferenceRule canonical_rule;
- luci::CircleShapeInferenceRule circle_rule;
-
- loco::MultiDialectShapeInferenceRule rules;
-
- rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
- .bind(luci::CircleDialect::get(), &circle_rule);
-
- return loco::apply(&rules).to(g);
-}
-
-} // namespace luci
diff --git a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp
index 6a58f18c5..92060f625 100644
--- a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp
+++ b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp
@@ -72,6 +72,9 @@ luci::CircleConst *shuffle_weight(luci::CircleFullyConnected *fc)
{
auto the_weights = loco::must_cast<luci::CircleConst *>(fc->weights());
+ auto name = fc->name();
+ assert(name.length() > 0);
+
// create CircleConst where shuffled data will be stored
luci::CircleConst *new_weights = fc->graph()->nodes()->create<luci::CircleConst>();
new_weights->dtype(loco::DataType::FLOAT32);
@@ -82,6 +85,7 @@ luci::CircleConst *shuffle_weight(luci::CircleFullyConnected *fc)
{
new_weights->dim(r).set(the_weights->dim(r).value());
}
+ new_weights->name(name + "/shuffle_weight");
// suffle weight
const uint32_t MULTIPLE = 16;
@@ -96,7 +100,7 @@ luci::CircleConst *shuffle_weight(luci::CircleFullyConnected *fc)
for (uint32_t i = 0; i < MULTIPLE; i++)
{
new_weights->at<loco::DataType::FLOAT32>(index++) =
- the_weights->at<loco::DataType::FLOAT32>((r * MULTIPLE + i) * cols + c);
+ the_weights->at<loco::DataType::FLOAT32>((r * MULTIPLE + i) * cols + c);
}
}
}
@@ -131,6 +135,8 @@ bool ShuffleWeightTo16x1Float32Pass::run(loco::Graph *g)
fc->weights(new_weights);
fc->weights_format(luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32);
}
+
+ changed = true;
}
return changed;
diff --git a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp
index 9745e5754..077985977 100644
--- a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp
+++ b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp
@@ -18,61 +18,86 @@
#include <luci/IR/CircleNodes.h>
+#include <luci/test/TestIOGraph.h>
+#include "test/TestFirstNode.h"
+
#include <gtest/gtest.h>
-void create_fc_net(loco::Graph *g)
+namespace
{
- assert(g);
-
- const uint32_t ROW = 16;
- const uint32_t COL = 2;
- const uint32_t elements_num = ROW * COL;
-
- // input
- auto input = g->nodes()->create<luci::CircleInput>();
- auto graph_input = g->inputs()->create();
- input->index(graph_input->index());
-
- // fc weights
- auto weights = g->nodes()->create<luci::CircleConst>();
- weights->dtype(loco::DataType::FLOAT32);
- weights->size<loco::DataType::FLOAT32>(elements_num);
- weights->rank(2);
- weights->dim(0).set(ROW);
- weights->dim(1).set(COL);
- for (uint32_t idx = 0; idx < elements_num; idx++)
+
+using namespace luci::test;
+
+class FCGraphlet
+{
+public:
+ FCGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 wshape)
{
- weights->at<loco::DataType::FLOAT32>(idx) = idx;
+ const uint32_t elements_num = num_elements(wshape);
+
+ // fc weights
+ _weights = g->nodes()->create<luci::CircleConst>();
+ _weights->dtype(loco::DataType::FLOAT32);
+ _weights->shape(wshape);
+ _weights->size<loco::DataType::FLOAT32>(elements_num);
+ for (uint32_t idx = 0; idx < elements_num; idx++)
+ {
+ _weights->at<loco::DataType::FLOAT32>(idx) = idx;
+ }
+ _weights->name("weights");
+
+ // fc
+ _fc = g->nodes()->create<luci::CircleFullyConnected>();
+ _fc->dtype(loco::DataType::FLOAT32);
+ _fc->name("fc");
}
- // fc
- auto fc = g->nodes()->create<luci::CircleFullyConnected>();
- fc->dtype(loco::DataType::FLOAT32);
- fc->input(input);
- fc->weights(weights);
-
- // output
- auto output = g->nodes()->create<luci::CircleOutput>();
- output->from(fc);
- auto graph_output = g->outputs()->create();
- output->index(graph_output->index());
-}
+protected:
+ luci::CircleFullyConnected *_fc = nullptr;
+ luci::CircleConst *_weights = nullptr;
+};
-TEST(ShuffleWeightTo16x1Float32PassTest, SimpleTest1)
+class FCGraph : public TestIGraphlet, public TestOGraphlet, public FCGraphlet
{
- auto graph = loco::make_graph();
- create_fc_net(graph.get());
+public:
+ FCGraph() = default;
- luci::CircleFullyConnected *fc_node = nullptr;
- for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ void init(const ShapeU32 shape, const ShapeU32 wshape)
{
- auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
- if (not fc)
- continue;
+ TestIGraphlet::init(g(), shape);
+ TestOGraphlet::init(g(), shape);
+ FCGraphlet::init(g(), wshape);
+
+ // connect graph
+ _fc->input(input());
+ _fc->weights(_weights);
- fc_node = fc;
- break;
+ output()->from(_fc);
}
+};
+
+} // namespace
+
+TEST(ShuffleWeightTo16x1Float32PassTest, name)
+{
+ luci::ShuffleWeightTo16x1Float32Pass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+const uint32_t ROW = 16;
+const uint32_t COL = 2;
+
+TEST(ShuffleWeightTo16x1Float32PassTest, SimpleTest1)
+{
+ FCGraph g;
+
+ g.init({ROW, COL}, {ROW, COL});
+
+ auto fc_node = luci::test::first_node<luci::CircleFullyConnected>(g.g());
ASSERT_NE(fc_node, nullptr);
auto weights = loco::must_cast<luci::CircleConst *>(fc_node->weights());
// before
@@ -94,7 +119,7 @@ TEST(ShuffleWeightTo16x1Float32PassTest, SimpleTest1)
ASSERT_EQ(15, weights->at<loco::DataType::FLOAT32>(15));
luci::ShuffleWeightTo16x1Float32Pass pass;
- while (pass.run(graph.get()))
+ while (pass.run(g.g()))
;
weights = loco::must_cast<luci::CircleConst *>(fc_node->weights());
@@ -116,3 +141,33 @@ TEST(ShuffleWeightTo16x1Float32PassTest, SimpleTest1)
ASSERT_EQ(28, weights->at<loco::DataType::FLOAT32>(14));
ASSERT_EQ(30, weights->at<loco::DataType::FLOAT32>(15));
}
+
+TEST(ShuffleWeightTo16x1Float32PassTest, invalid_weight_shape_NEG)
+{
+ FCGraph g;
+
+ g.init({ROW, COL}, {1, ROW, COL, 1});
+
+ auto fc_node = luci::test::first_node<luci::CircleFullyConnected>(g.g());
+ ASSERT_NE(fc_node, nullptr);
+
+ luci::ShuffleWeightTo16x1Float32Pass pass;
+ auto ret = pass.run(g.g());
+
+ ASSERT_FALSE(ret);
+}
+
+TEST(ShuffleWeightTo16x1Float32PassTest, invalid_weight_row16_NEG)
+{
+ FCGraph g;
+
+ g.init({COL, ROW}, {COL, ROW});
+
+ auto fc_node = luci::test::first_node<luci::CircleFullyConnected>(g.g());
+ ASSERT_NE(fc_node, nullptr);
+
+ luci::ShuffleWeightTo16x1Float32Pass pass;
+ auto ret = pass.run(g.g());
+
+ ASSERT_FALSE(ret);
+}
diff --git a/compiler/luci/pass/src/Sparsifier.cpp b/compiler/luci/pass/src/Sparsifier.cpp
index 210c1a34c..18ab45f98 100644
--- a/compiler/luci/pass/src/Sparsifier.cpp
+++ b/compiler/luci/pass/src/Sparsifier.cpp
@@ -26,8 +26,8 @@ Sparsifier<T>::Sparsifier(const std::vector<int32_t> &shape,
const std::vector<DimensionType> &format,
const std::vector<int32_t> &block_size,
const std::vector<int32_t> &block_map)
- : _dense_shape(shape), _traversal_order(traversal_order), _block_size(block_size),
- _block_map(block_map)
+ : _dense_shape(shape), _traversal_order(traversal_order), _block_size(block_size),
+ _block_map(block_map)
{
_dense_size = 1;
int32_t block_dim = 0;
diff --git a/compiler/luci/pass/src/Sparsifier.test.cpp b/compiler/luci/pass/src/Sparsifier.test.cpp
index 272e0e934..14e24aad7 100644
--- a/compiler/luci/pass/src/Sparsifier.test.cpp
+++ b/compiler/luci/pass/src/Sparsifier.test.cpp
@@ -190,6 +190,6 @@ TEST(SparsifierTest, WrongFormatRank_NEG)
const std::vector<int32_t> block_size = {4, 1};
const std::vector<int32_t> block_map = {0, 1};
EXPECT_THROW(
- luci::Sparsifier<int32_t>(dense_shape, traversal_order, format, block_size, block_map),
- std::out_of_range);
+ luci::Sparsifier<int32_t>(dense_shape, traversal_order, format, block_size, block_map),
+ std::out_of_range);
}
diff --git a/compiler/luci/pass/src/SparsifyTensorPass.cpp b/compiler/luci/pass/src/SparsifyTensorPass.cpp
index 2f1a36e77..1a75bfb0c 100644
--- a/compiler/luci/pass/src/SparsifyTensorPass.cpp
+++ b/compiler/luci/pass/src/SparsifyTensorPass.cpp
@@ -69,11 +69,11 @@ template <loco::DataType DT> void SparsifyTensorPass::sparsify_tensor(luci::Circ
else if (_format.at(idx) == DimensionType::SPARSE_CSR)
{
sparsityparam->dim_metadata.emplace_back(
- DimensionType::SPARSE_CSR, /* dense size */ 0,
- /* array_segments */ SparseIndexVector{SparseIndexVectorType::U16,
- dim_metadata.at(idx * 2)},
- /* array_indices */ SparseIndexVector{SparseIndexVectorType::U16,
- dim_metadata.at(idx * 2 + 1)});
+ DimensionType::SPARSE_CSR, /* dense size */ 0,
+ /* array_segments */
+ SparseIndexVector{SparseIndexVectorType::U16, dim_metadata.at(idx * 2)},
+ /* array_indices */
+ SparseIndexVector{SparseIndexVectorType::U16, dim_metadata.at(idx * 2 + 1)});
}
}
for (uint32_t i = 0; i < _block_size.size(); i++)
diff --git a/compiler/luci/pass/src/SparsifyTensorPass.test.cpp b/compiler/luci/pass/src/SparsifyTensorPass.test.cpp
new file mode 100644
index 000000000..372e8e5ca
--- /dev/null
+++ b/compiler/luci/pass/src/SparsifyTensorPass.test.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/SparsifyTensorPass.h"
+
+#include <gtest/gtest.h>
+
+TEST(SparsifyTensorPassTest, name)
+{
+ std::vector<int32_t> to;
+ std::vector<luci::DimensionType> vdt;
+ std::vector<int32_t> bs;
+ std::vector<int32_t> bm;
+ luci::SparsifyTensorPass pass("", to, vdt, bs, bm);
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
diff --git a/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp b/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp
index 44e974b91..d8676cd62 100644
--- a/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp
+++ b/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp
@@ -17,10 +17,22 @@
#include "luci/Pass/SubstitutePackToReshapePass.h"
#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
namespace
{
+int32_t unknown_dim_count(luci::CircleNode *node)
+{
+ int32_t count = 0;
+
+ for (uint32_t i = 0; i < node->rank(); ++i)
+ if (!node->dim(i).known())
+ ++count;
+
+ return count;
+}
+
bool substitute_pack_to_reshape(luci::CircleNode *node)
{
auto target_node = dynamic_cast<luci::CirclePack *>(node);
@@ -35,9 +47,14 @@ bool substitute_pack_to_reshape(luci::CircleNode *node)
if (axis < 0)
axis = axis + static_cast<int32_t>(value_node->rank()) + 1;
+ auto name = node->name();
+ assert(name.length() > 0);
+
auto graph = target_node->graph();
auto reshape_node = graph->nodes()->create<luci::CircleReshape>();
reshape_node->tensor(value_node);
+ reshape_node->name(name + "/Reshape");
+ luci::add_origin(reshape_node, luci::get_origin(node));
auto const_node = graph->nodes()->create<luci::CircleConst>();
const_node->dtype(loco::DataType::S32);
@@ -53,13 +70,16 @@ bool substitute_pack_to_reshape(luci::CircleNode *node)
}
else if (i < axis)
{
- const_node->at<loco::DataType::S32>(i) = value_node->dim(i).value();
+ const_node->at<loco::DataType::S32>(i) =
+ value_node->dim(i).known() ? value_node->dim(i).value() : -1;
}
else
{
- const_node->at<loco::DataType::S32>(i) = value_node->dim(i - 1).value();
+ const_node->at<loco::DataType::S32>(i) =
+ value_node->dim(i - 1).known() ? value_node->dim(i - 1).value() : -1;
}
}
+ const_node->name(name + "/Reshape/shape");
reshape_node->shape(const_node);
replace(target_node).with(reshape_node);
return true;
@@ -71,24 +91,23 @@ namespace luci
{
/**
- * BEFORE
- * |
- * [CircleNode]
- * |
- * [CirclePack]
- * |
- * [CircleNode]
- * |
+ * BEFORE
+ * |
+ * [CircleNode]
+ * |
+ * [CirclePack]
+ * |
+ * [CircleNode]
+ * |
*
- * AFTER
- * |
- * [CircleNode] [CircleConst]
- * \ /
- * [CircleReshape]
+ * AFTER
* |
- * [CircleNode]
- * |
- *
+ * [CircleNode] [CircleConst]
+ * | \ /
+ * [CirclePack] [CircleReshape]
+ * |
+ * [CircleNode]
+ * |
*/
bool SubstitutePackToReshapePass::run(loco::Graph *g)
{
@@ -96,7 +115,7 @@ bool SubstitutePackToReshapePass::run(loco::Graph *g)
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- if (substitute_pack_to_reshape(circle_node))
+ if (unknown_dim_count(circle_node) <= 1 && substitute_pack_to_reshape(circle_node))
{
changed = true;
}
diff --git a/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp b/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp
index 143b88896..3b5d4ea2c 100644
--- a/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp
+++ b/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp
@@ -22,26 +22,6 @@
namespace
{
-/**
- * BEFORE
- * |
- * [CircleNode]
- * |
- * [CirclePack]
- * |
- * [CircleNode]
- * |
- *
- * AFTER
- * |
- * [CircleNode] [CircleConst]
- * \ /
- * [CircleReshape]
- * |
- * [CircleNode]
- * |
- *
- */
void create_substitute_pack_to_reshape(loco::Graph *g, const std::initializer_list<uint32_t> shape,
int32_t axis)
{
@@ -54,23 +34,33 @@ void create_substitute_pack_to_reshape(loco::Graph *g, const std::initializer_li
input->shape_status(luci::ShapeStatus::VALID);
input->rank(shape.size());
input->shape(shape);
+ input->name("input");
// Pack Node create.
auto pack = g->nodes()->create<luci::CirclePack>(1);
pack->values(0, input);
pack->axis(axis);
+ pack->name("pack");
// Output Connect.
auto output = g->nodes()->create<luci::CircleOutput>();
output->from(pack);
auto graph_output = g->outputs()->create();
output->index(graph_output->index());
+ output->name("output");
return;
}
} // namespace
+TEST(SubstitutePackToReshapePassTest, name)
+{
+ luci::SubstitutePackToReshapePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
TEST(SubstitutePackToReshapePass, simple_case)
{
auto graph = loco::make_graph();
diff --git a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp
new file mode 100644
index 000000000..74be86a4c
--- /dev/null
+++ b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp
@@ -0,0 +1,183 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/SubstituteSqueezeToReshapePass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+
+namespace
+{
+
+/**
+ * @brief return TRUE if all dim is known
+ * @note This pass can be applied even some of dimensions are unknown.
+ For now, do not consider about it and update logic later.
+ */
+bool can_squeeze_shape(const luci::CircleNode *node)
+{
+ for (uint32_t r = 0; r < node->rank(); ++r)
+ {
+ if (not node->dim(r).known())
+ return false;
+ }
+ return true;
+}
+
+/**
+ * @brief return valid unsigned dim value from 0 ~ (rank-1)
+ * @note dim can be -rank to (rank-1)
+ */
+uint32_t valid_unsigned_dim(uint32_t rank, int32_t dim)
+{
+ int32_t irank = static_cast<int32_t>(rank);
+ return dim >= 0 ? static_cast<uint32_t>(dim) : static_cast<uint32_t>(irank + dim);
+}
+
+/**
+ * @brief return TRUE if input dim is 1 for squeeze_dims values
+ */
+bool is_valid_input(const luci::CircleNode *node, const std::vector<int32_t> &squeeze_dims)
+{
+ auto rank = node->rank();
+ for (auto dim : squeeze_dims)
+ {
+ auto udim = valid_unsigned_dim(rank, dim);
+ if (node->dim(udim).value() != 1)
+ return false;
+ }
+ return true;
+}
+
+/**
+ * @brief return shape vector from input
+ */
+std::vector<uint32_t> node_shape(const luci::CircleNode *input)
+{
+ std::vector<uint32_t> shape;
+ uint32_t rank = input->rank();
+ for (uint32_t r = 0; r < rank; ++r)
+ shape.push_back(input->dim(r).value());
+
+ return shape;
+}
+
+/**
+ * @brief return CircleConst ptr with values of new_shape
+ */
+luci::CircleConst *create_shape_const(loco::Graph *graph, const std::vector<uint32_t> &new_shape)
+{
+ // NOTE dim_size can be 0
+ uint32_t dim_size = static_cast<uint32_t>(new_shape.size());
+
+ auto shape_const = graph->nodes()->create<luci::CircleConst>();
+
+ // const shape/dtype
+ shape_const->dtype(loco::DataType::S32);
+ if (dim_size > 0)
+ {
+ shape_const->rank(1);
+ shape_const->dim(0).set(dim_size);
+ }
+ else
+ shape_const->rank(0);
+ shape_const->shape_status(luci::ShapeStatus::VALID);
+
+ // constant values
+ shape_const->size<loco::DataType::S32>(dim_size);
+ for (uint32_t i = 0; i < dim_size; ++i)
+ shape_const->at<loco::DataType::S32>(i) = new_shape.at(i);
+
+ return shape_const;
+}
+
+bool substitute_squeeze_to_reshape(luci::CircleSqueeze *squeeze)
+{
+ assert(squeeze != nullptr);
+
+ auto input = loco::must_cast<luci::CircleNode *>(squeeze->input());
+ // we need input node shape and all dim should be known
+ if (input->shape_status() != luci::ShapeStatus::VALID)
+ return false;
+ if (not can_squeeze_shape(input))
+ return false;
+
+ // we will use squeeze shape for new shape
+ if (squeeze->shape_status() != luci::ShapeStatus::VALID)
+ return false;
+
+ auto squeeze_dims = squeeze->squeeze_dims();
+ if (not is_valid_input(input, squeeze_dims))
+ throw std::runtime_error("Invalid values in squeeze_dims: " + squeeze->name());
+
+ auto name = squeeze->name();
+ assert(name.length() > 0);
+
+ auto reshape_shape = node_shape(squeeze);
+ auto graph = squeeze->graph();
+ auto reshape = graph->nodes()->create<luci::CircleReshape>();
+ auto shape_const = create_shape_const(graph, reshape_shape);
+ reshape->name(name + "/Reshape");
+ luci::add_origin(reshape, luci::get_origin(squeeze));
+ shape_const->name(name + "/Reshape/shape");
+
+ // graph connection
+ reshape->tensor(input);
+ reshape->shape(shape_const);
+ replace(squeeze).with(reshape);
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * BEFORE
+ * |
+ * [CircleNode]
+ * |
+ * [CircleSqueeze]
+ * |
+ * [CircleNode]
+ * |
+ *
+ * AFTER
+ * |
+ * [CircleNode] [CircleConst]
+ * | \ /
+ * [CircleSqueeze] [CircleReshape]
+ * |
+ * [CircleNode]
+ * |
+ */
+bool SubstituteSqueezeToReshapePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto squeeze = dynamic_cast<luci::CircleSqueeze *>(node))
+ {
+ if (substitute_squeeze_to_reshape(squeeze))
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.test.cpp b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.test.cpp
new file mode 100644
index 000000000..d917af678
--- /dev/null
+++ b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.test.cpp
@@ -0,0 +1,208 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "luci/Pass/SubstituteSqueezeToReshapePass.h"
+#include "luci/Pass/CircleShapeInferencePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using uilist = std::initializer_list<uint32_t>;
+using ilist = std::initializer_list<int32_t>;
+
+class PassTestGraph
+{
+public:
+ PassTestGraph() = default;
+
+public:
+ void init(const uilist shape_in, const uilist shape_out)
+ {
+ _graph_input = _g.inputs()->create();
+ _graph_output = _g.outputs()->create();
+
+ _input = _g.nodes()->create<luci::CircleInput>();
+ _input->shape(shape_in);
+ _input->shape_status(luci::ShapeStatus::VALID);
+ _input->name("input");
+
+ _output = _g.nodes()->create<luci::CircleOutput>();
+ _output->shape(shape_out);
+ _output->shape_status(luci::ShapeStatus::VALID);
+ _output->name("output");
+
+ _input->index(_graph_input->index());
+ _output->index(_graph_output->index());
+
+ auto input_shape = std::make_unique<loco::TensorShape>();
+ set(input_shape.get(), shape_in);
+ _graph_input->shape(std::move(input_shape));
+
+ auto output_shape = std::make_unique<loco::TensorShape>();
+ set(output_shape.get(), shape_out);
+ _graph_output->shape(std::move(output_shape));
+ }
+
+protected:
+ void set(loco::TensorShape *shape, const uilist &values)
+ {
+ uint32_t r = 0;
+ shape->rank(values.size());
+ for (auto v : values)
+ shape->dim(r++).set(v);
+ }
+
+public:
+ loco::Graph *g(void) { return &_g; }
+ luci::CircleOutput *output(void) { return _output; }
+
+protected:
+ loco::Graph _g;
+ loco::GraphInput *_graph_input = nullptr;
+ loco::GraphOutput *_graph_output = nullptr;
+ luci::CircleInput *_input = nullptr;
+ luci::CircleOutput *_output = nullptr;
+};
+
+class SubstituteSqueezeToReshapeGraph : public PassTestGraph
+{
+public:
+ SubstituteSqueezeToReshapeGraph() = default;
+
+public:
+ void init(const uilist shape_in, const uilist shape_out, const ilist squeeze_dims)
+ {
+ PassTestGraph::init(shape_in, shape_out);
+
+ _squeeze = _g.nodes()->create<luci::CircleSqueeze>();
+ _squeeze->input(_input);
+ _squeeze->squeeze_dims(squeeze_dims);
+ _squeeze->name("squeeze");
+
+ _output->from(_squeeze);
+ }
+
+protected:
+ luci::CircleSqueeze *_squeeze = nullptr;
+};
+
+class SubstituteSqueezeToReshapeTest : public ::testing::Test
+{
+public:
+ SubstituteSqueezeToReshapeTest() = default;
+
+ void run_pass(void)
+ {
+ while (_shapeinf.run(_graph.g()) || _pass.run(_graph.g()))
+ ;
+ }
+
+protected:
+ SubstituteSqueezeToReshapeGraph _graph;
+ luci::SubstituteSqueezeToReshapePass _pass;
+ luci::CircleShapeInferencePass _shapeinf;
+};
+
+} // namespace
+
+TEST(SubstituteSqueezeToReshapePassTest, name)
+{
+ luci::SubstituteSqueezeToReshapePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(SubstituteSqueezeToReshapeTest, simple_with_squeeze_dims)
+{
+ _graph.init({1, 16, 1, 1}, {1, 16}, {2, 3});
+
+ run_pass();
+
+ auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
+ auto squeeze = dynamic_cast<luci::CircleSqueeze *>(_graph.output()->from());
+ ASSERT_NE(nullptr, reshape);
+ ASSERT_EQ(nullptr, squeeze);
+ auto reshape_shape = loco::must_cast<luci::CircleConst *>(reshape->shape());
+ ASSERT_EQ(2, reshape_shape->size<loco::DataType::S32>());
+ ASSERT_EQ(1, reshape_shape->at<loco::DataType::S32>(0));
+ ASSERT_EQ(16, reshape_shape->at<loco::DataType::S32>(1));
+}
+
+TEST_F(SubstituteSqueezeToReshapeTest, simple_without_squeeze_dims)
+{
+ _graph.init({1, 16, 1, 1}, {16}, {});
+
+ run_pass();
+
+ auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
+ auto squeeze = dynamic_cast<luci::CircleSqueeze *>(_graph.output()->from());
+ ASSERT_NE(nullptr, reshape);
+ ASSERT_EQ(nullptr, squeeze);
+ auto reshape_shape = loco::must_cast<luci::CircleConst *>(reshape->shape());
+ ASSERT_EQ(1, reshape_shape->size<loco::DataType::S32>());
+ ASSERT_EQ(16, reshape_shape->at<loco::DataType::S32>(0));
+}
+
+TEST_F(SubstituteSqueezeToReshapeTest, input_with_0_dims)
+{
+ _graph.init({1, 16, 0, 1}, {16, 0}, {});
+
+ run_pass();
+
+ auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
+ auto squeeze = dynamic_cast<luci::CircleSqueeze *>(_graph.output()->from());
+ ASSERT_NE(nullptr, reshape);
+ ASSERT_EQ(nullptr, squeeze);
+ auto reshape_shape = loco::must_cast<luci::CircleConst *>(reshape->shape());
+ ASSERT_EQ(2, reshape_shape->size<loco::DataType::S32>());
+ ASSERT_EQ(16, reshape_shape->at<loco::DataType::S32>(0));
+ ASSERT_EQ(0, reshape_shape->at<loco::DataType::S32>(1));
+}
+
+TEST_F(SubstituteSqueezeToReshapeTest, nothing_to_squeeze)
+{
+ _graph.init({2, 16, 16, 3}, {2, 16, 16, 3}, {});
+
+ run_pass();
+
+ auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
+ auto squeeze = dynamic_cast<luci::CircleSqueeze *>(_graph.output()->from());
+ ASSERT_NE(nullptr, reshape);
+ ASSERT_EQ(nullptr, squeeze);
+}
+
+TEST_F(SubstituteSqueezeToReshapeTest, all_to_squeeze)
+{
+ _graph.init({1, 1}, {}, {});
+
+ run_pass();
+
+ auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
+ auto squeeze = dynamic_cast<luci::CircleSqueeze *>(_graph.output()->from());
+ ASSERT_NE(nullptr, reshape);
+ ASSERT_EQ(nullptr, squeeze);
+}
+
+TEST_F(SubstituteSqueezeToReshapeTest, wrong_squeeze_dims_NEG)
+{
+ _graph.init({1, 16, 1, 1}, {1, 16, 1, 1}, {1});
+
+ // shape inference will throw for invalid squeeze_dims
+ EXPECT_THROW(run_pass(), std::exception);
+}
diff --git a/compiler/luci/pass/src/SubstituteTransposeToReshapePass.cpp b/compiler/luci/pass/src/SubstituteTransposeToReshapePass.cpp
new file mode 100644
index 000000000..dfd5e6cf2
--- /dev/null
+++ b/compiler/luci/pass/src/SubstituteTransposeToReshapePass.cpp
@@ -0,0 +1,137 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/SubstituteTransposeToReshapePass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+
+namespace
+{
+
+/**
+ * @brief Convert transpose op in a certain condition to reshape op
+ * @details Convert transpose op if it have condition below
+ * 1. have a CircleConst perm value.
+ * 2. input have an unknown dimension less then 2
+ * 3. the order of shape that except dim value 1 remains same on input and output
+ * eg) input shape = (126, 201, 1, 1) => (126, 201)
+ * output shape = (1, 126, 1, 201) => (126, 201)
+ */
+bool substitute_transpose_to_reshape(luci::CircleTranspose *node)
+{
+ auto perm_const = dynamic_cast<luci::CircleConst *>(node->perm());
+ if (perm_const == nullptr)
+ return false;
+
+ assert(perm_const->dtype() == loco::DataType::S32);
+
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->a());
+ if (perm_const->dim(0).value() != input_node->rank())
+ return false;
+
+ // If input have more than 2 unknown dimension, transpose will not be changed.
+ int count = 0;
+ for (uint32_t i = 0; i < input_node->rank(); i++)
+ if (!input_node->dim(i).known())
+ count++;
+ if (count > 1)
+ return false;
+
+ uint32_t idx = 0;
+ auto size_items = perm_const->size<loco::DataType::S32>();
+ for (uint32_t i = 0; i < size_items; i++)
+ {
+ assert(perm_const->at<loco::DataType::S32>(i) >= 0 &&
+ perm_const->at<loco::DataType::S32>(i) < static_cast<int32_t>(input_node->rank()));
+ const auto perm_value = static_cast<uint32_t>(perm_const->at<loco::DataType::S32>(i));
+ if (input_node->dim(perm_value).known() && input_node->dim(perm_value).value() == 1)
+ continue;
+ // To check idx values are increasing
+ if (idx > perm_value)
+ return false;
+ idx = perm_value;
+ }
+
+ auto name = node->name();
+ assert(name.length() > 0);
+
+ auto new_const_node = node->graph()->nodes()->create<luci::CircleConst>();
+ new_const_node->dtype(loco::DataType::S32);
+ new_const_node->size<loco::DataType::S32>(size_items);
+ new_const_node->shape_status(luci::ShapeStatus::VALID);
+ new_const_node->rank(1);
+ new_const_node->dim(0).set(size_items);
+ for (uint32_t i = 0; i < size_items; i++)
+ {
+ if (input_node->dim(static_cast<uint32_t>(perm_const->at<loco::DataType::S32>(i))).known())
+ new_const_node->at<loco::DataType::S32>(i) = static_cast<int32_t>(
+ input_node->dim(static_cast<uint32_t>(perm_const->at<loco::DataType::S32>(i))).value());
+ else
+ new_const_node->at<loco::DataType::S32>(i) = -1;
+ }
+
+ auto new_reshape_node = node->graph()->nodes()->create<luci::CircleReshape>();
+ new_reshape_node->tensor(input_node);
+ new_reshape_node->shape(new_const_node);
+ new_reshape_node->name(name + "/Reshape");
+ luci::add_origin(new_reshape_node, luci::get_origin(node));
+ new_const_node->name(name + "/Reshape/shape");
+
+ replace(node).with(new_reshape_node);
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * BEFORE
+ *
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleTranspose]
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ *
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleReshape]
+ * |
+ * [CircleNode]
+ *
+ */
+bool SubstituteTransposeToReshapePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto circle_node = dynamic_cast<luci::CircleTranspose *>(node))
+ {
+ if (substitute_transpose_to_reshape(circle_node))
+ {
+ changed = true;
+ }
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/SubstituteTransposeToReshapePass.test.cpp b/compiler/luci/pass/src/SubstituteTransposeToReshapePass.test.cpp
new file mode 100644
index 000000000..f81f7e615
--- /dev/null
+++ b/compiler/luci/pass/src/SubstituteTransposeToReshapePass.test.cpp
@@ -0,0 +1,120 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "luci/Pass/SubstituteTransposeToReshapePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+class SubstituteTransposeToReshapeTest : public ::testing::Test
+{
+public:
+ SubstituteTransposeToReshapeTest() {}
+
+ void buildGraph(const std::initializer_list<uint32_t> shape, const std::vector<int32_t> perm)
+ {
+ // Input Create.
+ input = g.nodes()->create<luci::CircleInput>();
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ input->shape_status(luci::ShapeStatus::VALID);
+ input->rank(shape.size());
+ input->shape(shape);
+ input->name("input");
+
+ // Permutation Create.
+ auto perm_const = g.nodes()->create<luci::CircleConst>();
+ perm_const->dtype(loco::DataType::S32);
+ perm_const->size<loco::DataType::S32>(perm.size());
+ perm_const->shape_status(luci::ShapeStatus::VALID);
+ perm_const->rank(1);
+ perm_const->dim(0).set(perm.size());
+ for (uint32_t i = 0; i < static_cast<uint32_t>(perm.size()); i++)
+ {
+ perm_const->at<loco::DataType::S32>(i) = perm.at(i);
+ }
+ perm_const->name("perm_const");
+
+ // Transpose Create.
+ auto transpose_node = g.nodes()->create<luci::CircleTranspose>();
+ transpose_node->a(input);
+ transpose_node->perm(perm_const);
+ transpose_node->name("transpose_node");
+
+ // Output Connect.
+ output = g.nodes()->create<luci::CircleOutput>();
+ output->from(transpose_node);
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+ output->name("output");
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+} // namespace
+
+TEST(SubstituteTransposeToReshapePassTest, name)
+{
+ luci::SubstituteTransposeToReshapePass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(SubstituteTransposeToReshapeTest, simple_case)
+{
+ // Create graph that tranpose input {126, 201, 1, 1} with permutation {2, 0, 3, 1}
+ buildGraph({126, 201, 1, 1}, std::vector<int32_t>({2, 0, 3, 1}));
+ // With this input shape and permutation values, output shape will be [1, 126, 1, 201].
+ // The order of non-one values is unchanged (126, 201).
+ // So this Transpose op can be converted to Reshape op.
+ luci::SubstituteTransposeToReshapePass pass;
+ while (pass.run(&g))
+ ;
+
+ auto reshape_node = dynamic_cast<luci::CircleReshape *>(output->from());
+ auto transpose_node = dynamic_cast<luci::CircleTranspose *>(output->from());
+ ASSERT_NE(nullptr, reshape_node);
+ ASSERT_EQ(nullptr, transpose_node);
+ auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape());
+ ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(0));
+ ASSERT_EQ(126, new_shape->at<loco::DataType::S32>(1));
+ ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(2));
+ ASSERT_EQ(201, new_shape->at<loco::DataType::S32>(3));
+}
+
+TEST_F(SubstituteTransposeToReshapeTest, failed_to_substitute_NEG)
+{
+ // Create graph that tranpose input {126, 201, 1, 1} with permutation {2, 1, 3, 0}
+ buildGraph({126, 201, 1, 1}, std::vector<int32_t>({2, 1, 3, 0}));
+ // With this input shape and permutation values, output shape will be [1, 201, 1, 126].
+ // The order of non-one values is changed (126, 201) -> (201, 126).
+ // So this Transpose op cannot be converted to Reshape op.
+ luci::SubstituteTransposeToReshapePass pass;
+ while (pass.run(&g))
+ ;
+
+ auto reshape_node = dynamic_cast<luci::CircleReshape *>(output->from());
+ auto transpose_node = dynamic_cast<luci::CircleTranspose *>(output->from());
+ ASSERT_EQ(nullptr, reshape_node);
+ ASSERT_NE(nullptr, transpose_node);
+}
diff --git a/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.cpp b/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.cpp
new file mode 100644
index 000000000..c15a3b676
--- /dev/null
+++ b/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.cpp
@@ -0,0 +1,134 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/TransformMinMaxToRelu6Pass.h"
+
+#include "helpers/NodeFiller.h"
+#include "helpers/TypeMapper.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+
+namespace
+{
+
+template <loco::DataType DT>
+bool is_scalar_with_value(luci::CircleConst *node, typename loco::DataTypeImpl<DT>::Type val)
+{
+ if (node->dtype() != DT)
+ return false;
+ if (node->rank() != 0)
+ return false;
+ if (node->size<DT>() != 1)
+ return false;
+ if (node->at<DT>(0) != static_cast<typename loco::DataTypeImpl<DT>::Type>(val))
+ return false;
+
+ return true;
+}
+
+/**
+ * BEFORE
+ * [CircleNode]
+ * |
+ * [CircleMinimum]
+ * |
+ * [CircleMaximum]
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ *
+ * [CircleNode]
+ * |
+ * [CircleRelu6]
+ * |
+ * [CircleNode]
+ *
+ * NOTE Only max(min(input, 6), 0) pattern will be transformed.
+ */
+template <loco::DataType DT> bool transform_min_max_pattern(luci::CircleMaximum *maxi)
+{
+ if (not maxi)
+ return false;
+
+ if (maxi->dtype() != DT)
+ return false;
+
+ luci::CircleConst *maxi_const = nullptr;
+ luci::CircleMinimum *mini = nullptr;
+
+ // There are two ways Maximum takes inputs.
+ // 1. Maximum(x = CircleConst, y = CircleMinimum)
+ // 2. Maximum(x = CircleMinimum, y = CircleConst)
+ if (not luci::fill(&maxi_const, &mini).with_commutative_args_of(maxi))
+ return false;
+
+ // Maximum constant should be scalar whose value is 0.
+ if (not is_scalar_with_value<DT>(maxi_const,
+ static_cast<typename loco::DataTypeImpl<DT>::Type>(0)))
+ return false;
+
+ luci::CircleConst *mini_const = nullptr;
+ loco::Node *mini_input = nullptr;
+
+ // There are two ways Miminum takes inputs.
+ // 1. Miminum(x = CircleNode, y = CircleMinimum)
+ // 2. Miminum(x = CircleMinimum, y = CircleNode)
+ if (not luci::fill(&mini_const, &mini_input).with_commutative_args_of(mini))
+ return false;
+
+ // Miminum constant should be scalar whose value is 6.
+ if (not is_scalar_with_value<DT>(mini_const,
+ static_cast<typename loco::DataTypeImpl<DT>::Type>(6)))
+ return false;
+
+ auto name = maxi->name();
+ assert(name.length() > 0);
+
+ // Create Relu6 op
+ auto relu6 = mini->graph()->nodes()->create<luci::CircleRelu6>();
+ relu6->features(mini_input);
+ relu6->name(name + "/Relu6");
+ luci::add_origin(relu6, luci::composite_origin({luci::get_origin(maxi), luci::get_origin(mini)}));
+
+ replace(maxi).with(relu6);
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool TransformMinMaxToRelu6Pass::run(loco::Graph *g)
+{
+ bool changed = false;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto maxi = dynamic_cast<luci::CircleMaximum *>(node))
+ {
+ if (transform_min_max_pattern<loco::DataType::FLOAT32>(maxi))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.test.cpp b/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.test.cpp
new file mode 100644
index 000000000..9755a70cf
--- /dev/null
+++ b/compiler/luci/pass/src/TransformMinMaxToRelu6Pass.test.cpp
@@ -0,0 +1,151 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/TransformMinMaxToRelu6Pass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ * Minimum-Maximum pattern graph
+ *
+ * [CircleInput] [CircleConst]
+ * \ /
+ * [CircleMinimum] [CircleConst]
+ * | /
+ * [CircleMaximum]
+ * |
+ * [CircleOutput]
+ */
+struct MinMaxGraph
+{
+ loco::Graph _g;
+ luci::CircleInput *_input = nullptr;
+ luci::CircleMinimum *_mini = nullptr;
+ luci::CircleConst *_mini_const = nullptr;
+ luci::CircleMaximum *_maxi = nullptr;
+ luci::CircleConst *_maxi_const = nullptr;
+ luci::CircleOutput *_output = nullptr;
+};
+
+class TransformMinMaxToRelu6PassTest : public ::testing::Test
+{
+protected:
+ virtual void SetUp()
+ {
+ const int N = 1;
+ const int H = 4;
+ const int W = 4;
+ const int C = 3;
+
+ // graph input and output
+ auto graph_input = _min_max_g._g.inputs()->create();
+ auto graph_output = _min_max_g._g.outputs()->create();
+
+ // CircleInput
+ _min_max_g._input = _min_max_g._g.nodes()->create<luci::CircleInput>();
+ _min_max_g._input->index(graph_input->index());
+ _min_max_g._input->shape({N, H, W, C});
+ _min_max_g._input->dtype(loco::DataType::FLOAT32);
+ _min_max_g._input->name("input");
+
+ // CircleConst
+ _min_max_g._mini_const = _min_max_g._g.nodes()->create<luci::CircleConst>();
+ _min_max_g._mini_const->shape({}); // scalar
+ _min_max_g._mini_const->dtype(loco::DataType::FLOAT32);
+ _min_max_g._mini_const->size<loco::DataType::FLOAT32>(1);
+ _min_max_g._mini_const->at<loco::DataType::FLOAT32>(0) = 6.;
+ _min_max_g._mini_const->name("mini_const");
+
+ // CircleMinimum
+ _min_max_g._mini = _min_max_g._g.nodes()->create<luci::CircleMinimum>();
+ _min_max_g._mini->x(_min_max_g._input);
+ _min_max_g._mini->y(_min_max_g._mini_const);
+ _min_max_g._mini->shape({N, H, W, C});
+ _min_max_g._mini->dtype(loco::DataType::FLOAT32);
+ _min_max_g._mini->name("mini");
+
+ // CircleConst
+ _min_max_g._maxi_const = _min_max_g._g.nodes()->create<luci::CircleConst>();
+ _min_max_g._mini_const->shape({}); // scalar
+ _min_max_g._maxi_const->dtype(loco::DataType::FLOAT32);
+ _min_max_g._maxi_const->size<loco::DataType::FLOAT32>(1);
+ _min_max_g._maxi_const->at<loco::DataType::FLOAT32>(0) = 0.;
+ _min_max_g._maxi_const->name("maxi_const");
+
+ // CircleMaximum
+ _min_max_g._maxi = _min_max_g._g.nodes()->create<luci::CircleMaximum>();
+ _min_max_g._maxi->x(_min_max_g._mini);
+ _min_max_g._maxi->y(_min_max_g._maxi_const);
+ _min_max_g._maxi->shape({N, H, W, C});
+ _min_max_g._maxi->dtype(loco::DataType::FLOAT32);
+ _min_max_g._maxi->name("maxi");
+
+ // CircleOutput
+ _min_max_g._output = _min_max_g._g.nodes()->create<luci::CircleOutput>();
+ _min_max_g._output->index(graph_output->index());
+ _min_max_g._output->from(_min_max_g._maxi);
+ _min_max_g._output->shape({N, H, W, C});
+ _min_max_g._output->dtype(loco::DataType::FLOAT32);
+ _min_max_g._output->name("output");
+ }
+
+protected:
+ luci::TransformMinMaxToRelu6Pass _pass;
+ MinMaxGraph _min_max_g;
+};
+
+} // namespace
+
+TEST_F(TransformMinMaxToRelu6PassTest, name)
+{
+ auto const name = _pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+/**
+ * Optimized graph looks like below.
+ *
+ * [CircleInput]
+ * |
+ * [CircleRelu6]
+ * |
+ * [CircleOutput]
+ */
+TEST_F(TransformMinMaxToRelu6PassTest, simple_test)
+{
+ auto ret = _pass.run(&_min_max_g._g);
+ EXPECT_TRUE(ret);
+
+ auto relu6 = dynamic_cast<luci::CircleRelu6 *>(_min_max_g._output->from());
+ EXPECT_NE(nullptr, relu6);
+
+ auto input = dynamic_cast<luci::CircleInput *>(relu6->features());
+ EXPECT_NE(nullptr, input);
+}
+
+TEST_F(TransformMinMaxToRelu6PassTest, wrong_condition_NEG)
+{
+ _min_max_g._maxi_const->at<loco::DataType::FLOAT32>(0) = 2.;
+
+ auto ret = _pass.run(&_min_max_g._g);
+
+ EXPECT_FALSE(ret);
+}
diff --git a/compiler/luci/pass/src/TypeInferencePass.cpp b/compiler/luci/pass/src/TypeInferencePass.cpp
deleted file mode 100644
index 63744045c..000000000
--- a/compiler/luci/pass/src/TypeInferencePass.cpp
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "luci/Pass/TypeInferencePass.h"
-
-#include <luci/IR/CircleDialect.h>
-#include <luci/Service/CircleTypeInferenceRule.h>
-
-#include <loco.h>
-#include <loco/IR/CanonicalDialect.h>
-#include <loco/Service/TypeInference.h>
-
-namespace luci
-{
-
-bool TypeInferencePass::run(luci::Module *m)
-{
- bool changed = false;
-
- for (size_t g = 0; g < m->size(); ++g)
- {
- if (run(m->graph(g)))
- changed = true;
- }
-
- return changed;
-}
-
-bool TypeInferencePass::run(loco::Graph *g)
-{
- loco::CanonicalTypeInferenceRule canonical_rule;
- luci::CircleTypeInferenceRule circle_rule;
-
- loco::MultiDialectTypeInferenceRule rules;
-
- rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
- .bind(luci::CircleDialect::get(), &circle_rule);
-
- return loco::apply(&rules).to(g);
-}
-
-} // namespace luci
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h
new file mode 100644
index 000000000..32f0d1a34
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeChannelWiseGranularity.h
@@ -0,0 +1,401 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
+#define __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Pass/QuantizationParameters.h>
+
+using Granularity = luci::QuantizationGranularity;
+
+// This macro is undef at the end of the file
+#define RETURN_FALSE_UNLESS(ARG) \
+ if (not(ARG)) \
+ { \
+ return false; \
+ }
+
+namespace luci
+{
+
+/**
+ * @brief Verify the granualrity of channel-wise quantized node
+ * @details
+ *
+ * Targets to verify
+ * - node's output (i.e., node itself)
+ * - node's inputs
+ */
+struct VerifyQuantizedNodeChannelWiseGranularity final : public luci::CircleNodeVisitor<bool>
+{
+private:
+ bool is_lwq(const loco::Node *node)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+
+ if (circle_node->quantparam() == nullptr)
+ return false;
+
+ if (circle_node->quantparam()->scale.size() != 1)
+ return false;
+
+ if (circle_node->quantparam()->zerop.size() != 1)
+ return false;
+
+ return true;
+ }
+
+ uint32_t rank(const loco::Node *node)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ return circle_node->rank();
+ }
+
+ bool is_cwq_const(const loco::Node *node, uint32_t channel_dim)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
+
+ assert(channel_dim < circle_node->rank()); // FIX_CALLER_UNLESS
+ auto channel_size = circle_node->dim(channel_dim).value();
+
+ if (circle_node->quantparam() == nullptr)
+ return false;
+
+ if (circle_node->quantparam()->quantized_dimension != static_cast<int32_t>(channel_dim))
+ return false;
+
+ if (circle_node->quantparam()->scale.size() != channel_size)
+ return false;
+
+ if (circle_node->quantparam()->zerop.size() != channel_size)
+ return false;
+
+ return true;
+ }
+
+private:
+ bool visit(const luci::CircleConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CircleConcatenation *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ for (uint32_t i = 0; i < node->numValues(); i++)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
+ }
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthToSpace *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthwiseConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 3))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CircleInstanceNorm *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->gamma(), rank(node->gamma()) - 1))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->beta(), rank(node->beta()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CirclePad *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ return true;
+ }
+
+ bool visit(const luci::CirclePRelu *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->alpha(), rank(node->alpha()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CircleTransposeConv *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->filter(), 0))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+
+ return true;
+ }
+
+ bool visit(const luci::CircleFullyConnected *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->weights(), 0))
+ RETURN_FALSE_UNLESS(is_cwq_const(node->bias(), rank(node->bias()) - 1))
+ return true;
+ }
+
+ bool visit(const luci::CircleAdd *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleAveragePool2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->value()));
+ return true;
+ }
+
+ bool visit(const luci::CircleLogicalOr *)
+ {
+ // Logical OR has bool-type inputs and output
+ // Nothing to be checked
+ return true;
+ }
+
+ bool visit(const luci::CircleMaxPool2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->value()));
+ return true;
+ }
+
+ bool visit(const luci::CircleMean *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleMul *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleNotEqual *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleRelu *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->features()));
+ return true;
+ }
+
+ bool visit(const luci::CircleReshape *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->tensor()));
+ return true;
+ }
+
+ bool visit(const luci::CircleLogistic *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ return true;
+ }
+
+ bool visit(const luci::CircleSoftmax *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->logits()));
+ return true;
+ }
+
+ bool visit(const luci::CircleSpaceToBatchND *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleSpaceToDepth *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleSlice *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleSplit *node)
+ {
+ // node's output is the input of CircleSplitOut, thus not quantized
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleSplitOut *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ return true;
+ }
+
+ bool visit(const luci::CircleStridedSlice *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleArgMax *node)
+ {
+ // node's output is index, thus not quantized
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleBatchToSpaceND *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleTanh *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ return true;
+ }
+
+ bool visit(const luci::CircleTranspose *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->a()));
+ return true;
+ }
+
+ bool visit(const luci::CircleFloor *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ return true;
+ }
+
+ bool visit(const luci::CircleGreater *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleGreaterEqual *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleDiv *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleFloorDiv *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleRsqrt *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ return true;
+ }
+
+ bool visit(const luci::CircleSqrt *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ return true;
+ }
+
+ bool visit(const luci::CircleElu *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->features()));
+ return true;
+ }
+
+ bool visit(const luci::CirclePow *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleResizeBilinear *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ // TODO: Implement more Ops
+
+ bool visit(const luci::CircleNode *) { return true; }
+};
+
+} // namespace luci
+
+#undef RETURN_FALSE_UNLESS
+
+#endif // __LUCI_VERIFY_QUANTIZED_NODE_CHANNELWISE_GRANULARITY_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h
new file mode 100644
index 000000000..1e6fd53c0
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeLayerWiseGranularity.h
@@ -0,0 +1,388 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
+#define __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Pass/QuantizationParameters.h>
+
+using Granularity = luci::QuantizationGranularity;
+
+// This macro is undef at the end of the file
+#define RETURN_FALSE_UNLESS(ARG) \
+ if (not(ARG)) \
+ { \
+ return false; \
+ }
+
+namespace luci
+{
+
+/**
+ * @brief Verify the granualrity of layer-wise quantized node
+ * @details
+ *
+ * Targets to verify
+ * - node's output (i.e., node itself)
+ * - node's inputs
+ */
+struct VerifyQuantizedNodeLayerWiseGranularity final : public luci::CircleNodeVisitor<bool>
+{
+private:
+ bool is_lwq(const loco::Node *node)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+
+ if (circle_node->quantparam() == nullptr)
+ return false;
+
+ if (circle_node->quantparam()->scale.size() != 1)
+ return false;
+
+ if (circle_node->quantparam()->zerop.size() != 1)
+ return false;
+
+ return true;
+ }
+
+ bool is_lwq_const(const loco::Node *node)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleConst *>(node);
+
+ if (circle_node->quantparam() == nullptr)
+ return false;
+
+ if (circle_node->quantparam()->scale.size() != 1)
+ return false;
+
+ if (circle_node->quantparam()->zerop.size() != 1)
+ return false;
+
+ return true;
+ }
+
+private:
+ bool visit(const luci::CircleConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+
+ bool visit(const luci::CircleConcatenation *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ for (uint32_t i = 0; i < node->numValues(); i++)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node->values(i)));
+ }
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthToSpace *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthwiseConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+
+ bool visit(const luci::CircleInstanceNorm *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->gamma()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->beta()))
+ return true;
+ }
+
+ bool visit(const luci::CirclePad *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ return true;
+ }
+
+ bool visit(const luci::CirclePRelu *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->alpha()))
+ return true;
+ }
+
+ bool visit(const luci::CircleTransposeConv *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->outBackprop()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->filter()))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+
+ bool visit(const luci::CircleFullyConnected *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->weights()))
+ RETURN_FALSE_UNLESS(is_lwq_const(node->bias()))
+ return true;
+ }
+
+ bool visit(const luci::CircleAdd *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleAveragePool2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->value()));
+ return true;
+ }
+
+ bool visit(const luci::CircleLogicalOr *)
+ {
+ // Logical OR has bool-type inputs and output
+ // Nothing to be checked
+ return true;
+ }
+
+ bool visit(const luci::CircleMaxPool2D *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->value()));
+ return true;
+ }
+
+ bool visit(const luci::CircleMean *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleMul *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleNotEqual *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleRelu *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->features()));
+ return true;
+ }
+
+ bool visit(const luci::CircleReshape *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node))
+ RETURN_FALSE_UNLESS(is_lwq(node->tensor()));
+ return true;
+ }
+
+ bool visit(const luci::CircleLogistic *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ return true;
+ }
+
+ bool visit(const luci::CircleSoftmax *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->logits()));
+ return true;
+ }
+
+ bool visit(const luci::CircleSpaceToBatchND *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleSpaceToDepth *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleSlice *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleSplit *node)
+ {
+ // node's output is the input of CircleSplitOut, thus not quantized
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleSplitOut *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ return true;
+ }
+
+ bool visit(const luci::CircleStridedSlice *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleArgMax *node)
+ {
+ // node's output is index, thus not quantized
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleBatchToSpaceND *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ bool visit(const luci::CircleTanh *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ return true;
+ }
+
+ bool visit(const luci::CircleTranspose *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->a()));
+ return true;
+ }
+
+ bool visit(const luci::CircleFloor *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ return true;
+ }
+
+ bool visit(const luci::CircleGreater *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleGreaterEqual *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleDiv *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleFloorDiv *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleRsqrt *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ return true;
+ }
+
+ bool visit(const luci::CircleSqrt *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ return true;
+ }
+
+ bool visit(const luci::CircleElu *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->features()));
+ return true;
+ }
+
+ bool visit(const luci::CirclePow *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->x()));
+ RETURN_FALSE_UNLESS(is_lwq(node->y()));
+ return true;
+ }
+
+ bool visit(const luci::CircleResizeBilinear *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
+ // TODO: Implement more Ops
+
+ bool visit(const luci::CircleNode *) { return true; }
+};
+
+} // namespace luci
+
+#undef RETURN_FALSE_UNLESS
+
+#endif // __LUCI_VERIFY_QUANTIZED_NODE_LAYERWISE_GRANULARITY_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h b/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h
new file mode 100644
index 000000000..e05d8325f
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeS16Type.h
@@ -0,0 +1,375 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_VERIFY_QUANTIZED_NODE_S16_TYPE_H__
+#define __LUCI_VERIFY_QUANTIZED_NODE_S16_TYPE_H__
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+using Type = loco::DataType;
+
+// This macro is undef at the end of the file
+#define RETURN_FALSE_UNLESS(ARG) \
+ if (not(ARG)) \
+ { \
+ return false; \
+ }
+
+namespace luci
+{
+
+/**
+ * @brief Verify the data type of INT16 quantized node
+ * @details
+ *
+ * Targets to verify
+ * - node's output (i.e., node itself)
+ * - node's inputs
+ */
+struct VerifyQuantizedNodeS16Type final : public luci::CircleNodeVisitor<bool>
+{
+private:
+ bool has_type(const loco::Node *node, Type dtype)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ return circle_node->dtype() == dtype;
+ }
+
+private:
+ bool visit(const luci::CircleConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64))
+ return true;
+ }
+
+ bool visit(const luci::CircleConcatenation *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ for (uint32_t i = 0; i < node->numValues(); i++)
+ {
+ RETURN_FALSE_UNLESS(has_type(node->values(i), Type::S16))
+ }
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthToSpace *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthwiseConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64))
+ return true;
+ }
+
+ bool visit(const luci::CircleInstanceNorm *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->gamma(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->beta(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CirclePad *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CirclePRelu *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->alpha(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleTransposeConv *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Type::S16))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(has_type(bias, Type::S64))
+ return true;
+ }
+
+ bool visit(const luci::CircleFullyConnected *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->weights(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S64))
+ return true;
+ }
+
+ bool visit(const luci::CircleAdd *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleAveragePool2D *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleLogicalOr *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::BOOL))
+ return true;
+ }
+
+ bool visit(const luci::CircleMaxPool2D *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->value(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleMean *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleMul *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleNotEqual *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleRelu *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->features(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleReshape *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::S16))
+ luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
+ if (shape != nullptr)
+ RETURN_FALSE_UNLESS(has_type(shape, Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleLogistic *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleSoftmax *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->logits(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleSpaceToBatchND *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleSpaceToDepth *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleSlice *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->begin(), Type::S32) || has_type(node->begin(), Type::S64))
+ RETURN_FALSE_UNLESS(has_type(node->size(), Type::S32) || has_type(node->size(), Type::S64))
+ return true;
+ }
+
+ bool visit(const luci::CircleSplit *node)
+ {
+ // node's output is the input of CircleSplitOut, thus not quantized
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleSplitOut *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleStridedSlice *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleArgMax *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, node->output_type()))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->dimension(), Type::S32) ||
+ has_type(node->dimension(), Type::S64))
+ return true;
+ }
+
+ bool visit(const luci::CircleBatchToSpaceND *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleTanh *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleTranspose *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->a(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->perm(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleFloor *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleGreater *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleGreaterEqual *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleDiv *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleFloorDiv *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleRsqrt *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleSqrt *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleElu *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->features(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CirclePow *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::S16))
+ return true;
+ }
+
+ bool visit(const luci::CircleResizeBilinear *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::S16))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::S16))
+ return true;
+ }
+
+ // TODO: Implement more Ops
+
+ bool visit(const luci::CircleNode *) { return true; }
+};
+
+} // namespace luci
+
+#undef RETURN_FALSE_UNLESS
+
+#endif // __LUCI_VERIFY_QUNTIZED_NODE_S16_TYPE_H__
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h b/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h
new file mode 100644
index 000000000..72ce5b8f8
--- /dev/null
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeU8Type.h
@@ -0,0 +1,375 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_VERIFY_QUANTIZED_NODE_U8_TYPE_H__
+#define __LUCI_VERIFY_QUANTIZED_NODE_U8_TYPE_H__
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+
+using Type = loco::DataType;
+
+// This macro is undef at the end of the file
+#define RETURN_FALSE_UNLESS(ARG) \
+ if (not(ARG)) \
+ { \
+ return false; \
+ }
+
+namespace luci
+{
+
+/**
+ * @brief Verify the data type of UINT8 quantized node
+ * @details
+ *
+ * Targets to verify
+ * - node's output (i.e., node itself)
+ * - node's inputs
+ */
+struct VerifyQuantizedNodeU8Type final : public luci::CircleNodeVisitor<bool>
+{
+private:
+ bool has_type(const loco::Node *node, Type dtype)
+ {
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ return circle_node->dtype() == dtype;
+ }
+
+private:
+ bool visit(const luci::CircleConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleConcatenation *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ for (uint32_t i = 0; i < node->numValues(); i++)
+ {
+ RETURN_FALSE_UNLESS(has_type(node->values(i), Type::U8))
+ }
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthToSpace *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleDepthwiseConv2D *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleInstanceNorm *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->gamma(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->beta(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CirclePad *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->paddings(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CirclePRelu *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->alpha(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleTransposeConv *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->outBackprop(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->filter(), Type::U8))
+ luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ if (bias != nullptr)
+ RETURN_FALSE_UNLESS(has_type(bias, Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleFullyConnected *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->weights(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->bias(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleAdd *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleAveragePool2D *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleBatchToSpaceND *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleLogicalOr *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::BOOL))
+ return true;
+ }
+
+ bool visit(const luci::CircleMaxPool2D *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->value(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleMean *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleMul *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleNotEqual *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleRelu *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->features(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleReshape *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->tensor(), Type::U8))
+ luci::CircleConst *shape = dynamic_cast<luci::CircleConst *>(node->shape());
+ if (shape != nullptr)
+ RETURN_FALSE_UNLESS(has_type(shape, Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleLogistic *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleSoftmax *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->logits(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleSpaceToBatchND *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleSpaceToDepth *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleSlice *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->begin(), Type::S32) || has_type(node->begin(), Type::S64))
+ RETURN_FALSE_UNLESS(has_type(node->size(), Type::S32) || has_type(node->size(), Type::S64))
+ return true;
+ }
+
+ bool visit(const luci::CircleSplit *node)
+ {
+ // node's output is the input of CircleSplitOut, thus not quantized
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleSplitOut *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleStridedSlice *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleArgMax *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, node->output_type()))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->dimension(), Type::S32) ||
+ has_type(node->dimension(), Type::S64))
+ return true;
+ }
+
+ bool visit(const luci::CircleTanh *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleTranspose *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->a(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->perm(), Type::S32))
+ return true;
+ }
+
+ bool visit(const luci::CircleFloor *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleGreater *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleGreaterEqual *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::BOOL))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleDiv *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleFloorDiv *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleRsqrt *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleSqrt *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleElu *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->features(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CirclePow *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->x(), Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->y(), Type::U8))
+ return true;
+ }
+
+ bool visit(const luci::CircleResizeBilinear *node)
+ {
+ RETURN_FALSE_UNLESS(has_type(node, Type::U8))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Type::U8))
+ return true;
+ }
+
+ // TODO: Implement more Ops
+
+ bool visit(const luci::CircleNode *) { return true; }
+};
+
+} // namespace luci
+
+#undef RETURN_FALSE_UNLESS
+
+#endif // __LUCI_VERIFY_QUNTIZED_NODE_U8_TYPE_H__
diff --git a/compiler/luci/pass/src/helpers/InferenceCandidates.cpp b/compiler/luci/pass/src/helpers/InferenceCandidates.cpp
new file mode 100644
index 000000000..2c8565932
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/InferenceCandidates.cpp
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "InferenceCandidates.h"
+
+#include <luci/IR/DeadNodeQueryService.h>
+
+namespace luci
+{
+
+std::vector<loco::Node *> inference_candidates(loco::Graph *g)
+{
+ auto candidates = loco::postorder_traversal(loco::output_nodes(g));
+
+ for (auto node : loco::all_nodes(g))
+ {
+ // already included as candidate
+ if (std::find(candidates.begin(), candidates.end(), node) != candidates.end())
+ continue;
+
+ // As the node is not used for both graph output and multiple output operation,
+ // it cannot be candidate.
+ if (node->dialect()->service<DeadNodeQueryServiceImpl>()->isDeadNode(node))
+ continue;
+
+ candidates.emplace_back(node);
+ }
+
+ return candidates;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/helpers/InferenceCandidates.h b/compiler/luci/pass/src/helpers/InferenceCandidates.h
new file mode 100644
index 000000000..f27e4fe60
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/InferenceCandidates.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_INFERENCE_CANDIDATES_H__
+#define __LUCI_INFERENCE_CANDIDATES_H__
+
+#include <loco.h>
+
+#include <vector>
+
+namespace luci
+{
+
+/**
+ * @brief Enumerate all the nodes whose shape/dtype should be inferenced to export graph.
+ */
+std::vector<loco::Node *> inference_candidates(loco::Graph *g);
+
+} // namespace luci
+
+#endif // __LUCI_INFERENCE_CANDIDATES_H__
diff --git a/compiler/luci/pass/src/helpers/InferenceCandidates.test.cpp b/compiler/luci/pass/src/helpers/InferenceCandidates.test.cpp
new file mode 100644
index 000000000..e34421f5e
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/InferenceCandidates.test.cpp
@@ -0,0 +1,122 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "InferenceCandidates.h"
+#include "luci/IR/CircleNode.h"
+
+#include <algorithm>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+bool contains(const std::vector<loco::Node *> &vec, loco::Node *val)
+{
+ return std::any_of(vec.begin(), vec.end(), [val](loco::Node *node) { return node == val; });
+}
+
+} // namespace
+
+TEST(LuciPassHelpersInferenceCandidates, inference_candidates)
+{
+ auto g = loco::make_graph();
+
+ // Create nodes
+ auto input = g->nodes()->create<luci::CircleInput>();
+ auto split = g->nodes()->create<luci::CircleSplit>();
+ auto split_out1 = g->nodes()->create<luci::CircleSplitOut>();
+ auto split_out2 = g->nodes()->create<luci::CircleSplitOut>();
+ auto split_dim = g->nodes()->create<luci::CircleConst>();
+ auto output = g->nodes()->create<luci::CircleOutput>();
+
+ // Build up initial graph
+ auto graph_input1 = g->inputs()->create();
+ input->index(graph_input1->index());
+
+ split->split_dim(split_dim);
+ split->input(input);
+ split->num_split(2);
+
+ split_out1->input(split);
+ split_out1->index(0);
+
+ split_out2->input(split);
+ split_out2->index(1);
+
+ auto graph_output = g->outputs()->create();
+ output->from(split_out1);
+ output->index(graph_output->index());
+
+ auto s = luci::inference_candidates(g.get());
+
+ ASSERT_EQ(6, s.size());
+ ASSERT_TRUE(contains(s, input));
+ ASSERT_TRUE(contains(s, split));
+ ASSERT_TRUE(contains(s, split_out1));
+ ASSERT_TRUE(contains(s, split_out2));
+ ASSERT_TRUE(contains(s, split_dim));
+ ASSERT_TRUE(contains(s, output));
+}
+
+TEST(LuciPassHelpersInferenceCandidates, inference_candidates_NEG)
+{
+ auto g = loco::make_graph();
+
+ // Create nodes
+ auto input = g->nodes()->create<luci::CircleInput>();
+ auto split = g->nodes()->create<luci::CircleSplit>();
+ auto split_out1 = g->nodes()->create<luci::CircleSplitOut>();
+ auto split_out2 = g->nodes()->create<luci::CircleSplitOut>();
+ auto split_dim = g->nodes()->create<luci::CircleConst>();
+ auto relu1 = g->nodes()->create<luci::CircleRelu>();
+ auto relu2 = g->nodes()->create<luci::CircleRelu>();
+ auto output = g->nodes()->create<luci::CircleOutput>();
+
+ // Build up initial graph
+ auto graph_input1 = g->inputs()->create();
+ input->index(graph_input1->index());
+
+ split->split_dim(split_dim);
+ split->input(input);
+ split->num_split(2);
+
+ split_out1->input(split);
+ split_out1->index(0);
+
+ split_out2->input(split);
+ split_out2->index(1);
+
+ relu1->features(split_out2);
+
+ relu2->features(input);
+
+ auto graph_output = g->outputs()->create();
+ output->from(split_out1);
+ output->index(graph_output->index());
+
+ auto s = luci::inference_candidates(g.get());
+
+ ASSERT_EQ(6, s.size());
+ ASSERT_TRUE(contains(s, input));
+ ASSERT_TRUE(contains(s, split));
+ ASSERT_TRUE(contains(s, split_out1));
+ ASSERT_TRUE(contains(s, split_out2));
+ ASSERT_TRUE(contains(s, split_dim));
+ ASSERT_TRUE(contains(s, output));
+ ASSERT_FALSE(contains(s, relu1));
+ ASSERT_FALSE(contains(s, relu2));
+}
diff --git a/compiler/luci/pass/src/helpers/NodeFiller.cpp b/compiler/luci/pass/src/helpers/NodeFiller.cpp
new file mode 100644
index 000000000..b1416655d
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/NodeFiller.cpp
@@ -0,0 +1,20 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "NodeFiller.h"
+
+// NOTE Do NOT delete this file; this file enforces compiler to check whether 'NodeFiller.h' is
+// complete.
diff --git a/compiler/luci/pass/src/helpers/NodeFiller.h b/compiler/luci/pass/src/helpers/NodeFiller.h
new file mode 100644
index 000000000..b80f085b0
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/NodeFiller.h
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+namespace luci
+{
+
+/**
+ * INTRODUCTION
+ * Binary operation f(x,y) is 'commutative' when
+ * f(x,y) == f(y,x) holds for all x, y.
+ * For examples, ADD, MUL and SQUARED_DIFFERENCE are commutative.
+ * These helpers make it easy to find commutative arguments of commutative node.
+ *
+ * HOW TO USE
+ * COMM_NODE *node;
+ * ARG_TYPE_1 *arg1;
+ * ARG_TYPE_2 *arg2;
+ *
+ * bool ok = fill(&arg1, &arg2).with_commutative_args_of(node);
+ *
+ * Result
+ * If 'node's commutative argument types are actually {ARG_TYPE_1, ARG_TYPE_2}
+ * (as a set), 'arg1' and 'arg2' set as actual 'node's arguments with matching
+ * type, and return value 'ok' is true.
+ * Otherwise, 'arg1' and 'arg2' not changed, 'ok' is false.
+ */
+
+template <class ARG_TYPE_1, class ARG_TYPE_2> class NodeFiller final
+{
+public:
+ NodeFiller(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2) : _arg_1(arg_1), _arg_2(arg_2)
+ {
+ // DO NOTHING
+ }
+
+ /**
+ * @return true When 'node's argument types are 'ARG_TYPE_1' and 'ARG_TYPE_2'
+ * In such case, it assign '_arg_1' and '_arg_2' to actual arguments
+ *
+ * @return false When 'node's argument types are NOT matched with 'ARG_TYPE_*'
+ * In such case, it does not amend '_arg_1' and '_arg_2'
+ *
+ * @require COMM_NODE has member x() and y()
+ */
+ template <class COMM_NODE> bool with_commutative_args_of(const COMM_NODE *node);
+
+private:
+ ARG_TYPE_1 **_arg_1;
+ ARG_TYPE_2 **_arg_2;
+};
+
+template <class ARG_TYPE_1, class ARG_TYPE_2>
+inline NodeFiller<ARG_TYPE_1, ARG_TYPE_2> fill(ARG_TYPE_1 **arg_1, ARG_TYPE_2 **arg_2)
+{
+ return NodeFiller<ARG_TYPE_1, ARG_TYPE_2>{arg_1, arg_2};
+}
+
+template <class ARG_TYPE_1, class ARG_TYPE_2>
+template <class COMM_NODE>
+bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_commutative_args_of(const COMM_NODE *node)
+{
+ // Case 1) X == ARG_TYPE_1 / Y == ARG_TYPE_2
+ {
+ auto x = dynamic_cast<ARG_TYPE_1 *>(node->x());
+ auto y = dynamic_cast<ARG_TYPE_2 *>(node->y());
+
+ if (x && y)
+ {
+ *_arg_1 = x;
+ *_arg_2 = y;
+ return true;
+ }
+ }
+
+ // Case 2) X == ARG_TYPE_2 / Y == ARG_TYPE_1
+ {
+ auto x = dynamic_cast<ARG_TYPE_2 *>(node->x());
+ auto y = dynamic_cast<ARG_TYPE_1 *>(node->y());
+
+ if (x && y)
+ {
+ *_arg_1 = y;
+ *_arg_2 = x;
+ return true;
+ }
+ }
+
+ return false;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/helpers/NodeFiller.test.cpp b/compiler/luci/pass/src/helpers/NodeFiller.test.cpp
new file mode 100644
index 000000000..9bbc7f264
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/NodeFiller.test.cpp
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+#include "NodeFiller.h"
+
+TEST(NodeFillerTest, simple_test)
+{
+ luci::CircleConst maxi_const;
+ luci::CircleMinimum mini;
+ luci::CircleMaximum maxi;
+ maxi.x(&maxi_const);
+ maxi.y(&mini);
+
+ luci::CircleConst *x = nullptr;
+ luci::CircleMinimum *y = nullptr;
+
+ EXPECT_TRUE(luci::fill(&x, &y).with_commutative_args_of(&maxi));
+ EXPECT_TRUE(x == &maxi_const);
+ EXPECT_TRUE(y == &mini);
+
+ x = nullptr;
+ y = nullptr;
+
+ EXPECT_TRUE(luci::fill(&y, &x).with_commutative_args_of(&maxi));
+ EXPECT_TRUE(x == &maxi_const);
+ EXPECT_TRUE(y == &mini);
+}
+
+TEST(NodeFillerTest, wrong_condition_NEG)
+{
+ luci::CircleConst add_const;
+ luci::CircleMinimum mini;
+ luci::CircleAdd add;
+ add.x(&add_const);
+ add.y(&mini);
+
+ luci::CircleMul *x = nullptr;
+ luci::CircleMinimum *y = nullptr;
+
+ EXPECT_FALSE(luci::fill(&x, &y).with_commutative_args_of(&add));
+ EXPECT_FALSE(luci::fill(&y, &x).with_commutative_args_of(&add));
+}
diff --git a/compiler/luci/pass/src/helpers/Strings.cpp b/compiler/luci/pass/src/helpers/Strings.cpp
new file mode 100644
index 000000000..d020f6ddc
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/Strings.cpp
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Strings.h"
+
+#include <algorithm>
+
+namespace luci
+{
+
+bool in_array(const std::string &str, const std::vector<std::string> &array)
+{
+ return std::find(array.begin(), array.end(), str) != array.end();
+}
+
+std::string to_string(const std::vector<std::string> &strings)
+{
+ assert(!strings.empty());
+
+ std::string res;
+ for (unsigned int i = 0; i < strings.size() - 1; i++)
+ res += strings[i] + ", ";
+
+ res += strings[strings.size() - 1];
+ return res;
+}
+
+std::string to_lower_case(std::string s)
+{
+ std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
+ return s;
+}
+
+loco::DataType str_to_dtype(const std::string &str)
+{
+ if (to_lower_case(str).compare("uint8") == 0)
+ return loco::DataType::U8;
+ if (to_lower_case(str).compare("uint16") == 0)
+ return loco::DataType::U16;
+ if (to_lower_case(str).compare("uint32") == 0)
+ return loco::DataType::U32;
+ if (to_lower_case(str).compare("uint64") == 0)
+ return loco::DataType::U64;
+
+ if (to_lower_case(str).compare("int8") == 0)
+ return loco::DataType::S8;
+ if (to_lower_case(str).compare("int16") == 0)
+ return loco::DataType::S16;
+ if (to_lower_case(str).compare("int32") == 0)
+ return loco::DataType::S32;
+ if (to_lower_case(str).compare("int64") == 0)
+ return loco::DataType::S64;
+
+ if (to_lower_case(str).compare("float16") == 0)
+ return loco::DataType::FLOAT16;
+ if (to_lower_case(str).compare("float32") == 0)
+ return loco::DataType::FLOAT32;
+ if (to_lower_case(str).compare("float64") == 0)
+ return loco::DataType::FLOAT64;
+
+ if (to_lower_case(str).compare("bool") == 0)
+ return loco::DataType::BOOL;
+
+ return loco::DataType::Unknown;
+}
+
+QuantizationGranularity str_to_granularity(const std::string &str)
+{
+ if (to_lower_case(str).compare("layer") == 0)
+ return QuantizationGranularity::LayerWise;
+
+ if (to_lower_case(str).compare("channel") == 0)
+ return QuantizationGranularity::ChannelWise;
+
+ throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'");
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/helpers/Strings.h b/compiler/luci/pass/src/helpers/Strings.h
new file mode 100644
index 000000000..793d137fb
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/Strings.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PASS_HELPERS_STRINGS_H__
+#define __LUCI_PASS_HELPERS_STRINGS_H__
+
+#include "luci/Pass/QuantizationParameters.h"
+
+#include <loco.h>
+
+#include <vector>
+#include <sstream>
+#include <string>
+
+namespace luci
+{
+
+bool in_array(const std::string &, const std::vector<std::string> &);
+
+std::string to_string(const std::vector<std::string> &);
+
+std::string to_lower_case(std::string);
+
+loco::DataType str_to_dtype(const std::string &);
+
+QuantizationGranularity str_to_granularity(const std::string &);
+
+template <typename T> std::vector<T> csv_to_vector(const std::string &str)
+{
+ std::vector<T> ret;
+ std::istringstream is(str);
+ for (T i; is >> i;)
+ {
+ assert(i != ',');
+ ret.push_back(i);
+ if (is.peek() == ',')
+ is.ignore();
+ }
+ return ret;
+}
+
+} // namespace luci
+
+#endif // __LUCI_PASS_HELPERS_STRINGS_H__
diff --git a/compiler/luci/pass/src/helpers/Strings.test.cpp b/compiler/luci/pass/src/helpers/Strings.test.cpp
new file mode 100644
index 000000000..f6bb48951
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/Strings.test.cpp
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Strings.h"
+
+#include "luci/Pass/QuantizationParameters.h"
+
+#include <gtest/gtest.h>
+
+TEST(StringsTest, str_to_dtype)
+{
+ ASSERT_EQ(loco::DataType::U8, luci::str_to_dtype("uint8"));
+ ASSERT_EQ(loco::DataType::U16, luci::str_to_dtype("uint16"));
+ ASSERT_EQ(loco::DataType::U32, luci::str_to_dtype("uint32"));
+ ASSERT_EQ(loco::DataType::U64, luci::str_to_dtype("uint64"));
+
+ ASSERT_EQ(loco::DataType::S8, luci::str_to_dtype("int8"));
+ ASSERT_EQ(loco::DataType::S16, luci::str_to_dtype("int16"));
+ ASSERT_EQ(loco::DataType::S32, luci::str_to_dtype("int32"));
+ ASSERT_EQ(loco::DataType::S64, luci::str_to_dtype("int64"));
+
+ ASSERT_EQ(loco::DataType::FLOAT16, luci::str_to_dtype("float16"));
+ ASSERT_EQ(loco::DataType::FLOAT32, luci::str_to_dtype("float32"));
+ ASSERT_EQ(loco::DataType::FLOAT64, luci::str_to_dtype("float64"));
+
+ ASSERT_EQ(loco::DataType::BOOL, luci::str_to_dtype("bool"));
+
+ ASSERT_EQ(loco::DataType::Unknown, luci::str_to_dtype("foo"));
+}
+
+TEST(StringsTest, str_to_granularity)
+{
+ ASSERT_EQ(luci::QuantizationGranularity::LayerWise, luci::str_to_granularity("layer"));
+ ASSERT_EQ(luci::QuantizationGranularity::ChannelWise, luci::str_to_granularity("channel"));
+
+ EXPECT_THROW(luci::str_to_granularity("foo"), std::runtime_error);
+}
+
+TEST(StringsTest, csv_to_vector_int32)
+{
+ auto ret = luci::csv_to_vector<int32_t>("1,2,3");
+ ASSERT_EQ(3, ret.size());
+ ASSERT_EQ(1, ret.at(0));
+ ASSERT_EQ(3, ret.at(2));
+}
diff --git a/compiler/luci/pass/src/helpers/TypeMapper.cpp b/compiler/luci/pass/src/helpers/TypeMapper.cpp
new file mode 100644
index 000000000..ffa0159dd
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/TypeMapper.cpp
@@ -0,0 +1,20 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TypeMapper.h"
+
+// NOTE Do NOT delete this file; this file enforces compiler to check whether 'TypeMapper.h' is
+// complete.
diff --git a/compiler/luci/pass/src/helpers/TypeMapper.h b/compiler/luci/pass/src/helpers/TypeMapper.h
new file mode 100644
index 000000000..90760e95b
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/TypeMapper.h
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <loco/IR/DataType.h>
+
+#include <cstdint>
+
+namespace luci
+{
+
+/**
+ * @brief TypeMapper maps between c++ primitive data type and loco::DataType.
+ */
+template <typename T> struct TypeMapper
+{
+ static constexpr loco::DataType get() { return loco::DataType::Unknown; }
+};
+
+template <> struct TypeMapper<float>
+{
+ static constexpr loco::DataType get() { return loco::DataType::FLOAT32; }
+};
+
+template <> struct TypeMapper<uint8_t>
+{
+ static constexpr loco::DataType get() { return loco::DataType::U8; }
+};
+
+template <> struct TypeMapper<uint16_t>
+{
+ static constexpr loco::DataType get() { return loco::DataType::U16; }
+};
+
+template <> struct TypeMapper<uint32_t>
+{
+ static constexpr loco::DataType get() { return loco::DataType::U32; }
+};
+
+template <> struct TypeMapper<uint64_t>
+{
+ static constexpr loco::DataType get() { return loco::DataType::U64; }
+};
+
+template <> struct TypeMapper<int8_t>
+{
+ static constexpr loco::DataType get() { return loco::DataType::S8; }
+};
+
+template <> struct TypeMapper<int16_t>
+{
+ static constexpr loco::DataType get() { return loco::DataType::S16; }
+};
+
+template <> struct TypeMapper<int32_t>
+{
+ static constexpr loco::DataType get() { return loco::DataType::S32; }
+};
+
+template <> struct TypeMapper<int64_t>
+{
+ static constexpr loco::DataType get() { return loco::DataType::S64; }
+};
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/helpers/TypeMapper.test.cpp b/compiler/luci/pass/src/helpers/TypeMapper.test.cpp
new file mode 100644
index 000000000..a7ac08a63
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/TypeMapper.test.cpp
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+#include "TypeMapper.h"
+
+#include <vector>
+
+namespace
+{
+
+template <typename T> bool fill_const_node(luci::CircleConst *node, std::vector<T> &data)
+{
+ if (node->dtype() != luci::TypeMapper<T>::get())
+ return false;
+
+ node->size<luci::TypeMapper<T>::get()>(data.size());
+ for (uint32_t i = 0; i < data.size(); i++)
+ {
+ node->at<luci::TypeMapper<T>::get()>(i) = data.at(i);
+ }
+
+ return true;
+}
+
+class STRANGER
+{
+};
+
+} // namespace
+
+TEST(TypeMapperTest, simple_test)
+{
+ EXPECT_EQ(loco::DataType::FLOAT32, luci::TypeMapper<float>::get());
+ EXPECT_EQ(loco::DataType::U8, luci::TypeMapper<uint8_t>::get());
+ EXPECT_EQ(loco::DataType::U16, luci::TypeMapper<uint16_t>::get());
+ EXPECT_EQ(loco::DataType::U32, luci::TypeMapper<uint32_t>::get());
+ EXPECT_EQ(loco::DataType::U64, luci::TypeMapper<uint64_t>::get());
+ EXPECT_EQ(loco::DataType::S8, luci::TypeMapper<int8_t>::get());
+ EXPECT_EQ(loco::DataType::S16, luci::TypeMapper<int16_t>::get());
+ EXPECT_EQ(loco::DataType::S32, luci::TypeMapper<int32_t>::get());
+ EXPECT_EQ(loco::DataType::S64, luci::TypeMapper<int64_t>::get());
+}
+
+TEST(TypeMapperTest, with_template_test)
+{
+ std::vector<int32_t> int32_vec{0, 1, 2, 3, 4, 5, 6, 7};
+ luci::CircleConst const_node;
+ const_node.dtype(loco::DataType::S32);
+ EXPECT_TRUE(fill_const_node(&const_node, int32_vec));
+ EXPECT_EQ(8, const_node.size<loco::DataType::S32>());
+ EXPECT_EQ(0, const_node.at<loco::DataType::S32>(0));
+ EXPECT_EQ(1, const_node.at<loco::DataType::S32>(1));
+ EXPECT_EQ(2, const_node.at<loco::DataType::S32>(2));
+ EXPECT_EQ(3, const_node.at<loco::DataType::S32>(3));
+ EXPECT_EQ(4, const_node.at<loco::DataType::S32>(4));
+ EXPECT_EQ(5, const_node.at<loco::DataType::S32>(5));
+ EXPECT_EQ(6, const_node.at<loco::DataType::S32>(6));
+ EXPECT_EQ(7, const_node.at<loco::DataType::S32>(7));
+
+ std::vector<float> f32_vec{0.0, 1.1, 2.2, 3.3, 4.4, 5.5};
+ const_node.dtype(loco::DataType::FLOAT32);
+ EXPECT_FALSE(fill_const_node(&const_node, int32_vec));
+ EXPECT_TRUE(fill_const_node(&const_node, f32_vec));
+ EXPECT_EQ(6, const_node.size<loco::DataType::FLOAT32>());
+ EXPECT_FLOAT_EQ(0.0, const_node.at<loco::DataType::FLOAT32>(0));
+ EXPECT_FLOAT_EQ(1.1, const_node.at<loco::DataType::FLOAT32>(1));
+ EXPECT_FLOAT_EQ(2.2, const_node.at<loco::DataType::FLOAT32>(2));
+ EXPECT_FLOAT_EQ(3.3, const_node.at<loco::DataType::FLOAT32>(3));
+ EXPECT_FLOAT_EQ(4.4, const_node.at<loco::DataType::FLOAT32>(4));
+ EXPECT_FLOAT_EQ(5.5, const_node.at<loco::DataType::FLOAT32>(5));
+}
+
+TEST(TypeMapperTest, wrong_condition_NEG)
+{
+ EXPECT_EQ(loco::DataType::Unknown, luci::TypeMapper<STRANGER>::get());
+}
diff --git a/compiler/luci/pass/src/test/TestFirstNode.h b/compiler/luci/pass/src/test/TestFirstNode.h
new file mode 100644
index 000000000..21f859fcd
--- /dev/null
+++ b/compiler/luci/pass/src/test/TestFirstNode.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PASS_TEST_FIRST_NODE_H__
+#define __LUCI_PASS_TEST_FIRST_NODE_H__
+
+#include <luci/IR/CircleNodes.h>
+
+#include <loco.h>
+
+namespace luci
+{
+namespace test
+{
+
+template <class T> T *first_node(loco::Graph *g)
+{
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto target_node = dynamic_cast<T *>(node);
+ if (target_node != nullptr)
+ return target_node;
+ }
+ return nullptr;
+}
+
+} // namespace test
+} // namespace luci
+
+#endif // __LUCI_PASS_TEST_FIRST_NODE_H__
diff --git a/compiler/luci/pass/src/test/TestFirstNode.test.cpp b/compiler/luci/pass/src/test/TestFirstNode.test.cpp
new file mode 100644
index 000000000..b07ac6199
--- /dev/null
+++ b/compiler/luci/pass/src/test/TestFirstNode.test.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestFirstNode.h"
+
+// This file validates "TestFirstNode.h". Pleaes DO NOT remove this file.
diff --git a/compiler/luci/pass/src/test/TestIOGraph.h b/compiler/luci/pass/src/test/TestIOGraph.h
new file mode 100644
index 000000000..b1fc41f90
--- /dev/null
+++ b/compiler/luci/pass/src/test/TestIOGraph.h
@@ -0,0 +1,161 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PASS_TEST_IO_GRAPH_H__
+#define __LUCI_PASS_TEST_IO_GRAPH_H__
+
+#include "TestShape.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace luci
+{
+namespace test
+{
+
+/**
+ * @brief Graphlet with Inputs and loco::Graph for multiple inputs
+ * @note Every Graph will have Input(s) and Output(s)
+ * We put loco::Graph only in IsGraphlet not to declare separate
+ * class for loco::Graph
+ */
+template <unsigned N> class TestIsGraphlet
+{
+public:
+ TestIsGraphlet()
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _graph_inputs[n] = nullptr;
+ _inputs[n] = nullptr;
+ }
+ }
+
+public:
+ virtual void init(loco::Graph *g, const ShapeU32 shape_in)
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _graph_inputs[n] = g->inputs()->create();
+
+ _inputs[n] = g->nodes()->create<luci::CircleInput>();
+ _inputs[n]->shape(shape_in);
+ _inputs[n]->shape_status(luci::ShapeStatus::VALID);
+ _inputs[n]->dtype(loco::DataType::FLOAT32);
+ _inputs[n]->name("input_" + std::to_string(n));
+
+ _inputs[n]->index(_graph_inputs[n]->index());
+
+ auto input_shape = std::make_unique<loco::TensorShape>();
+ set_shape_vector(input_shape.get(), shape_in);
+ _graph_inputs[n]->shape(std::move(input_shape));
+ _graph_inputs[n]->dtype(loco::DataType::FLOAT32);
+ }
+ }
+
+public:
+ loco::Graph *g(void) { return &_g; }
+ luci::CircleInput *input(int idx) { return _inputs[idx]; }
+
+protected:
+ loco::Graph _g;
+ std::array<loco::GraphInput *, N> _graph_inputs;
+ std::array<luci::CircleInput *, N> _inputs;
+};
+
+/**
+ * @brief Graphlet with one Input
+ */
+class TestIGraphlet : public TestIsGraphlet<1>
+{
+public:
+ luci::CircleInput *input() { return _inputs[0]; }
+};
+
+/**
+ * @brief Graphlet with Outputs for multiple outputs
+ */
+template <unsigned N> class TestOsGraphlet
+{
+public:
+ TestOsGraphlet()
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _graph_outputs[n] = nullptr;
+ _outputs[n] = nullptr;
+ }
+ }
+
+public:
+ virtual void init(loco::Graph *g, const ShapeU32 shape_out)
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _graph_outputs[n] = g->outputs()->create();
+
+ _outputs[n] = g->nodes()->create<luci::CircleOutput>();
+ _outputs[n]->shape(shape_out);
+ _outputs[n]->shape_status(luci::ShapeStatus::VALID);
+ _outputs[n]->dtype(loco::DataType::FLOAT32);
+ _outputs[n]->name("output_" + std::to_string(n));
+
+ _outputs[n]->index(_graph_outputs[n]->index());
+
+ auto output_shape = std::make_unique<loco::TensorShape>();
+ set_shape_vector(output_shape.get(), shape_out);
+ _graph_outputs[n]->shape(std::move(output_shape));
+ _graph_outputs[n]->dtype(loco::DataType::FLOAT32);
+ }
+ }
+
+public:
+ luci::CircleOutput *output(int idx) { return _outputs[idx]; }
+
+protected:
+ std::array<loco::GraphOutput *, N> _graph_outputs;
+ std::array<luci::CircleOutput *, N> _outputs;
+};
+
+/**
+ * @brief Graphlet with one Output
+ */
+class TestOGraphlet : public TestOsGraphlet<1>
+{
+public:
+ luci::CircleOutput *output() { return _outputs[0]; }
+};
+
+/**
+ * @brief Graph with Input and Output
+ */
+class TestIOGraph : public TestIGraphlet, public TestOGraphlet
+{
+public:
+ TestIOGraph() = default;
+
+public:
+ virtual void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIsGraphlet<1>::init(g(), shape_in);
+ TestOsGraphlet<1>::init(g(), shape_out);
+ }
+};
+
+} // namespace test
+} // namespace luci
+
+#endif // __LUCI_PASS_TEST_IO_GRAPH_H__
diff --git a/compiler/luci/pass/src/test/TestIOGraph.test.cpp b/compiler/luci/pass/src/test/TestIOGraph.test.cpp
new file mode 100644
index 000000000..e58a13f2b
--- /dev/null
+++ b/compiler/luci/pass/src/test/TestIOGraph.test.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestIOGraph.h"
+
+// This file validates "TestIOGraph.h". Pleaes DO NOT remove this file.
diff --git a/compiler/luci/export/src/TypeBridge.h b/compiler/luci/pass/src/test/TestShape.h
index a63fbce54..ccc55c9da 100644
--- a/compiler/luci/export/src/TypeBridge.h
+++ b/compiler/luci/pass/src/test/TestShape.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,31 +14,27 @@
* limitations under the License.
*/
-#ifndef __TYPE_BRIDGE_H__
-#define __TYPE_BRIDGE_H__
+#ifndef __LUCI_PASS_TEST_SHAPE_H__
+#define __LUCI_PASS_TEST_SHAPE_H__
#include <luci/IR/CircleNode.h>
-#include <loco.h>
+#include <initializer_list>
namespace luci
{
+namespace test
+{
-/**
- * @brief node_shape() will return loco::TensorShape of CircleNode
- */
-loco::TensorShape node_shape(CircleNode *node);
+using ShapeU32 = std::initializer_list<uint32_t>;
+using ShapeI32 = std::initializer_list<int32_t>;
-/**
- * @brief node_dtype() will return loco::DataType of CircleNode
- */
-loco::DataType node_dtype(CircleNode *node);
+void set_shape_vector(loco::TensorShape *shape, const ShapeU32 &values);
+void set_shape_vector(luci::CircleConst *const_node, const ShapeI32 &values);
-/**
- * @brief copy_shape_dtype() will copy shape and dtype inference data to CircleNode
- */
-void copy_shape_dtype(loco::Graph *graph);
+uint32_t num_elements(const ShapeU32 shape);
+} // namespace test
} // namespace luci
-#endif // __TYPE_BRIDGE_H__
+#endif // __LUCI_PASS_TEST_SHAPE_H__
diff --git a/compiler/luci/pass/src/test/TestShape.test.cpp b/compiler/luci/pass/src/test/TestShape.test.cpp
new file mode 100644
index 000000000..39790c614
--- /dev/null
+++ b/compiler/luci/pass/src/test/TestShape.test.cpp
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestShape.h"
+
+/**
+ * @note This file does not hold any test cases but provides methods for tests
+ */
+
+namespace luci
+{
+namespace test
+{
+
+void set_shape_vector(loco::TensorShape *shape, const ShapeU32 &values)
+{
+ uint32_t r = 0;
+ shape->rank(values.size());
+ for (auto v : values)
+ shape->dim(r++).set(v);
+}
+
+void set_shape_vector(luci::CircleConst *const_node, const ShapeI32 &values)
+{
+ const_node->rank(1);
+ const_node->dim(0).set(values.size());
+ const_node->shape_status(luci::ShapeStatus::VALID);
+ const_node->dtype(loco::DataType::S32);
+ const_node->size<loco::DataType::S32>(values.size());
+ uint32_t idx = 0;
+ for (auto val : values)
+ const_node->at<loco::DataType::S32>(idx++) = val;
+}
+
+uint32_t num_elements(const ShapeU32 shape)
+{
+ uint32_t result = 1;
+ for (auto val : shape)
+ result = result * val;
+ return result;
+}
+
+} // namespace test
+} // namespace luci
diff --git a/compiler/luci/profile/CMakeLists.txt b/compiler/luci/profile/CMakeLists.txt
new file mode 100644
index 000000000..f2c6665da
--- /dev/null
+++ b/compiler/luci/profile/CMakeLists.txt
@@ -0,0 +1,22 @@
+file(GLOB_RECURSE SOURCES "src/*.cpp")
+file(GLOB_RECURSE TESTS "src/*.test.cpp")
+list(REMOVE_ITEM SOURCES ${TESTS})
+
+add_library(luci_profile SHARED ${SOURCES})
+target_include_directories(luci_profile PRIVATE src)
+target_include_directories(luci_profile PUBLIC include)
+target_link_libraries(luci_profile PUBLIC loco)
+target_link_libraries(luci_profile PUBLIC luci_lang)
+
+install(TARGETS luci_profile DESTINATION lib)
+
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest REQUIRED)
+
+GTest_AddTest(luci_profile_test ${TESTS})
+target_include_directories(luci_profile_test PRIVATE src)
+target_link_libraries(luci_profile_test luci_lang)
+target_link_libraries(luci_profile_test luci_profile)
diff --git a/compiler/luci/profile/README.md b/compiler/luci/profile/README.md
new file mode 100644
index 000000000..577e60a7c
--- /dev/null
+++ b/compiler/luci/profile/README.md
@@ -0,0 +1,119 @@
+# luci-profile
+
+`luci-profile` provides profiling related items.
+
+## CircleNodeOrigin
+
+`CircleNodeOrigin` allow us know where some node is originated from.
+
+Let's assume following graph transformations are done.
+
+```
+ | | |
+ [node1] --------+ | |
+(id = 1) | | |
+ | +--------> [node5] ----------------> [node6]
+ | | (origin = [1,2]) (origin = [1,2])
+ [node2] --------+ | |
+(id = 2) | |
+ | | |
+ [node3] -----------------> [node3] --------+-------> [node3]
+(id = 3) (origin = [3]) | (origin = [3,4])
+ | | | |
+ [node4] -----------------> [node4] --------+ |
+(id = 4) (origin = [4]) |
+ | | |
+
+<Circle1> -- optimizer --> <circle2> -- quantizer --> <circle3>
+```
+
+The most important purpose of using `CircleNodeOrigin` is preserving origin information.
+Following changes show how origin information is preserved even after graph is transformed.
+
+- `node3`
+ - `node4` is absorbed to **existing** `node3`.
+ - origin of `node4` is absorbed to origin of `node3`.
+- `node5`
+ - `node1` and `node2` are fused to **newly created** `node5`.
+ - origin of `node1` and `node2` are inherited to origin of `node4`.
+- `node6`
+ - `node5` is **replaced with newly created** `node6`.
+ - origin of `node5` is copied to origin of `node6`.
+
+**Therefore, when using `CircleNodeOrigin`, please aware of the most important principle. "Preserve origin information"**
+
+Next items are about implementation details to store the origin information.
+
+### Source Table
+
+Source table includes a set of id and name of origin node.
+
+#### Binary format
+
+```
+[ entry_number : uint32_t ]
+[ id : uint32_t ][ length : uint32_t ][ data : char * length ] * entry_number
+```
+- entry_number : The number of entries
+ - Each entry consists of id, length, and data.
+- id : ID of origin node
+- length : Length of data
+- data : Name of origin node **(null-terminated string)**
+
+#### In-memory format
+```cpp
+// size = entry_number
+std::map<uint32_t /* id */, std::string /* name */>
+```
+
+#### Example
+
+Following example means "Name of origin 1 is node1".
+
+```
+[Binary Format]
+ 0x01 00 00 00 0x01 00 00 00 0x06 00 00 00 0x6e 0x6f 0x64 0x65 0x31 00
+ ------------- ------------- ------------- ---- ---- ---- ---- ---- ----
+entry_number=1 id=1 length=6 'n' 'o' 'd' 'e' '1' '\0'
+```
+```cpp
+[In-memory Format]
+std::map<uint32_t, std::string>({1, "node1"});
+```
+
+### Op Table
+
+Op table includes a set of id of operation and id(s) of operation's origin nodes.
+
+#### Binary format
+
+Op table is stored in circle file as binary with following format.
+```
+[ entry_number : uint32_t ]
+[ id : uint32_t ][ node_num : uint32_t ][ node_ids : uint32_t * node_num ] * entry_number
+```
+- entry_number : The number of entries
+ - Each entry consists of id, node_num, and node_ids.
+- id : ID of operation in circle model file
+- node_num : The number of operation's origin nodes
+- node_ids : Set of IDs of origin nodes
+
+#### In-memory format
+```cpp
+std::map<uint32_t /* id */, std::set<uint32_t> /* node_ids */>
+```
+
+#### Example
+
+Following example means "Operation 5 is originated from origin 1 and origin 2".
+
+```
+[Binary Format]
+ 0x01 00 00 00 0x05 00 00 00 0x02 00 00 00 0x01 00 00 00 0x02 00 00 00
+ ------------- ------------- ------------- ---------------------------
+entry_number=1 id=5 node_num=2 node_ids : 1, 2
+```
+```cpp
+[In-memory Format]
+std::map<uint32_t, std::set<uint32_t>>({5, std::set{1, 2}});
+```
diff --git a/compiler/luci/pass/src/FuseActivationFunctionPassInternal.h b/compiler/luci/profile/include/luci/Profile/CircleNodeID.h
index 0cfb9d507..165866bcf 100644
--- a/compiler/luci/pass/src/FuseActivationFunctionPassInternal.h
+++ b/compiler/luci/profile/include/luci/Profile/CircleNodeID.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,18 +14,22 @@
* limitations under the License.
*/
-#ifndef __LUCI_CIRCLE_FUSE_ACTIVATION_FUNCTION_PASS_INTERNAL_H__
-#define __LUCI_CIRCLE_FUSE_ACTIVATION_FUNCTION_PASS_INTERNAL_H__
+#ifndef __LUCI_PROFILE_CIRCLE_NODE_ID_H__
+#define __LUCI_PROFILE_CIRCLE_NODE_ID_H__
-#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNode.h>
namespace luci
{
-// Fuse activation function with preceding Op
-/// @return true if success
-bool fuse_activation_function(luci::CircleNode *node);
+using CircleNodeID = uint32_t;
+
+bool has_node_id(const luci::CircleNode *circle_node);
+
+void set_node_id(luci::CircleNode *circle_node, CircleNodeID id);
+
+CircleNodeID get_node_id(const luci::CircleNode *circle_node);
} // namespace luci
-#endif // __LUCI_CIRCLE_FUSE_ACTIVATION_FUNCTION_PASS_INTERNAL_H__
+#endif // __LUCI_PROFILE_CIRCLE_NODE_ID_H__
diff --git a/compiler/luci/profile/include/luci/Profile/CircleNodeOrigin.h b/compiler/luci/profile/include/luci/Profile/CircleNodeOrigin.h
new file mode 100644
index 000000000..2d6558c92
--- /dev/null
+++ b/compiler/luci/profile/include/luci/Profile/CircleNodeOrigin.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PROFILE_CIRCLE_NODE_ORIGIN_H__
+#define __LUCI_PROFILE_CIRCLE_NODE_ORIGIN_H__
+
+#include "CircleNodeID.h"
+
+#include <luci/IR/CircleNode.h>
+
+#include <set>
+
+namespace luci
+{
+
+class CircleNodeOrigin
+{
+protected:
+ struct Source
+ {
+ public:
+ std::string name(void) const { return _name; }
+ void name(const std::string &name) { _name = name; }
+
+ uint32_t id(void) const { return _id; }
+ void id(const uint32_t id) { _id = id; }
+
+ private:
+ std::string _name;
+ uint32_t _id = 0;
+ };
+
+public:
+ virtual std::set<const Source *> sources(void) const = 0;
+};
+
+std::shared_ptr<CircleNodeOrigin> single_origin(uint32_t id, const std::string &name);
+
+std::shared_ptr<CircleNodeOrigin>
+composite_origin(const std::initializer_list<std::shared_ptr<CircleNodeOrigin>> origins);
+
+std::shared_ptr<CircleNodeOrigin>
+composite_origin(const std::vector<std::shared_ptr<CircleNodeOrigin>> &origins);
+
+} // namespace luci
+
+namespace luci
+{
+
+bool has_origin(const luci::CircleNode *circle_node);
+
+void add_origin(luci::CircleNode *circle_node, const std::shared_ptr<CircleNodeOrigin> origin);
+
+// NOTE When circle_node does not have origin, nullptr is returned
+const std::shared_ptr<luci::CircleNodeOrigin> get_origin(const luci::CircleNode *circle_node);
+
+} // namespace luci
+
+#endif // __LUCI_PROFILE_CIRCLE_NODE_ORIGIN_H__
diff --git a/compiler/luci/profile/src/CircleNodeID.cpp b/compiler/luci/profile/src/CircleNodeID.cpp
new file mode 100644
index 000000000..750b36cae
--- /dev/null
+++ b/compiler/luci/profile/src/CircleNodeID.cpp
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Profile/CircleNodeID.h"
+
+#include <loco.h>
+
+#include <stdexcept>
+
+namespace
+{
+
+/**
+ * @brief Set annotation for circle node id
+ * @note Once CircleNodeID is annotated, it should not be changed.
+ * If CircleNodeID is needed to be changed, create new CircleNodeID.
+ */
+class CircleNodeIDAnnotation final : public loco::NodeAnnotation
+{
+public:
+ CircleNodeIDAnnotation() = delete;
+
+ CircleNodeIDAnnotation(luci::CircleNodeID node_id) : _node_id{node_id}
+ {
+ // Do nothing
+ }
+
+public:
+ luci::CircleNodeID node_id(void) const { return _node_id; }
+ // No setter
+
+private:
+ luci::CircleNodeID _node_id;
+};
+
+} // namespace
+
+namespace luci
+{
+
+bool has_node_id(const luci::CircleNode *circle_node)
+{
+ return circle_node->annot<CircleNodeIDAnnotation>() != nullptr;
+}
+
+void set_node_id(luci::CircleNode *circle_node, luci::CircleNodeID id)
+{
+ circle_node->annot<CircleNodeIDAnnotation>(nullptr);
+ circle_node->annot(std::make_unique<CircleNodeIDAnnotation>(id));
+}
+
+luci::CircleNodeID get_node_id(const luci::CircleNode *circle_node)
+{
+ if (!has_node_id(circle_node))
+ throw std::runtime_error("Cannot find CircleNodeID");
+
+ return circle_node->annot<CircleNodeIDAnnotation>()->node_id();
+}
+
+} // namespace luci
diff --git a/compiler/luci/profile/src/CircleNodeID.test.cpp b/compiler/luci/profile/src/CircleNodeID.test.cpp
new file mode 100644
index 000000000..d80c09b2c
--- /dev/null
+++ b/compiler/luci/profile/src/CircleNodeID.test.cpp
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Profile/CircleNodeID.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+TEST(LuciCircleNodeID, simple_circle_node_id)
+{
+ auto g = loco::make_graph();
+ auto add = g->nodes()->create<luci::CircleAdd>();
+
+ ASSERT_FALSE(has_node_id(add));
+
+ set_node_id(add, 3);
+
+ ASSERT_TRUE(has_node_id(add));
+ ASSERT_EQ(3, get_node_id(add));
+}
+
+TEST(LuciCircleNodeID, simple_circle_node_id_NEG)
+{
+ auto g = loco::make_graph();
+ auto add = g->nodes()->create<luci::CircleAdd>();
+
+ ASSERT_FALSE(has_node_id(add));
+
+ ASSERT_ANY_THROW(get_node_id(add));
+}
diff --git a/compiler/luci/profile/src/CircleNodeOrigin.cpp b/compiler/luci/profile/src/CircleNodeOrigin.cpp
new file mode 100644
index 000000000..0a731a9ad
--- /dev/null
+++ b/compiler/luci/profile/src/CircleNodeOrigin.cpp
@@ -0,0 +1,168 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Profile/CircleNodeOrigin.h"
+
+#include <loco.h>
+
+#include <cassert>
+#include <vector>
+
+namespace
+{
+
+/**
+ * @brief Set annotation for recording origin information
+ * @note Once CircleNodeOrigin is annotated, it should not be changed.
+ * If CircleNodeOrigin is needed to be changed, create new CircleNodeOrigin.
+ */
+class CircleNodeOriginAnnotation final : public loco::NodeAnnotation
+{
+public:
+ CircleNodeOriginAnnotation() = delete;
+
+ CircleNodeOriginAnnotation(const std::shared_ptr<luci::CircleNodeOrigin> origin) : _origin(origin)
+ {
+ // Do nothing
+ }
+
+public:
+ const std::shared_ptr<luci::CircleNodeOrigin> origin(void) const { return _origin; }
+ // No setter
+
+private:
+ const std::shared_ptr<luci::CircleNodeOrigin> _origin;
+};
+
+} // namespace
+
+namespace
+{
+
+class SingleOrigin final : public luci::CircleNodeOrigin
+{
+public:
+ SingleOrigin() = delete;
+
+ SingleOrigin(uint32_t id, const std::string &name)
+ {
+ _source.id(id);
+ _source.name(name);
+ }
+
+public:
+ std::set<const Source *> sources(void) const final
+ {
+ std::set<const Source *> res;
+ res.emplace(&_source);
+ return res;
+ }
+
+private:
+ Source _source;
+};
+
+class CompositeOrigin final : public luci::CircleNodeOrigin
+{
+public:
+ CompositeOrigin() = delete;
+
+ template <typename T> CompositeOrigin(T origins)
+ {
+ if (origins.size() == 0)
+ throw std::invalid_argument("No origins provided");
+
+ for (auto &origin : origins)
+ {
+ if (origin != nullptr)
+ _origins.emplace_back(origin);
+ }
+ }
+
+public:
+ std::set<const Source *> sources(void) const final
+ {
+ std::set<const Source *> res;
+
+ for (auto &origin : _origins)
+ {
+ for (auto source : origin->sources())
+ {
+ res.emplace(source);
+ }
+ }
+
+ return res;
+ }
+
+private:
+ std::vector<std::shared_ptr<CircleNodeOrigin>> _origins;
+};
+
+} // namespace
+
+namespace luci
+{
+
+std::shared_ptr<CircleNodeOrigin> single_origin(uint32_t id, const std::string &name)
+{
+ return std::make_shared<SingleOrigin>(id, name);
+}
+
+std::shared_ptr<CircleNodeOrigin>
+composite_origin(const std::initializer_list<std::shared_ptr<CircleNodeOrigin>> origins)
+{
+ return std::make_shared<CompositeOrigin>(origins);
+}
+
+std::shared_ptr<CircleNodeOrigin>
+composite_origin(const std::vector<std::shared_ptr<CircleNodeOrigin>> &origins)
+{
+ return std::make_shared<CompositeOrigin>(origins);
+}
+
+} // namespace luci
+
+namespace luci
+{
+
+bool has_origin(const luci::CircleNode *circle_node)
+{
+ return circle_node->annot<CircleNodeOriginAnnotation>() != nullptr;
+}
+
+/**
+ * @brief 'origin' is added to the existing origin of circle_node.
+ * @note If 'origin' is nullptr, nothing is changed.
+ * For more detail, please refer to CompositeOrigin constructor.
+ */
+void add_origin(luci::CircleNode *circle_node, const std::shared_ptr<CircleNodeOrigin> origin)
+{
+ auto new_origin = composite_origin({get_origin(circle_node), origin});
+ circle_node->annot<CircleNodeOriginAnnotation>(nullptr);
+ circle_node->annot(std::make_unique<CircleNodeOriginAnnotation>(new_origin));
+}
+
+const std::shared_ptr<luci::CircleNodeOrigin> get_origin(const luci::CircleNode *circle_node)
+{
+ if (!has_origin(circle_node))
+ return nullptr;
+
+ assert(circle_node->annot<CircleNodeOriginAnnotation>()->origin() != nullptr);
+ return circle_node->annot<CircleNodeOriginAnnotation>()->origin();
+}
+
+} // namespace luci
diff --git a/compiler/luci/profile/src/CircleNodeOrigin.test.cpp b/compiler/luci/profile/src/CircleNodeOrigin.test.cpp
new file mode 100644
index 000000000..34618e1ab
--- /dev/null
+++ b/compiler/luci/profile/src/CircleNodeOrigin.test.cpp
@@ -0,0 +1,108 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Profile/CircleNodeID.h"
+#include "luci/Profile/CircleNodeOrigin.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+TEST(LuciCircleNodeOrigin, simple_single_origin)
+{
+ auto g = loco::make_graph();
+ auto add = g->nodes()->create<luci::CircleAdd>();
+
+ ASSERT_FALSE(has_origin(add));
+
+ auto origin = luci::single_origin(3, "add");
+ add_origin(add, origin);
+
+ ASSERT_TRUE(has_origin(add));
+
+ auto sources = get_origin(add)->sources();
+ ASSERT_EQ(1, sources.size());
+ for (auto source : sources)
+ {
+ ASSERT_EQ(3, source->id());
+ ASSERT_EQ(0, source->name().compare("add"));
+ }
+}
+
+TEST(LuciCircleNodeOrigin, simple_composite_origin_with_initializer)
+{
+ auto g = loco::make_graph();
+ auto mul = g->nodes()->create<luci::CircleMul>();
+
+ ASSERT_FALSE(has_origin(mul));
+
+ auto origin =
+ luci::composite_origin({luci::single_origin(3, "add"), luci::single_origin(7, "sub")});
+ add_origin(mul, origin);
+
+ ASSERT_TRUE(has_origin(mul));
+
+ bool add_origin_passed = false;
+ bool sub_origin_passed = false;
+ auto sources = get_origin(mul)->sources();
+ ASSERT_EQ(2, sources.size());
+ for (auto source : sources)
+ {
+ if (source->id() == 3 && source->name().compare("add") == 0)
+ add_origin_passed = true;
+ if (source->id() == 7 && source->name().compare("sub") == 0)
+ sub_origin_passed = true;
+ }
+
+ ASSERT_EQ(true, add_origin_passed);
+ ASSERT_EQ(true, sub_origin_passed);
+}
+
+TEST(LuciCircleNodeOrigin, simple_composite_origin_with_vector)
+{
+ auto g = loco::make_graph();
+ auto mul = g->nodes()->create<luci::CircleMul>();
+
+ ASSERT_FALSE(has_origin(mul));
+
+ std::vector<std::shared_ptr<luci::CircleNodeOrigin>> vec;
+ vec.push_back(luci::single_origin(3, "add"));
+ vec.push_back(luci::single_origin(7, "sub"));
+ auto origin = luci::composite_origin(vec);
+ add_origin(mul, origin);
+
+ ASSERT_TRUE(has_origin(mul));
+
+ bool add_origin_passed = false;
+ bool sub_origin_passed = false;
+ auto sources = get_origin(mul)->sources();
+ ASSERT_EQ(2, sources.size());
+ for (auto source : sources)
+ {
+ if (source->id() == 3 && source->name().compare("add") == 0)
+ add_origin_passed = true;
+ if (source->id() == 7 && source->name().compare("sub") == 0)
+ sub_origin_passed = true;
+ }
+
+ ASSERT_EQ(true, add_origin_passed);
+ ASSERT_EQ(true, sub_origin_passed);
+}
+
+TEST(LuciCircleNodeOrigin, composite_origin_empty_ctor_NEG)
+{
+ ASSERT_ANY_THROW(luci::composite_origin({}));
+}
diff --git a/compiler/luci/service/CMakeLists.txt b/compiler/luci/service/CMakeLists.txt
index 9f50c9c4f..1c78031ab 100644
--- a/compiler/luci/service/CMakeLists.txt
+++ b/compiler/luci/service/CMakeLists.txt
@@ -22,4 +22,5 @@ nnas_find_package(GTest REQUIRED)
GTest_AddTest(luci_service_test ${TESTS})
target_include_directories(luci_service_test PRIVATE src)
target_link_libraries(luci_service_test luci_service)
+target_link_libraries(luci_service_test luci_testhelper)
target_link_libraries(luci_service_test oops)
diff --git a/compiler/luci/service/include/luci/Service/CircleNodeClone.h b/compiler/luci/service/include/luci/Service/CircleNodeClone.h
new file mode 100644
index 000000000..2429997cc
--- /dev/null
+++ b/compiler/luci/service/include/luci/Service/CircleNodeClone.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CIRCLE_NODE_CLONE__
+#define __LUCI_CIRCLE_NODE_CLONE__
+
+#include <luci/IR/CircleNodes.h>
+
+#include <loco/IR/Graph.h>
+
+namespace luci
+{
+
+/**
+ * @brief Copy common attributes of CircleNode from src to dst.
+ */
+void copy_common_attributes(const luci::CircleNode *src, luci::CircleNode *dst);
+
+/**
+ * @brief Return a new cloned CircleNode object with same attributes value of node to graph.
+ * @note Will return nullptr if clone has failed
+ */
+CircleNode *clone_node(const CircleNode *node, loco::Graph *graph);
+
+} // namespace luci
+
+#endif // __LUCI_CIRCLE_NODE_CLONE__
diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h
index c301db5f4..60bc16e48 100644
--- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h
+++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h
@@ -17,29 +17,15 @@
#ifndef __LUCI_CIRCLE_SHAPE_INFERENCE_H__
#define __LUCI_CIRCLE_SHAPE_INFERENCE_H__
-#include "ShapeDescription.h"
-
#include <loco/IR/Nodes.h>
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Service/CircleShapeInferenceHelper.h>
+#include <luci/Service/CircleShapeInferenceRule.h>
namespace luci
{
-/**
- * @brief Get the shape of each node as a node annotation
- *
- * HOW TO USE
- *
- * ShapeInference::get(g->nodes()->at(..));
- */
-struct ShapeInference
-{
- static ShapeDescription get(loco::Node *node);
-};
-
namespace sinf // namespace for Shape Inference
{
@@ -52,7 +38,12 @@ class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape>
{
public:
// TODO Remove this when all of visit function is implemented
- loco::TensorShape visit(const luci::CircleNode *node) final { return sinf::circle_shape(node); }
+ loco::TensorShape visit(const luci::CircleNode *node) final
+ {
+ loco::NodeShape shape;
+ luci::CircleShapeInferenceRule().infer(node, shape);
+ return shape.as<loco::TensorShape>();
+ }
// loco::TensorShape visit(const luci::CircleAbs *node) final;
// loco::TensorShape visit(const luci::CircleAdd *node) final;
@@ -77,6 +68,7 @@ public:
// loco::TensorShape visit(const luci::CircleEqual *node) final;
// loco::TensorShape visit(const luci::CircleExp *node) final;
// loco::TensorShape visit(const luci::CircleExpandDims *node) final;
+ // loco::TensorShape visit(const luci::CircleFakeQuant *node) final;
// loco::TensorShape visit(const luci::CircleFill *node) final;
// loco::TensorShape visit(const luci::CircleFloor *node) final;
// loco::TensorShape visit(const luci::CircleFloorDiv *node) final;
@@ -106,10 +98,12 @@ public:
// loco::TensorShape visit(const luci::CircleMean *node) final;
// loco::TensorShape visit(const luci::CircleMinimum *node) final;
// loco::TensorShape visit(const luci::CircleMirrorPad *node) final;
+ // loco::TensorShape visit(const luci::CircleMul *node) final;
// loco::TensorShape visit(const luci::CircleNeg *node) final;
// loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4 *node) final;
// loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5 *node) final;
// loco::TensorShape visit(const luci::CircleNotEqual *node) final;
+ // loco::TensorShape visit(const luci::CircleOneHot *node) final;
// loco::TensorShape visit(const luci::CirclePack *node) final;
// loco::TensorShape visit(const luci::CirclePad *node) final;
// loco::TensorShape visit(const luci::CirclePadV2 *node) final;
@@ -117,8 +111,6 @@ public:
// loco::TensorShape visit(const luci::CirclePRelu *node) final;
// loco::TensorShape visit(const luci::CircleRange *node) final;
// loco::TensorShape visit(const luci::CircleRank *node) final;
- // loco::TensorShape visit(const luci::CircleMul *node) final;
- // loco::TensorShape visit(const luci::CircleOneHot *node) final;
// loco::TensorShape visit(const luci::CircleReduceAny *node) final;
// loco::TensorShape visit(const luci::CircleReduceMax *node) final;
// loco::TensorShape visit(const luci::CircleReduceMin *node) final;
@@ -171,14 +163,14 @@ public:
// loco::TensorShape visit(const luci::CircleInstanceNorm *node) final;
// Virtual
+ // loco::TensorShape visit(const luci::CircleCustomOut *node) final;
+ loco::TensorShape visit(const luci::CircleIfOut *node) final;
// loco::TensorShape visit(const luci::CircleInput *node) final;
+ // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final;
+ // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5Out *node) final;
// loco::TensorShape visit(const luci::CircleOutput *node) final;
// loco::TensorShape visit(const luci::CircleOutputDummy *node) final;
// loco::TensorShape visit(const luci::CircleOutputExclude *node) final;
- // loco::TensorShape visit(const luci::CircleCustomOut *node) final;
- // loco::TensorShape visit(const luci::CircleIfOut *node) final;
- // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final;
- // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5Out *node) final;
// loco::TensorShape visit(const luci::CircleSplitOut *node) final;
// loco::TensorShape visit(const luci::CircleSplitVOut *node) final;
// loco::TensorShape visit(const luci::CircleTopKV2Out *node) final;
diff --git a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h
deleted file mode 100644
index f7ea89bb8..000000000
--- a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h
+++ /dev/null
@@ -1,179 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__
-#define __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/IR/CircleShapeSignature.h>
-#include <luci/Service/CircleShapeSignatureInferenceHelper.h>
-
-namespace luci
-{
-
-namespace ssinf // namespace for Shape Signature Inference
-{
-
-struct Rule
-{
- bool infer(const luci::CircleNode *, ShapeSignature &) const;
-};
-
-class Algorithm final : public luci::CircleNodeVisitor<ShapeSignature>
-{
-public:
- // TODO Remove this when visit function is implemented for all the operations.
- ShapeSignature visit(const luci::CircleNode *node) final { return node->shape_signature(); }
-
- // ShapeSignature visit(const luci::CircleAbs *node) final;
- // ShapeSignature visit(const luci::CircleAdd *node) final;
- // ShapeSignature visit(const luci::CircleAddN *node) final;
- // ShapeSignature visit(const luci::CircleArgMax *node) final;
- // ShapeSignature visit(const luci::CircleArgMin *node) final;
- // ShapeSignature visit(const luci::CircleAveragePool2D *node) final;
- // ShapeSignature visit(const luci::CircleBatchMatMul *node) final;
- // ShapeSignature visit(const luci::CircleBatchToSpaceND *node) final;
- // ShapeSignature visit(const luci::CircleCast *node) final;
- // ShapeSignature visit(const luci::CircleCeil *node) final;
- // ShapeSignature visit(const luci::CircleConcatenation *node) final;
- // ShapeSignature visit(const luci::CircleConst *node) final;
- // ShapeSignature visit(const luci::CircleConv2D *node) final;
- // ShapeSignature visit(const luci::CircleCos *node) final;
- // ShapeSignature visit(const luci::CircleCustom *node) final;
- // ShapeSignature visit(const luci::CircleDepthToSpace *node) final;
- // ShapeSignature visit(const luci::CircleDepthwiseConv2D *node) final;
- // ShapeSignature visit(const luci::CircleDequantize *node) final;
- // ShapeSignature visit(const luci::CircleDiv *node) final;
- // ShapeSignature visit(const luci::CircleElu *node) final;
- // ShapeSignature visit(const luci::CircleEqual *node) final;
- // ShapeSignature visit(const luci::CircleExp *node) final;
- // ShapeSignature visit(const luci::CircleExpandDims *node) final;
- // ShapeSignature visit(const luci::CircleFill *node) final;
- // ShapeSignature visit(const luci::CircleFloor *node) final;
- // ShapeSignature visit(const luci::CircleFloorDiv *node) final;
- // ShapeSignature visit(const luci::CircleFloorMod *node) final;
- // ShapeSignature visit(const luci::CircleFullyConnected *node) final;
- // ShapeSignature visit(const luci::CircleGather *node) final;
- // ShapeSignature visit(const luci::CircleGatherNd *node) final;
- // ShapeSignature visit(const luci::CircleGreater *node) final;
- // ShapeSignature visit(const luci::CircleGreaterEqual *node) final;
- // ShapeSignature visit(const luci::CircleIf *node) final;
- // ShapeSignature visit(const luci::CircleL2Normalize *node) final;
- // ShapeSignature visit(const luci::CircleL2Pool2D *node) final;
- // ShapeSignature visit(const luci::CircleLeakyRelu *node) final;
- // ShapeSignature visit(const luci::CircleLess *node) final;
- // ShapeSignature visit(const luci::CircleLessEqual *node) final;
- // ShapeSignature visit(const luci::CircleLocalResponseNormalization *node) final;
- // ShapeSignature visit(const luci::CircleLog *node) final;
- // ShapeSignature visit(const luci::CircleLogicalAnd *node) final;
- // ShapeSignature visit(const luci::CircleLogicalNot *node) final;
- // ShapeSignature visit(const luci::CircleLogicalOr *node) final;
- // ShapeSignature visit(const luci::CircleLogistic *node) final;
- // ShapeSignature visit(const luci::CircleLogSoftmax *node) final;
- // ShapeSignature visit(const luci::CircleMatrixDiag *node) final;
- // ShapeSignature visit(const luci::CircleMatrixSetDiag *node) final;
- // ShapeSignature visit(const luci::CircleMaximum *node) final;
- // ShapeSignature visit(const luci::CircleMaxPool2D *node) final;
- ShapeSignature visit(const luci::CircleMean *node) final;
- // ShapeSignature visit(const luci::CircleMinimum *node) final;
- // ShapeSignature visit(const luci::CircleMirrorPad *node) final;
- // ShapeSignature visit(const luci::CircleNeg *node) final;
- // ShapeSignature visit(const luci::CircleNonMaxSuppressionV4 *node) final;
- // ShapeSignature visit(const luci::CircleNonMaxSuppressionV5 *node) final;
- // ShapeSignature visit(const luci::CircleNotEqual *node) final;
- // ShapeSignature visit(const luci::CirclePack *node) final;
- // ShapeSignature visit(const luci::CirclePad *node) final;
- // ShapeSignature visit(const luci::CirclePadV2 *node) final;
- // ShapeSignature visit(const luci::CirclePow *node) final;
- // ShapeSignature visit(const luci::CirclePRelu *node) final;
- // ShapeSignature visit(const luci::CircleRange *node) final;
- // ShapeSignature visit(const luci::CircleRank *node) final;
- // ShapeSignature visit(const luci::CircleMul *node) final;
- // ShapeSignature visit(const luci::CircleOneHot *node) final;
- ShapeSignature visit(const luci::CircleReduceAny *node) final;
- ShapeSignature visit(const luci::CircleReduceMax *node) final;
- ShapeSignature visit(const luci::CircleReduceMin *node) final;
- ShapeSignature visit(const luci::CircleReduceProd *node) final;
- ShapeSignature visit(const luci::CircleRelu *node) final;
- ShapeSignature visit(const luci::CircleRelu6 *node) final;
- ShapeSignature visit(const luci::CircleReluN1To1 *node) final;
- // ShapeSignature visit(const luci::CircleReshape *node) final;
- // ShapeSignature visit(const luci::CircleResizeBilinear *node) final;
- // ShapeSignature visit(const luci::CircleResizeNearestNeighbor *node) final;
- // ShapeSignature visit(const luci::CircleReverseSequence *node) final;
- // ShapeSignature visit(const luci::CircleReverseV2 *node) final;
- // ShapeSignature visit(const luci::CircleRound *node) final;
- // ShapeSignature visit(const luci::CircleRsqrt *node) final;
- // ShapeSignature visit(const luci::CircleScatterNd *node) final;
- // ShapeSignature visit(const luci::CircleSegmentSum *node) final;
- // ShapeSignature visit(const luci::CircleSelect *node) final;
- // ShapeSignature visit(const luci::CircleSelectV2 *node) final;
- // ShapeSignature visit(const luci::CircleShape *node) final;
- // ShapeSignature visit(const luci::CircleSin *node) final;
- // ShapeSignature visit(const luci::CircleSlice *node) final;
- // ShapeSignature visit(const luci::CircleSoftmax *node) final;
- // ShapeSignature visit(const luci::CircleSpaceToBatchND *node) final;
- // ShapeSignature visit(const luci::CircleSpaceToDepth *node) final;
- // ShapeSignature visit(const luci::CircleSparseToDense *node) final;
- // ShapeSignature visit(const luci::CircleSplit *node) final;
- // ShapeSignature visit(const luci::CircleSplitV *node) final;
- // ShapeSignature visit(const luci::CircleSqrt *node) final;
- // ShapeSignature visit(const luci::CircleSquare *node) final;
- // ShapeSignature visit(const luci::CircleSquaredDifference *node) final;
- // ShapeSignature visit(const luci::CircleSqueeze *node) final;
- // ShapeSignature visit(const luci::CircleStridedSlice *node) final;
- // ShapeSignature visit(const luci::CircleSub *node) final;
- ShapeSignature visit(const luci::CircleSum *node) final;
- // ShapeSignature visit(const luci::CircleTanh *node) final;
- // ShapeSignature visit(const luci::CircleTile *node) final;
- // ShapeSignature visit(const luci::CircleTopKV2 *node) final;
- // ShapeSignature visit(const luci::CircleTranspose *node) final;
- // ShapeSignature visit(const luci::CircleTransposeConv *node) final;
- // ShapeSignature visit(const luci::CircleUnidirectionalSequenceLSTM *node) final;
- // ShapeSignature visit(const luci::CircleUnique *node) final;
- // ShapeSignature visit(const luci::CircleUnpack *node) final;
- // ShapeSignature visit(const luci::CircleWhere *node) final ;
- // ShapeSignature visit(const luci::CircleWhile *node) final;
- // ShapeSignature visit(const luci::CircleZerosLike *node) final;
-
- // Circle Only
- // ShapeSignature visit(const luci::CircleBCQFullyConnected *node) final;
- // ShapeSignature visit(const luci::CircleBCQGather *node) final;
- // ShapeSignature visit(const luci::CircleInstanceNorm *node) final;
-
- // Virtual
- ShapeSignature visit(const luci::CircleInput *node) final;
- ShapeSignature visit(const luci::CircleOutput *node) final;
- ShapeSignature visit(const luci::CircleOutputDummy *node) final;
- ShapeSignature visit(const luci::CircleOutputExclude *node) final;
- // ShapeSignature visit(const luci::CircleCustomOut *node) final;
- // ShapeSignature visit(const luci::CircleIfOut *node) final;
- // ShapeSignature visit(const luci::CircleNonMaxSuppressionV4Out *node) final;
- // ShapeSignature visit(const luci::CircleNonMaxSuppressionV5Out *node) final;
- // ShapeSignature visit(const luci::CircleSplitOut *node) final;
- // ShapeSignature visit(const luci::CircleSplitVOut *node) final;
- // ShapeSignature visit(const luci::CircleTopKV2Out *node) final;
- // ShapeSignature visit(const luci::CircleUniqueOut *node) final;
- // ShapeSignature visit(const luci::CircleUnpackOut *node) final;
- // ShapeSignature visit(const luci::CircleWhileOut *node) final;
-};
-
-} // namespace ssinf
-
-} // namespace luci
-
-#endif // __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__
diff --git a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h
deleted file mode 100644
index fb5b3b302..000000000
--- a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__
-#define __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleShapeSignature.h>
-
-namespace luci
-{
-
-namespace ssinf // Namespace for Shape Signature Inference
-{
-
-// Return empty signature if all of dimensions are known.
-// If at least one of dimensions is unknown, return signature without change.
-ShapeSignature legalized_signature(const luci::ShapeSignature &signature);
-
-// Return reduced input_signature with indices and keep_dims.
-// - indices : reduction index
-// - keep_dims : If true, rank is not changed. If false, rank is reduced along indices.
-ShapeSignature reduced_signature(const loco::Node *node, const loco::Node *indices, bool keep_dims);
-
-// Return signature of index-th argument of node.
-ShapeSignature input_arg_signature(const luci::CircleNode *node, uint32_t index);
-
-} // namespace ssinf
-
-} // namespace luci
-
-#endif // __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__
diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInference.h b/compiler/luci/service/include/luci/Service/CircleTypeInference.h
index 342214887..8eef469ac 100644
--- a/compiler/luci/service/include/luci/Service/CircleTypeInference.h
+++ b/compiler/luci/service/include/luci/Service/CircleTypeInference.h
@@ -23,24 +23,11 @@
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Service/CircleTypeInferenceHelper.h>
+#include <luci/Service/CircleTypeInferenceRule.h>
namespace luci
{
-/**
- * @brief Get the type of each node as NodeAnnotation
- *
- * HOW TO USE
- *
- * TypeInference::get(g->nodes()->at(0));
- * TypeInference::get(g->nodes()->at(...));
- */
-struct TypeInference
-{
- static circle::TensorType get(loco::Node *node);
-};
-
namespace tinf // namespace for Type Inference
{
@@ -53,7 +40,12 @@ class Algorithm final : public luci::CircleNodeVisitor<loco::DataType>
{
public:
// TODO Remove this when all of visit function is implemented
- loco::DataType visit(const luci::CircleNode *node) final { return node->dtype(); }
+ loco::DataType visit(const luci::CircleNode *node) final
+ {
+ loco::DataType dtype;
+ luci::CircleTypeInferenceRule().infer(node, dtype);
+ return dtype;
+ }
// loco::DataType visit(const luci::CircleAbs *node) final;
// loco::DataType visit(const luci::CircleAdd *node) final;
@@ -78,6 +70,7 @@ public:
// loco::DataType visit(const luci::CircleEqual *node) final;
// loco::DataType visit(const luci::CircleExp *node) final;
// loco::DataType visit(const luci::CircleExpandDims *node) final;
+ // loco::DataType visit(const luci::CircleFakeQuant *node) final;
// loco::DataType visit(const luci::CircleFill *node) final;
// loco::DataType visit(const luci::CircleFloor *node) final;
// loco::DataType visit(const luci::CircleFloorDiv *node) final;
@@ -177,7 +170,7 @@ public:
// loco::DataType visit(const luci::CircleOutputDummy *node) final;
// loco::DataType visit(const luci::CircleOutputExclude *node) final;
// loco::DataType visit(const luci::CircleCustomOut *node) final;
- // loco::DataType visit(const luci::CircleIfOut *node) final;
+ loco::DataType visit(const luci::CircleIfOut *node) final;
// loco::DataType visit(const luci::CircleNonMaxSuppressionV4Out *node) final;
// loco::DataType visit(const luci::CircleNonMaxSuppressionV5Out *node) final;
// loco::DataType visit(const luci::CircleSplitOut *node) final;
diff --git a/compiler/luci/service/include/luci/Service/Nodes/CircleConst.h b/compiler/luci/service/include/luci/Service/Nodes/CircleConst.h
new file mode 100644
index 000000000..6049b4297
--- /dev/null
+++ b/compiler/luci/service/include/luci/Service/Nodes/CircleConst.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_SERVICE_CIRCLE_CONST_H__
+#define __LUCI_SERVICE_CIRCLE_CONST_H__
+
+#include <luci/IR/Nodes/CircleConst.h>
+
+namespace luci
+{
+
+/**
+ * @brief Return cloned object of CircleConst node
+ */
+luci::CircleConst *clone(luci::CircleConst *node);
+
+} // namespace luci
+
+#endif // __LUCI_SERVICE_CIRCLE_CONST_H__
diff --git a/compiler/luci/service/include/luci/Service/ShapeDescription.h b/compiler/luci/service/include/luci/Service/ShapeDescription.h
index 4d92be13f..4671096fd 100644
--- a/compiler/luci/service/include/luci/Service/ShapeDescription.h
+++ b/compiler/luci/service/include/luci/Service/ShapeDescription.h
@@ -37,10 +37,6 @@ struct ShapeDescription
// TODO remove these when CircleDialect is fully functioal
ShapeDescription to_shape_description(const luci::CircleNode *node);
ShapeDescription to_shape_description(const loco::TensorShape &shape);
-ShapeDescription to_shape_description(const loco::FeatureShape &shape);
-ShapeDescription to_shape_description(const loco::FilterShape &shape);
-ShapeDescription to_shape_description(const loco::BiasShape &shape);
-ShapeDescription to_shape_description(const loco::MatrixShape &shape);
ShapeDescription to_shape_description(const loco::NodeShape &shape);
template <typename Permutation> inline bool isNHWC(Permutation *perm);
diff --git a/compiler/luci/service/include/luci/Service/Validate.h b/compiler/luci/service/include/luci/Service/Validate.h
index 4b80d1d16..456d6e504 100644
--- a/compiler/luci/service/include/luci/Service/Validate.h
+++ b/compiler/luci/service/include/luci/Service/Validate.h
@@ -17,6 +17,8 @@
#ifndef __LUCI_SERVICE_VALIDATE_H__
#define __LUCI_SERVICE_VALIDATE_H__
+#include <luci/IR/Module.h>
+
#include <loco.h>
namespace luci
@@ -24,6 +26,17 @@ namespace luci
bool validate(loco::Graph *);
+/**
+ * @brief Return true if all nodes in graph have non empty name
+ */
+bool validate_name(loco::Graph *);
+
+/**
+ * @brief Return true if all names in the Module are unique
+ * @note CircleOutput may have duplicate name
+ */
+bool validate_unique_name(luci::Module *);
+
} // namespace luci
#endif // __LUCI_SERVICE_VALIDATE_H__
diff --git a/compiler/luci/service/src/CircleCloneNode.h b/compiler/luci/service/src/CircleCloneNode.h
new file mode 100644
index 000000000..02c7cd256
--- /dev/null
+++ b/compiler/luci/service/src/CircleCloneNode.h
@@ -0,0 +1,174 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __CIRCLE_CLONE_NODE_H__
+#define __CIRCLE_CLONE_NODE_H__
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/IR/CircleNodeVisitor.h>
+
+namespace luci
+{
+
+class CloneNode final : public luci::CircleNodeVisitor<luci::CircleNode *>
+{
+public:
+ CloneNode(loco::Graph *graph) : _graph(graph){};
+
+public:
+ luci::CircleNode *visit(const luci::CircleAbs *) final;
+ luci::CircleNode *visit(const luci::CircleAdd *) final;
+ luci::CircleNode *visit(const luci::CircleAddN *) final;
+ luci::CircleNode *visit(const luci::CircleArgMax *) final;
+ luci::CircleNode *visit(const luci::CircleArgMin *) final;
+ luci::CircleNode *visit(const luci::CircleAveragePool2D *) final;
+ luci::CircleNode *visit(const luci::CircleBatchMatMul *) final;
+ luci::CircleNode *visit(const luci::CircleBatchToSpaceND *) final;
+ luci::CircleNode *visit(const luci::CircleCast *) final;
+ luci::CircleNode *visit(const luci::CircleCeil *) final;
+ luci::CircleNode *visit(const luci::CircleConcatenation *) final;
+ luci::CircleNode *visit(const luci::CircleConst *) final;
+ luci::CircleNode *visit(const luci::CircleConv2D *) final;
+ luci::CircleNode *visit(const luci::CircleCos *) final;
+ luci::CircleNode *visit(const luci::CircleCustom *) final;
+ luci::CircleNode *visit(const luci::CircleDepthToSpace *) final;
+ luci::CircleNode *visit(const luci::CircleDepthwiseConv2D *) final;
+ luci::CircleNode *visit(const luci::CircleDequantize *) final;
+ luci::CircleNode *visit(const luci::CircleDiv *) final;
+ luci::CircleNode *visit(const luci::CircleElu *) final;
+ luci::CircleNode *visit(const luci::CircleEqual *) final;
+ luci::CircleNode *visit(const luci::CircleExp *) final;
+ luci::CircleNode *visit(const luci::CircleExpandDims *) final;
+ luci::CircleNode *visit(const luci::CircleFakeQuant *) final;
+ luci::CircleNode *visit(const luci::CircleFill *) final;
+ luci::CircleNode *visit(const luci::CircleFloor *) final;
+ luci::CircleNode *visit(const luci::CircleFloorDiv *) final;
+ luci::CircleNode *visit(const luci::CircleFloorMod *) final;
+ luci::CircleNode *visit(const luci::CircleFullyConnected *) final;
+ luci::CircleNode *visit(const luci::CircleGather *) final;
+ luci::CircleNode *visit(const luci::CircleGatherNd *) final;
+ luci::CircleNode *visit(const luci::CircleGreater *) final;
+ luci::CircleNode *visit(const luci::CircleGreaterEqual *) final;
+ // luci::CircleNode *visit(const luci::CircleIf *) final;
+ luci::CircleNode *visit(const luci::CircleL2Normalize *) final;
+ luci::CircleNode *visit(const luci::CircleL2Pool2D *) final;
+ luci::CircleNode *visit(const luci::CircleLeakyRelu *) final;
+ luci::CircleNode *visit(const luci::CircleLess *) final;
+ luci::CircleNode *visit(const luci::CircleLessEqual *) final;
+ luci::CircleNode *visit(const luci::CircleLocalResponseNormalization *) final;
+ luci::CircleNode *visit(const luci::CircleLog *) final;
+ luci::CircleNode *visit(const luci::CircleLogicalAnd *) final;
+ luci::CircleNode *visit(const luci::CircleLogicalNot *) final;
+ luci::CircleNode *visit(const luci::CircleLogicalOr *) final;
+ luci::CircleNode *visit(const luci::CircleLogistic *) final;
+ luci::CircleNode *visit(const luci::CircleLogSoftmax *) final;
+ luci::CircleNode *visit(const luci::CircleMatrixDiag *) final;
+ luci::CircleNode *visit(const luci::CircleMatrixSetDiag *) final;
+ luci::CircleNode *visit(const luci::CircleMaximum *) final;
+ luci::CircleNode *visit(const luci::CircleMaxPool2D *) final;
+ luci::CircleNode *visit(const luci::CircleMean *) final;
+ luci::CircleNode *visit(const luci::CircleMinimum *) final;
+ luci::CircleNode *visit(const luci::CircleMirrorPad *) final;
+ luci::CircleNode *visit(const luci::CircleMul *) final;
+ luci::CircleNode *visit(const luci::CircleNeg *) final;
+ luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV4 *) final;
+ luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV5 *) final;
+ luci::CircleNode *visit(const luci::CircleNotEqual *) final;
+ luci::CircleNode *visit(const luci::CircleOneHot *) final;
+ luci::CircleNode *visit(const luci::CirclePack *) final;
+ luci::CircleNode *visit(const luci::CirclePad *) final;
+ luci::CircleNode *visit(const luci::CirclePadV2 *) final;
+ luci::CircleNode *visit(const luci::CirclePow *) final;
+ luci::CircleNode *visit(const luci::CirclePRelu *) final;
+ luci::CircleNode *visit(const luci::CircleRange *) final;
+ luci::CircleNode *visit(const luci::CircleRank *) final;
+ luci::CircleNode *visit(const luci::CircleReduceAny *) final;
+ luci::CircleNode *visit(const luci::CircleReduceMax *) final;
+ luci::CircleNode *visit(const luci::CircleReduceMin *) final;
+ luci::CircleNode *visit(const luci::CircleReduceProd *) final;
+ luci::CircleNode *visit(const luci::CircleRelu *) final;
+ luci::CircleNode *visit(const luci::CircleRelu6 *) final;
+ luci::CircleNode *visit(const luci::CircleReluN1To1 *) final;
+ luci::CircleNode *visit(const luci::CircleReshape *) final;
+ luci::CircleNode *visit(const luci::CircleResizeBilinear *) final;
+ luci::CircleNode *visit(const luci::CircleResizeNearestNeighbor *) final;
+ luci::CircleNode *visit(const luci::CircleReverseSequence *) final;
+ luci::CircleNode *visit(const luci::CircleReverseV2 *) final;
+ luci::CircleNode *visit(const luci::CircleRound *) final;
+ luci::CircleNode *visit(const luci::CircleRsqrt *) final;
+ luci::CircleNode *visit(const luci::CircleScatterNd *) final;
+ luci::CircleNode *visit(const luci::CircleSegmentSum *) final;
+ luci::CircleNode *visit(const luci::CircleSelect *) final;
+ luci::CircleNode *visit(const luci::CircleSelectV2 *) final;
+ luci::CircleNode *visit(const luci::CircleShape *) final;
+ luci::CircleNode *visit(const luci::CircleSin *) final;
+ luci::CircleNode *visit(const luci::CircleSlice *) final;
+ luci::CircleNode *visit(const luci::CircleSoftmax *) final;
+ luci::CircleNode *visit(const luci::CircleSpaceToBatchND *) final;
+ luci::CircleNode *visit(const luci::CircleSpaceToDepth *) final;
+ luci::CircleNode *visit(const luci::CircleSparseToDense *) final;
+ luci::CircleNode *visit(const luci::CircleSplit *) final;
+ luci::CircleNode *visit(const luci::CircleSplitV *) final;
+ luci::CircleNode *visit(const luci::CircleSqrt *) final;
+ luci::CircleNode *visit(const luci::CircleSquare *) final;
+ luci::CircleNode *visit(const luci::CircleSquaredDifference *) final;
+ luci::CircleNode *visit(const luci::CircleSqueeze *) final;
+ luci::CircleNode *visit(const luci::CircleStridedSlice *) final;
+ luci::CircleNode *visit(const luci::CircleSub *) final;
+ luci::CircleNode *visit(const luci::CircleSum *) final;
+ luci::CircleNode *visit(const luci::CircleTanh *) final;
+ luci::CircleNode *visit(const luci::CircleTile *) final;
+ luci::CircleNode *visit(const luci::CircleTopKV2 *) final;
+ luci::CircleNode *visit(const luci::CircleTranspose *) final;
+ luci::CircleNode *visit(const luci::CircleTransposeConv *) final;
+ luci::CircleNode *visit(const luci::CircleUnidirectionalSequenceLSTM *) final;
+ luci::CircleNode *visit(const luci::CircleUnique *) final;
+ luci::CircleNode *visit(const luci::CircleUnpack *) final;
+ luci::CircleNode *visit(const luci::CircleWhere *) final;
+ // luci::CircleNode *visit(const luci::CircleWhile *) final;
+ luci::CircleNode *visit(const luci::CircleZerosLike *) final;
+
+ // Circle Only
+ luci::CircleNode *visit(const luci::CircleBCQFullyConnected *) final;
+ luci::CircleNode *visit(const luci::CircleBCQGather *) final;
+ luci::CircleNode *visit(const luci::CircleInstanceNorm *) final;
+
+ // Virtual
+ luci::CircleNode *visit(const luci::CircleCustomOut *) final;
+ // luci::CircleNode *visit(const luci::CircleIfOut *) final;
+ // luci::CircleNode *visit(const luci::CircleInput *) final;
+ luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV4Out *) final;
+ luci::CircleNode *visit(const luci::CircleNonMaxSuppressionV5Out *) final;
+ // luci::CircleNode *visit(const luci::CircleOutput *) final;
+ luci::CircleNode *visit(const luci::CircleOutputDummy *) final;
+ luci::CircleNode *visit(const luci::CircleOutputExclude *) final;
+ luci::CircleNode *visit(const luci::CircleSplitOut *) final;
+ luci::CircleNode *visit(const luci::CircleSplitVOut *) final;
+ luci::CircleNode *visit(const luci::CircleTopKV2Out *) final;
+ luci::CircleNode *visit(const luci::CircleUniqueOut *) final;
+ luci::CircleNode *visit(const luci::CircleUnpackOut *) final;
+ // luci::CircleNode *visit(const luci::CircleWhileOut *) final;
+
+ // NOTE CircleNodeVisitor will throw if not supported here
+
+protected:
+ loco::Graph *_graph = nullptr;
+};
+
+} // namespace luci
+
+#endif // __CIRCLE_CLONE_NODE_H__
diff --git a/compiler/luci/service/src/CircleNodeClone.cpp b/compiler/luci/service/src/CircleNodeClone.cpp
new file mode 100644
index 000000000..d2033dd0c
--- /dev/null
+++ b/compiler/luci/service/src/CircleNodeClone.cpp
@@ -0,0 +1,92 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include "CircleCloneNode.h"
+
+#include <oops/UserExn.h>
+
+#include <cassert>
+
+namespace luci
+{
+
+/**
+ * @note Attributes of specific node type like keep_dims() of CircleSum are
+ * not copied.
+ */
+void copy_common_attributes(const luci::CircleNode *src, luci::CircleNode *dst)
+{
+ assert(src != nullptr);
+ assert(dst != nullptr);
+
+ dst->name(src->name());
+ dst->dtype(src->dtype());
+
+ dst->rank(src->rank());
+ for (uint32_t i = 0; i < src->rank(); i++)
+ {
+ dst->dim(i) = src->dim(i);
+ }
+ dst->shape_status(src->shape_status());
+
+ // quantparam
+ const auto *quantparam = src->quantparam();
+ if (quantparam != nullptr)
+ {
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ qparam->scale = quantparam->scale;
+ qparam->zerop = quantparam->zerop;
+ qparam->min = quantparam->min;
+ qparam->max = quantparam->max;
+ qparam->quantized_dimension = quantparam->quantized_dimension;
+
+ dst->quantparam(std::move(qparam));
+ }
+
+ // sparsity
+ const auto *sparsity = src->sparsityparam();
+ if (sparsity != nullptr)
+ {
+ auto sparam = std::make_unique<luci::SparsityParam>();
+ sparam->traversal_order = sparsity->traversal_order;
+ sparam->block_map = sparsity->block_map;
+ sparam->dim_metadata = sparsity->dim_metadata;
+
+ dst->sparsityparam(std::move(sparam));
+ }
+
+ // op version
+ dst->op_version(src->op_version());
+}
+
+/**
+ * @note Each visit implementation must copy node specific attributes.
+ */
+luci::CircleNode *clone_node(const luci::CircleNode *node, loco::Graph *graph)
+{
+ if (node == nullptr || graph == nullptr)
+ return nullptr;
+
+ CloneNode cn(graph);
+ auto cloned = node->accept(&cn);
+ if (cloned != nullptr)
+ copy_common_attributes(node, cloned);
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/CircleNodeClone.test.cpp b/compiler/luci/service/src/CircleNodeClone.test.cpp
new file mode 100644
index 000000000..5908eeb82
--- /dev/null
+++ b/compiler/luci/service/src/CircleNodeClone.test.cpp
@@ -0,0 +1,109 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+// NOTE any node will do for testing
+#include <luci/IR/Nodes/CircleAdd.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+luci::CircleAdd *build_simple_add_graph(loco::Graph *g)
+{
+ auto node = g->nodes()->create<luci::CircleAdd>();
+
+ node->name("name");
+ node->dtype(loco::DataType::FLOAT32);
+ node->rank(1);
+ node->dim(0).set(3);
+ node->shape_status(luci::ShapeStatus::VALID);
+ node->fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ qparam->scale = {1.0};
+ qparam->zerop = {0};
+ qparam->min = {0.0};
+ qparam->max = {1.0};
+ qparam->quantized_dimension = 0;
+ node->quantparam(std::move(qparam));
+
+ auto sparam = std::make_unique<luci::SparsityParam>();
+ sparam->traversal_order = {0};
+ sparam->block_map = {0};
+ sparam->dim_metadata = {luci::DimMetaData(luci::DimensionType::DENSE, 1)};
+ node->sparsityparam(std::move(sparam));
+
+ node->op_version(2);
+
+ return node;
+}
+
+} // namespace
+
+TEST(CircleNodeCloneTest, copy_attribites)
+{
+ auto g = loco::make_graph();
+ auto node = build_simple_add_graph(g.get());
+
+ auto copy = g->nodes()->create<luci::CircleAdd>();
+ luci::copy_common_attributes(node, copy);
+
+ ASSERT_EQ(node->name(), copy->name());
+ ASSERT_EQ(node->dtype(), copy->dtype());
+ ASSERT_EQ(node->rank(), copy->rank());
+ ASSERT_EQ(node->shape_status(), copy->shape_status());
+
+ const auto *qparam_node = node->quantparam();
+ const auto *qparam_copy = copy->quantparam();
+ ASSERT_EQ(qparam_node->scale, qparam_copy->scale);
+
+ const auto *sparsity_node = node->sparsityparam();
+ const auto *sparsity_copy = copy->sparsityparam();
+ ASSERT_EQ(sparsity_node->traversal_order, sparsity_copy->traversal_order);
+
+ ASSERT_EQ(node->op_version(), copy->op_version());
+}
+
+TEST(CircleNodeCloneTest, clone_add_node)
+{
+ auto g = loco::make_graph();
+ auto node = build_simple_add_graph(g.get());
+
+ auto cg = loco::make_graph();
+ auto clone = clone_node(node, cg.get());
+
+ ASSERT_NE(nullptr, clone);
+ ASSERT_EQ(cg.get(), clone->graph());
+ ASSERT_EQ(node->name(), clone->name());
+ ASSERT_EQ(node->dtype(), clone->dtype());
+ ASSERT_EQ(node->rank(), clone->rank());
+ ASSERT_EQ(node->shape_status(), clone->shape_status());
+}
+
+TEST(CircleNodeCloneTest, clone_node_NEG)
+{
+ auto g = loco::make_graph();
+ auto node = build_simple_add_graph(g.get());
+
+ auto cg = loco::make_graph();
+ auto clone = luci::clone_node(nullptr, cg.get());
+ ASSERT_EQ(nullptr, clone);
+ auto clone2 = luci::clone_node(node, nullptr);
+ ASSERT_EQ(nullptr, clone2);
+}
diff --git a/compiler/luci/service/src/CircleShapeInference.cpp b/compiler/luci/service/src/CircleShapeInference.cpp
index db8ffd8ad..73472069b 100644
--- a/compiler/luci/service/src/CircleShapeInference.cpp
+++ b/compiler/luci/service/src/CircleShapeInference.cpp
@@ -15,27 +15,16 @@
*/
#include "luci/Service/CircleShapeInference.h"
-#include "luci/Service/ShapeDescription.h"
+
+#include "CircleShapeInferenceHelper.h"
#include <loco.h>
-#include <loco/Service/ShapeInference.h>
#include <luci/Log.h>
#include <cassert>
#include <iostream>
-namespace luci
-{
-
-ShapeDescription ShapeInference::get(loco::Node *node)
-{
- assert(loco::shape_known(node));
- return to_shape_description(loco::shape_get(node));
-}
-
-} // namespace luci
-
namespace
{
@@ -46,7 +35,11 @@ std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape
{
if (r)
os << ",";
- os << tensor_shape.dim(r).value();
+
+ if (tensor_shape.dim(r).known())
+ os << tensor_shape.dim(r).value();
+ else
+ os << "?";
}
os << "]";
return os;
@@ -90,5 +83,5 @@ bool Rule::infer(const luci::CircleNode *circle_node, loco::TensorShape &shape)
return true;
}
-} // namespace ssinf
+} // namespace sinf
} // namespace luci
diff --git a/compiler/luci/service/src/CircleShapeInferenceHelper.cpp b/compiler/luci/service/src/CircleShapeInferenceHelper.cpp
index f7eb6c3ec..2009aa59f 100644
--- a/compiler/luci/service/src/CircleShapeInferenceHelper.cpp
+++ b/compiler/luci/service/src/CircleShapeInferenceHelper.cpp
@@ -14,7 +14,24 @@
* limitations under the License.
*/
-#include "luci/Service/CircleShapeInferenceHelper.h"
+#include "CircleShapeInferenceHelper.h"
+
+namespace luci
+{
+
+loco::NodeShape shape_get(const loco::Node *node)
+{
+ assert(luci::shape_known(node));
+ return loco::NodeShape{sinf::circle_shape(loco::must_cast<const luci::CircleNode *>(node))};
+}
+
+bool shape_known(const loco::Node *node)
+{
+ return loco::must_cast<const luci::CircleNode *>(node)->shape_status() !=
+ luci::ShapeStatus::UNDEFINED;
+}
+
+} // namespace luci
namespace luci
{
@@ -26,7 +43,7 @@ loco::TensorShape circle_shape(const luci::CircleNode *node)
loco::TensorShape shape;
shape.rank(node->rank());
for (uint32_t r = 0; r < node->rank(); ++r)
- shape.dim(r) = loco::Dimension(node->dim(r).value());
+ shape.dim(r) = node->dim(r);
return shape;
}
diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h b/compiler/luci/service/src/CircleShapeInferenceHelper.h
index dd6a5a454..7c7ea496c 100644
--- a/compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h
+++ b/compiler/luci/service/src/CircleShapeInferenceHelper.h
@@ -17,10 +17,24 @@
#ifndef __LUCI_CIRCLE_SHAPE_INFERENCE_HELPER_H__
#define __LUCI_CIRCLE_SHAPE_INFERENCE_HELPER_H__
+#include <loco/IR/NodeShape.h>
#include <loco/IR/TensorShape.h>
#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleShapeSignature.h>
+
+namespace luci
+{
+
+// NOTE Functions in this namespace will be removed after new inference
+// algorithms are fully implemented.
+
+// This function is temporary function for deprecating loco::shape_get
+loco::NodeShape shape_get(const loco::Node *node);
+
+// This function is temporary function for deprecating loco::shape_known
+bool shape_known(const loco::Node *node);
+
+} // namespace luci
namespace luci
{
diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
index 38ff619ab..c6d8232c3 100644
--- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
@@ -17,6 +17,7 @@
#include "luci/Service/CircleShapeInferenceRule.h"
#include "Check.h"
+#include "CircleShapeInferenceHelper.h"
#include "ShapeInfer_StridedSlice.h"
#include <luci/IR/CircleNodes.h>
@@ -41,7 +42,11 @@ std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape
{
if (r)
os << ",";
- os << tensor_shape.dim(r).value();
+
+ if (tensor_shape.dim(r).known())
+ os << tensor_shape.dim(r).value();
+ else
+ os << "?";
}
os << "]";
return os;
@@ -52,7 +57,15 @@ loco::TensorShape own_shape(const luci::CircleNode *node)
loco::TensorShape shape;
shape.rank(node->rank());
for (uint32_t r = 0; r < node->rank(); ++r)
- shape.dim(r) = loco::Dimension(node->dim(r).value());
+ {
+ // Shape inference rules in this file did not consider unknown dimension.
+ // If some node has unknown dimension, 0 is inserted and wrong shape
+ // inference was done as a result.
+ // To fix this, new shape inference algorithm is being implemented.
+ // Until new inference algorithm is fully implemented, unknown dimension
+ // would be represented as 1 along with TFLite expression.
+ shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1;
+ }
return shape;
}
@@ -135,10 +148,8 @@ loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::Tenso
output_shape.rank(rank);
for (uint32_t axis = 0; axis < rank; ++axis)
{
- assert(x.dim(axis).known() && y.dim(axis).known());
-
- auto x_dim = x.dim(axis).value();
- auto y_dim = y.dim(axis).value();
+ auto x_dim = x.dim(axis).known() ? x.dim(axis).value() : 1;
+ auto y_dim = y.dim(axis).known() ? y.dim(axis).value() : 1;
// each dimension of x and y should be same or one must be 1 if different
if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
@@ -177,23 +188,29 @@ template <loco::DataType T> std::vector<int64_t> vector_from_constant(luci::Circ
template <class CIRCLENODE> loco::NodeShape broadcast_xy(const CIRCLENODE *node)
{
- auto x_shape = loco::shape_get(node->x()).template as<loco::TensorShape>();
- auto y_shape = loco::shape_get(node->y()).template as<loco::TensorShape>();
+ auto x_shape = luci::shape_get(node->x()).template as<loco::TensorShape>();
+ auto y_shape = luci::shape_get(node->y()).template as<loco::TensorShape>();
auto output_shape = broadcast_shape(x_shape, y_shape);
return loco::NodeShape{output_shape};
}
+template <class CIRCLENODE> loco::NodeShape use_inputs(const CIRCLENODE *node)
+{
+ auto inputs_shape = luci::shape_get(node->inputs()).template as<loco::TensorShape>();
+ return loco::NodeShape{inputs_shape};
+}
+
template <class CIRCLENODE> loco::NodeShape use_x(const CIRCLENODE *node)
{
- auto x_shape = loco::shape_get(node->x()).template as<loco::TensorShape>();
+ auto x_shape = luci::shape_get(node->x()).template as<loco::TensorShape>();
return loco::NodeShape{x_shape};
}
template <class CIRCLENODE> loco::NodeShape use_logits(const CIRCLENODE *node)
{
- auto shape = loco::shape_get(node->logits()).template as<loco::TensorShape>();
+ auto shape = luci::shape_get(node->logits()).template as<loco::TensorShape>();
return loco::NodeShape{shape};
}
@@ -202,7 +219,7 @@ loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *pa
{
const loco::DataType S32 = loco::DataType::S32;
- auto input_shape = loco::shape_get(node->input()).template as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
// TODO support other data type
LUCI_ASSERT(paddings->dtype() == S32, "Only support int 32 for now");
@@ -232,11 +249,11 @@ loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *pa
loco::NodeShape infer_add_n(const luci::CircleAddN *node)
{
- auto shape = loco::shape_get(node->inputs(0)).as<loco::TensorShape>();
+ auto shape = luci::shape_get(node->inputs(0)).as<loco::TensorShape>();
for (uint32_t idx = 1; idx < node->arity(); ++idx)
{
- auto shape_idx = loco::shape_get(node->inputs(idx)).as<loco::TensorShape>();
+ auto shape_idx = luci::shape_get(node->inputs(idx)).as<loco::TensorShape>();
if (!(shape == shape_idx))
{
INTERNAL_EXN_V("ADD_N shape not same as the first input: ", idx);
@@ -247,8 +264,8 @@ loco::NodeShape infer_add_n(const luci::CircleAddN *node)
loco::NodeShape infer_arg_max(const luci::CircleArgMax *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
- auto dimension_shape = loco::shape_get(node->dimension()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ auto dimension_shape = luci::shape_get(node->dimension()).as<loco::TensorShape>();
int64_t select_axis = 0;
{
@@ -286,8 +303,8 @@ loco::NodeShape infer_arg_max(const luci::CircleArgMax *node)
loco::NodeShape infer_arg_min(const luci::CircleArgMin *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
- auto dimension_shape = loco::shape_get(node->dimension()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ auto dimension_shape = luci::shape_get(node->dimension()).as<loco::TensorShape>();
int64_t select_axis = 0;
{
@@ -326,9 +343,7 @@ loco::NodeShape infer_arg_min(const luci::CircleArgMin *node)
// Call this for CircleAvgPool2D and CircleMaxPool2D only
template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType *node)
{
- LUCI_ASSERT(loco::shape_known(node->value()), "Shape must be known");
-
- auto ifm_shape = loco::shape_get(node->value()).template as<loco::TensorShape>();
+ auto ifm_shape = luci::shape_get(node->value()).template as<loco::TensorShape>();
assert(ifm_shape.rank() == 4);
uint32_t input_height = ifm_shape.dim(1).value();
@@ -372,7 +387,7 @@ loco::NodeShape infer_batch_to_space_nd(const luci::CircleBatchToSpaceND *node)
{
const loco::DataType S32 = loco::DataType::S32;
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
// Support only input rank is 3 and 4
assert(input_shape.rank() == 3 || input_shape.rank() == 4);
@@ -384,8 +399,8 @@ loco::NodeShape infer_batch_to_space_nd(const luci::CircleBatchToSpaceND *node)
auto const_crops = loco::must_cast<luci::CircleConst *>(node->crops());
LUCI_ASSERT(const_crops->dtype() == loco::DataType::S32, "Only support int32 crops");
- auto const_block_shape_shape = loco::shape_get(const_block_shape).as<loco::TensorShape>();
- auto const_crops_shape = loco::shape_get(const_crops).as<loco::TensorShape>();
+ auto const_block_shape_shape = luci::shape_get(const_block_shape).as<loco::TensorShape>();
+ auto const_crops_shape = luci::shape_get(const_crops).as<loco::TensorShape>();
assert(const_block_shape_shape.rank() == 1);
assert(const_crops_shape.rank() == 2);
@@ -423,8 +438,8 @@ struct OutputSize
template <class Conv2DType> OutputSize infer_conv2d_type(const Conv2DType *node)
{
- auto ifm_shape = loco::shape_get(node->input()).template as<loco::TensorShape>();
- auto ker_shape = loco::shape_get(node->filter()).template as<loco::TensorShape>();
+ auto ifm_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
+ auto ker_shape = luci::shape_get(node->filter()).template as<loco::TensorShape>();
assert(ifm_shape.rank() == 4);
assert(ker_shape.rank() == 4);
@@ -496,7 +511,7 @@ loco::NodeShape infer_batchmatmul_shape(const loco::TensorShape &x_shape,
loco::Dimension y_lhs = adj_y ? y_shape.dim(y_rank - 1) : y_shape.dim(y_rank - 2);
loco::Dimension y_rhs = adj_y ? y_shape.dim(y_rank - 2) : y_shape.dim(y_rank - 1);
- if (not(x_rhs == y_lhs))
+ if (x_rhs.known() && y_lhs.known() && not(x_rhs == y_lhs))
INTERNAL_EXN("x_rhs and y_lhs should be same");
uint32_t out_rank = output_shape.rank();
@@ -511,7 +526,7 @@ loco::NodeShape infer_concatenation(const luci::CircleConcatenation *node)
// TODO Support when CircleConcatenation has 0 input
assert(node->numValues() > 0);
- auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>();
+ auto first_shape = luci::shape_get(node->values(0)).as<loco::TensorShape>();
auto axis = node->axis();
if (axis < 0)
axis += first_shape.rank();
@@ -527,14 +542,20 @@ loco::NodeShape infer_concatenation(const luci::CircleConcatenation *node)
for (uint32_t i = 1; i < node->numValues(); ++i)
{
- auto input_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->values(i)).as<loco::TensorShape>();
for (uint32_t j = 0; j < output_shape.rank(); ++j)
{
if (j == static_cast<uint32_t>(axis))
+ {
+ // If dimension is unknown, value() will return 0.
+ // This is wrong but until new inference algorithm is implemented,
+ // this code will not be modified to keep compatibility.
output_shape.dim(j) = output_shape.dim(j).value() + input_shape.dim(j).value();
+ }
else
- assert(output_shape.dim(j) == input_shape.dim(j));
+ assert(!output_shape.dim(j).known() || !input_shape.dim(j).known() ||
+ output_shape.dim(j) == input_shape.dim(j));
}
}
@@ -545,8 +566,8 @@ loco::NodeShape infer_conv2d(const luci::CircleConv2D *node)
{
LOGGER(l);
- auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
- auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI
+ auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
+ auto ker_shape = luci::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI
INFO(l) << "[luci] CircleConv2D ShapeInf ifm(" << ifm_shape.rank() << ") ker(" << ker_shape.rank()
<< ")" << std::endl;
@@ -569,7 +590,7 @@ loco::NodeShape infer_conv2d(const luci::CircleConv2D *node)
loco::NodeShape infer_depth_to_space(const luci::CircleDepthToSpace *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported");
// Only data format NHWC is supported
@@ -601,12 +622,13 @@ loco::NodeShape infer_depth_to_space(const luci::CircleDepthToSpace *node)
loco::NodeShape infer_depthwise_conv2d(const luci::CircleDepthwiseConv2D *node)
{
- auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
- auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM
+ auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
+ auto ker_shape = luci::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM
assert(ifm_shape.rank() == 4);
assert(ker_shape.rank() == 4);
assert(ker_shape.dim(0).value() == 1);
+ assert(ifm_shape.dim(3).value() * node->depthMultiplier() == ker_shape.dim(3).value());
auto os = infer_conv2d_type(node);
@@ -623,7 +645,7 @@ loco::NodeShape infer_depthwise_conv2d(const luci::CircleDepthwiseConv2D *node)
loco::NodeShape infer_expand_dims(const luci::CircleExpandDims *node)
{
const loco::DataType S32 = loco::DataType::S32;
- auto x_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto x_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
if (x_shape.rank() == 0)
{
// This maybe for unknown shape. We use shape from the node itself.
@@ -637,7 +659,7 @@ loco::NodeShape infer_expand_dims(const luci::CircleExpandDims *node)
}
int32_t axis = const_axis->at<S32>(0);
LUCI_ASSERT((axis <= static_cast<int32_t>(x_shape.rank())) &&
- (axis >= -1 - static_cast<int32_t>(x_shape.rank())),
+ (axis >= -1 - static_cast<int32_t>(x_shape.rank())),
"Axis has to be between [-(D+1), D], where D is rank of input.");
size_t positive_axis = axis < 0 ? x_shape.rank() + axis + 1 : axis;
loco::TensorShape output_shape;
@@ -684,8 +706,8 @@ loco::NodeShape infer_fill(const luci::CircleFill *node)
loco::NodeShape infer_fully_connected(const luci::CircleFullyConnected *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
- auto weights_shape = loco::shape_get(node->weights()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ auto weights_shape = luci::shape_get(node->weights()).as<loco::TensorShape>();
// Checking shape capability for fully connected layer
// Input: a tensor of at least rank 2 [D1, D2, ... Dn]
@@ -715,8 +737,8 @@ loco::NodeShape infer_gather(const luci::CircleGather *node)
{
loco::TensorShape output_shape;
- const auto input_shape = loco::shape_get(node->params()).as<loco::TensorShape>();
- const auto positions_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
+ const auto input_shape = luci::shape_get(node->params()).as<loco::TensorShape>();
+ const auto positions_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
int32_t axis = node->axis();
// If CircleGather input has a dynamic shape, it can't inference this shape. So, it returns the
@@ -743,8 +765,8 @@ loco::NodeShape infer_gather_nd(const luci::CircleGatherNd *node)
{
loco::TensorShape output_shape;
- const auto params_shape = loco::shape_get(node->params()).as<loco::TensorShape>();
- const auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
+ const auto params_shape = luci::shape_get(node->params()).as<loco::TensorShape>();
+ const auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
const auto params_rank = params_shape.rank();
const auto indices_rank = indices_shape.rank();
@@ -791,7 +813,7 @@ loco::NodeShape infer_matrix_diag(const luci::CircleMatrixDiag *node)
{
loco::TensorShape output_shape;
- auto diagonal_shape = loco::shape_get(node->diagonal()).as<loco::TensorShape>();
+ auto diagonal_shape = luci::shape_get(node->diagonal()).as<loco::TensorShape>();
auto rank = diagonal_shape.rank();
output_shape.rank(rank + 1);
@@ -808,8 +830,8 @@ loco::NodeShape infer_matrix_diag(const luci::CircleMatrixDiag *node)
loco::NodeShape infer_matrix_set_diag(const luci::CircleMatrixSetDiag *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
- auto diagonal_shape = loco::shape_get(node->diagonal()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ auto diagonal_shape = luci::shape_get(node->diagonal()).as<loco::TensorShape>();
auto rank = diagonal_shape.rank();
@@ -831,7 +853,7 @@ loco::TensorShape infer_reducer(const loco::Node *input, const loco::Node *indic
{
const loco::DataType S32 = loco::DataType::S32;
- auto input_shape = loco::shape_get(input).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(input).as<loco::TensorShape>();
auto reduction_indices = loco::must_cast<const luci::CircleConst *>(indices);
{ // Exceptions
@@ -892,7 +914,7 @@ loco::NodeShape infer_mirror_pad(const luci::CircleMirrorPad *node)
loco::NodeShape infer_one_hot(const luci::CircleOneHot *node)
{
const loco::DataType S32 = loco::DataType::S32;
- auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
+ auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
// Only support OneHot node's depth() is CircleConst with type S32
// TODO support depth with other types
auto depth = loco::must_cast<luci::CircleConst *>(node->depth());
@@ -925,11 +947,11 @@ loco::NodeShape infer_pack(const luci::CirclePack *node)
{
LUCI_ASSERT(node->values_count() > 0, "Only support one or more inputs");
- auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>();
+ auto first_shape = luci::shape_get(node->values(0)).as<loco::TensorShape>();
// Make sure all inputs have the same shape.
for (uint32_t i = 1; i < node->values_count(); ++i)
{
- auto in_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>();
+ auto in_shape = luci::shape_get(node->values(i)).as<loco::TensorShape>();
LUCI_ASSERT(loco::NodeShape{first_shape} == loco::NodeShape{in_shape},
"All inputs must have the same shape");
}
@@ -985,8 +1007,8 @@ loco::NodeShape infer_pad_v2(const luci::CirclePadV2 *node)
loco::NodeShape infer_p_relu(const luci::CirclePRelu *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
- auto alpha_shape = loco::shape_get(node->alpha()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ auto alpha_shape = luci::shape_get(node->alpha()).as<loco::TensorShape>();
auto output_shape = broadcast_shape(input_shape, alpha_shape);
@@ -1087,10 +1109,12 @@ loco::NodeShape infer_reshape(const luci::CircleReshape *node)
loco::TensorShape output_shape = shape_by_input;
// One of the dimensions can have special value -1, meaning its actual value should be inferred.
- const auto input_shape = loco::shape_get(node->tensor()).as<loco::TensorShape>();
- const uint32_t input_element_count = loco::element_count(&input_shape);
+ const auto input_shape = luci::shape_get(node->tensor()).as<loco::TensorShape>();
+ uint32_t input_element_count = 1;
uint32_t output_element_count = 1;
uint32_t unknown_dim_index = UINT32_MAX;
+ for (uint32_t i = 0; i < input_shape.rank(); ++i)
+ input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1);
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
{
const uint32_t dim_value = output_shape.dim(dim_index).value();
@@ -1114,7 +1138,7 @@ loco::NodeShape infer_reshape(const luci::CircleReshape *node)
loco::NodeShape infer_resize_bilinear(const luci::CircleResizeBilinear *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
if (input_shape.rank() != 4)
INTERNAL_EXN("Expected ResizeBilinear input to have rank 4");
@@ -1142,7 +1166,7 @@ loco::NodeShape infer_resize_bilinear(const luci::CircleResizeBilinear *node)
loco::NodeShape infer_resize_nearest_neighbor(const luci::CircleResizeNearestNeighbor *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
if (input_shape.rank() != 4)
INTERNAL_EXN("Expected ResizeNearesNeighbor input to have rank 4");
@@ -1195,8 +1219,8 @@ loco::NodeShape infer_scatter_nd(const luci::CircleScatterNd *node)
loco::NodeShape infer_segment_sum(const luci::CircleSegmentSum *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
- auto segment_shape = loco::shape_get(node->segment_ids()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ auto segment_shape = luci::shape_get(node->segment_ids()).as<loco::TensorShape>();
LUCI_ASSERT(segment_shape.rank() == 1, "segment_ids must be 1-D tensor");
LUCI_ASSERT(segment_shape.dim(0).value() == input_shape.dim(0).value(),
@@ -1226,11 +1250,11 @@ loco::NodeShape infer_segment_sum(const luci::CircleSegmentSum *node)
loco::NodeShape infer_select(const luci::CircleSelect *node)
{
- auto t_shape = loco::shape_get(node->t()).as<loco::TensorShape>();
- assert(t_shape == loco::shape_get(node->e()).as<loco::TensorShape>());
+ auto t_shape = luci::shape_get(node->t()).as<loco::TensorShape>();
+ assert(t_shape == luci::shape_get(node->e()).as<loco::TensorShape>());
// condition shape validation
- auto c_shape = loco::shape_get(node->condition()).as<loco::TensorShape>();
+ auto c_shape = luci::shape_get(node->condition()).as<loco::TensorShape>();
if (c_shape.rank() != t_shape.rank())
{
if (c_shape.rank() != 0 && c_shape.rank() != 1)
@@ -1248,9 +1272,9 @@ loco::NodeShape infer_select(const luci::CircleSelect *node)
loco::NodeShape infer_select_v2(const luci::CircleSelectV2 *node)
{
- auto c_shape = loco::shape_get(node->condition()).as<loco::TensorShape>();
- auto t_shape = loco::shape_get(node->t()).as<loco::TensorShape>();
- auto e_shape = loco::shape_get(node->e()).as<loco::TensorShape>();
+ auto c_shape = luci::shape_get(node->condition()).as<loco::TensorShape>();
+ auto t_shape = luci::shape_get(node->t()).as<loco::TensorShape>();
+ auto e_shape = luci::shape_get(node->e()).as<loco::TensorShape>();
// validate ability to broadcast shapes to each other
auto b_shape = broadcast_shape(broadcast_shape(c_shape, t_shape), e_shape);
@@ -1259,7 +1283,7 @@ loco::NodeShape infer_select_v2(const luci::CircleSelectV2 *node)
loco::NodeShape infer_shape(const luci::CircleShape *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
loco::TensorShape output_shape;
@@ -1274,7 +1298,7 @@ loco::NodeShape infer_slice(const luci::CircleSlice *node)
const loco::DataType S32 = loco::DataType::S32;
const loco::DataType S64 = loco::DataType::S64;
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
auto const_begin = loco::must_cast<luci::CircleConst *>(node->begin());
auto const_size = loco::must_cast<luci::CircleConst *>(node->size());
@@ -1318,7 +1342,7 @@ loco::NodeShape infer_space_to_batch_nd(const luci::CircleSpaceToBatchND *node)
{
const loco::DataType S32 = loco::DataType::S32;
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
// Support only input rank is 3 and 4
assert(input_shape.rank() == 3 || input_shape.rank() == 4);
@@ -1330,8 +1354,8 @@ loco::NodeShape infer_space_to_batch_nd(const luci::CircleSpaceToBatchND *node)
auto const_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
LUCI_ASSERT(const_paddings->dtype() == S32, "Only support int32 paddings");
- auto const_block_shape_shape = loco::shape_get(const_block_shape).as<loco::TensorShape>();
- auto const_paddings_shape = loco::shape_get(const_paddings).as<loco::TensorShape>();
+ auto const_block_shape_shape = luci::shape_get(const_block_shape).as<loco::TensorShape>();
+ auto const_paddings_shape = luci::shape_get(const_paddings).as<loco::TensorShape>();
assert(const_block_shape_shape.rank() == 1);
assert(const_paddings_shape.rank() == 2);
@@ -1374,7 +1398,7 @@ loco::NodeShape infer_space_to_batch_nd(const luci::CircleSpaceToBatchND *node)
loco::NodeShape infer_space_to_depth(const luci::CircleSpaceToDepth *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported");
// Only data format NHWC is supported
@@ -1412,19 +1436,33 @@ loco::NodeShape infer_sparse_to_dense(const luci::CircleSparseToDense *node)
auto output_shape_node = dynamic_cast<luci::CircleConst *>(node->output_shape());
if (output_shape_node != nullptr)
{
- // Only support node with S32
- LUCI_ASSERT(output_shape_node->dtype() == loco::DataType::S32,
- "Only support int32 CircleConst");
+ const auto output_shape_type = output_shape_node->dtype();
if (output_shape_node->rank() != 1)
INTERNAL_EXN_V("Only support rank 1 CircleConst",
oops::to_uint32(output_shape_node->rank()));
- shape.rank(output_shape_node->size<loco::DataType::S32>());
+ if (output_shape_type == loco::DataType::S32)
+ {
+ shape.rank(output_shape_node->size<loco::DataType::S32>());
- for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+ for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+ {
+ shape.dim(axis) = output_shape_node->at<loco::DataType::S32>(axis);
+ }
+ }
+ else if (output_shape_type == loco::DataType::S64)
{
- shape.dim(axis) = output_shape_node->at<loco::DataType::S32>(axis);
+ shape.rank(output_shape_node->size<loco::DataType::S64>());
+
+ for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+ {
+ shape.dim(axis) = output_shape_node->at<loco::DataType::S64>(axis);
+ }
+ }
+ else
+ {
+ INTERNAL_EXN("Output shape of SparseToDense must be either int32 or int64");
}
}
else
@@ -1453,7 +1491,7 @@ loco::NodeShape infer_strided_slice(const luci::CircleStridedSlice *node)
loco::NodeShape infer_squeeze(const luci::CircleSqueeze *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
// TODO input shape may be unknown before runtime
std::vector<bool> do_squeeze(input_shape.rank(), false);
@@ -1508,7 +1546,7 @@ loco::NodeShape infer_tile(const luci::CircleTile *node)
{
const loco::DataType S32 = loco::DataType::S32;
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
auto multiples = loco::must_cast<luci::CircleConst *>(node->multiples());
// TODO support non-const case
@@ -1534,7 +1572,7 @@ loco::NodeShape infer_tile(const luci::CircleTile *node)
loco::NodeShape infer_transpose(const luci::CircleTranspose *node)
{
- auto input_shape = loco::shape_get(node->a()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->a()).as<loco::TensorShape>();
auto perm_node = loco::must_cast<luci::CircleConst *>(node->perm());
@@ -1576,7 +1614,7 @@ loco::NodeShape infer_unpack(const luci::CircleUnpack *node)
// CircleUnpack provides list(array) of Tensors which has one less dimension of the input
// We'll set shape of CircleUnpack to shape of actual outputs
// TODO fix this if any problem rises
- auto value_shape = loco::shape_get(node->value()).as<loco::TensorShape>();
+ auto value_shape = luci::shape_get(node->value()).as<loco::TensorShape>();
auto axis = node->axis();
auto num = node->num();
@@ -1610,9 +1648,9 @@ loco::NodeShape infer_unpack(const luci::CircleUnpack *node)
loco::NodeShape infer_unidirectionalsequencelstm(const luci::CircleUnidirectionalSequenceLSTM *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
auto recurrent_to_output_weights =
- loco::shape_get(node->recurrent_to_output_weights()).as<loco::TensorShape>();
+ luci::shape_get(node->recurrent_to_output_weights()).as<loco::TensorShape>();
auto rank = input_shape.rank();
loco::TensorShape output_shape;
output_shape.rank(rank);
@@ -1626,7 +1664,7 @@ loco::NodeShape infer_unidirectionalsequencelstm(const luci::CircleUnidirectiona
loco::NodeShape infer_unique(const luci::CircleUnique *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
assert(input_shape.rank() == 1);
@@ -1641,7 +1679,7 @@ loco::NodeShape infer_bcq_fully_connected(const luci::CircleBCQFullyConnected *n
{
loco::TensorShape out_shape;
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
auto weights_clusters = loco::must_cast<luci::CircleConst *>(node->weights_clusters());
LUCI_ASSERT(input_shape.rank() == 2, "Input rank of BCQFullyConnected should be 2");
@@ -1664,8 +1702,8 @@ loco::NodeShape infer_bcq_gather(const luci::CircleBCQGather *node)
loco::TensorShape input_shape;
loco::TensorShape output_shape;
- const auto input_binary_shape = loco::shape_get(node->input_binary()).as<loco::TensorShape>();
- const auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
+ const auto input_binary_shape = luci::shape_get(node->input_binary()).as<loco::TensorShape>();
+ const auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
auto axis = node->axis();
auto input_clusters = loco::must_cast<luci::CircleConst *>(node->input_clusters());
@@ -1712,46 +1750,6 @@ loco::NodeShape infer_output(const luci::CircleOutput *node)
return loco::NodeShape{*output_shape};
}
-loco::NodeShape infer_if_out(const luci::CircleIfOut *node)
-{
- /**
- * @note IF operator type and shape are that of the "then" and "else"
- * Graph Outputs.
- */
- auto circle_if = dynamic_cast<const luci::CircleIf *>(node->input());
- if (circle_if == nullptr)
- {
- INTERNAL_EXN("CircleIf IR is not configured correctly");
- }
-
- auto index = node->index();
- auto then_graph = circle_if->then_graph();
- auto else_graph = circle_if->else_graph();
- assert(then_graph != nullptr);
- assert(else_graph != nullptr);
-
- // shape and type are assumed to be same
- // these are checked at post_import_graph() in Import
- auto then_outputs = loco::output_nodes(then_graph);
- auto else_outputs = loco::output_nodes(else_graph);
- assert(then_outputs.size() == else_outputs.size());
- assert(index < static_cast<int32_t>(then_outputs.size()));
-
- auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index));
- auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index));
-
- auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items
- auto else_graph_outputs = else_graph->outputs();
- assert(then_graph_outputs->size() == else_graph_outputs->size());
-
- auto then_graph_output = then_graph_outputs->at(then_out->index());
- auto else_graph_output = else_graph_outputs->at(else_out->index());
- (void)else_graph_output; // make compiler happy for unused variable warnings
- assert(*then_graph_output->shape() == *else_graph_output->shape());
-
- return loco::NodeShape{*then_graph_output->shape()};
-}
-
loco::NodeShape infer_non_max_suppression_v4_out(const luci::CircleNonMaxSuppressionV4Out *node)
{
const loco::DataType S32 = loco::DataType::S32;
@@ -1818,7 +1816,7 @@ loco::NodeShape infer_split_out(const luci::CircleSplitOut *node)
loco::NodeShape unknown;
- auto split_shape = loco::shape_get(split).as<loco::TensorShape>();
+ auto split_shape = luci::shape_get(split).as<loco::TensorShape>();
auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
if (split_dim == nullptr)
@@ -1852,7 +1850,7 @@ loco::NodeShape infer_split_v_out(const luci::CircleSplitVOut *node)
loco::NodeShape unknown;
- auto split_shape = loco::shape_get(split).as<loco::TensorShape>();
+ auto split_shape = luci::shape_get(split).as<loco::TensorShape>();
auto size_splits = dynamic_cast<const luci::CircleConst *>(split->size_splits());
if (size_splits == nullptr)
@@ -1913,7 +1911,7 @@ loco::NodeShape infer_top_k_v2_out(const luci::CircleTopKV2Out *node)
INTERNAL_EXN("CircleSplit IR is not configured correctly");
// shape of topkv2 is same as topkv2->input()
- auto input_shape = loco::shape_get(topkv2).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(topkv2).as<loco::TensorShape>();
auto node_k = loco::must_cast<const luci::CircleConst *>(topkv2->k());
LUCI_ASSERT(node_k->dtype() == S32, "Only support Int32");
@@ -1940,7 +1938,7 @@ loco::NodeShape infer_unique_out(const luci::CircleUniqueOut *node)
}
assert(node->index() == 1);
auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
- auto unique_shape = loco::shape_get(unique->input()).as<loco::TensorShape>();
+ auto unique_shape = luci::shape_get(unique->input()).as<loco::TensorShape>();
assert(unique_shape.rank() == 1);
@@ -1958,7 +1956,7 @@ loco::NodeShape infer_unpack_out(const luci::CircleUnpackOut *node)
INTERNAL_EXN("CircleUnpack IR is not configured correctly");
}
- auto unpack_shape = loco::shape_get(unpack).as<loco::TensorShape>();
+ auto unpack_shape = luci::shape_get(unpack).as<loco::TensorShape>();
return loco::NodeShape{unpack_shape};
}
@@ -2025,8 +2023,8 @@ public:
loco::NodeShape visit(const luci::CircleBatchMatMul *node) final
{
- auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
- auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
+ auto x_shape = luci::shape_get(node->x()).as<loco::TensorShape>();
+ auto y_shape = luci::shape_get(node->y()).as<loco::TensorShape>();
return infer_batchmatmul_shape(x_shape, y_shape, node->adj_x(), node->adj_y());
}
@@ -2065,7 +2063,7 @@ public:
loco::NodeShape visit(const luci::CircleDequantize *node) final
{
- const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2073,7 +2071,7 @@ public:
loco::NodeShape visit(const luci::CircleElu *node) final
{
- auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2087,6 +2085,8 @@ public:
return infer_expand_dims(node);
}
+ loco::NodeShape visit(const luci::CircleFakeQuant *node) final { return use_inputs(node); }
+
loco::NodeShape visit(const luci::CircleFill *node) final { return infer_fill(node); }
loco::NodeShape visit(const luci::CircleFloor *node) final { return use_x(node); }
@@ -2112,7 +2112,7 @@ public:
{
// Shape of CircleIf is not used. Just use input 0
assert(node->input_count() > 0);
- const auto input_shape = loco::shape_get(node->input(0)).as<loco::TensorShape>();
+ const auto input_shape = luci::shape_get(node->input(0)).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2125,7 +2125,7 @@ public:
loco::NodeShape visit(const luci::CircleLeakyRelu *node) final
{
- const auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
+ const auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2135,7 +2135,7 @@ public:
loco::NodeShape visit(const luci::CircleLocalResponseNormalization *node) final
{
- const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2184,13 +2184,13 @@ public:
loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4 *node) final
{
- const auto boxes_shape = loco::shape_get(node->boxes()).as<loco::TensorShape>();
+ const auto boxes_shape = luci::shape_get(node->boxes()).as<loco::TensorShape>();
return loco::NodeShape{boxes_shape};
}
loco::NodeShape visit(const luci::CircleNonMaxSuppressionV5 *node) final
{
- const auto boxes_shape = loco::shape_get(node->boxes()).as<loco::TensorShape>();
+ const auto boxes_shape = luci::shape_get(node->boxes()).as<loco::TensorShape>();
return loco::NodeShape{boxes_shape};
}
@@ -2244,21 +2244,21 @@ public:
loco::NodeShape visit(const luci::CircleRelu *node) final
{
- auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
loco::NodeShape visit(const luci::CircleRelu6 *node) final
{
- auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
loco::NodeShape visit(const luci::CircleReluN1To1 *node) final
{
- auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2284,7 +2284,7 @@ public:
loco::NodeShape visit(const luci::CircleReverseSequence *node) final
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2293,9 +2293,9 @@ public:
loco::NodeShape visit(const luci::CircleReverseV2 *node) final
{
- auto input_shape = loco::shape_get(node->tensor()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->tensor()).as<loco::TensorShape>();
- LUCI_ASSERT(loco::shape_get(node->axis()).as<loco::TensorShape>().rank() == 1,
+ LUCI_ASSERT(luci::shape_get(node->axis()).as<loco::TensorShape>().rank() == 1,
"Tensor must be 1-D");
return loco::NodeShape{input_shape};
@@ -2340,14 +2340,14 @@ public:
loco::NodeShape visit(const luci::CircleSplit *node) final
{
// We'll set Split output as same as input so that SplitOut can handle it's own shape
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
loco::NodeShape visit(const luci::CircleSplitV *node) final
{
// We'll set SplitV output as same as input so that SplitOut can handle it's own shape
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2382,7 +2382,7 @@ public:
loco::NodeShape visit(const luci::CircleTopKV2 *node) final
{
// set shape of this node as same as input
- const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2408,13 +2408,13 @@ public:
{
// Shape of CircleWhile is not used. Just use input 0
assert(node->arity() > 0);
- const auto input_shape = loco::shape_get(node->input(0)).as<loco::TensorShape>();
+ const auto input_shape = luci::shape_get(node->input(0)).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
loco::NodeShape visit(const luci::CircleZerosLike *node) final
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2429,7 +2429,7 @@ public:
loco::NodeShape visit(const luci::CircleInstanceNorm *node) final
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2445,8 +2445,6 @@ public:
loco::NodeShape visit(const luci::CircleCustomOut *node) final { return use_own(node); }
- loco::NodeShape visit(const luci::CircleIfOut *node) final { return infer_if_out(node); }
-
loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final
{
return infer_non_max_suppression_v4_out(node);
diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.test.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.test.cpp
deleted file mode 100644
index ac27db3bd..000000000
--- a/compiler/luci/service/src/CircleShapeInferenceRule.test.cpp
+++ /dev/null
@@ -1,626 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "TestGraph.h"
-#include "luci/Service/CircleShapeInferenceRule.h"
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleDialect.h>
-
-#include <loco.h>
-#include <loco/IR/CanonicalDialect.h>
-#include <loco/Service/ShapeInference.h>
-#include <loco/Service/CanonicalShapeInferenceRule.h>
-#include <loco/Service/MultiDialectShapeInferenceRule.h>
-
-#include <oops/InternalExn.h>
-
-#include <gtest/gtest.h>
-
-#include <memory>
-
-namespace
-{
-
-bool shape_pass(loco::Graph *g)
-{
- loco::CanonicalShapeInferenceRule canonical_rule;
- luci::CircleShapeInferenceRule circle_rule;
- loco::MultiDialectShapeInferenceRule rules;
-
- rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
- .bind(luci::CircleDialect::get(), &circle_rule);
-
- return loco::apply(&rules).to(g);
-}
-
-} // namespace
-
-TEST(CircleShapeInferenceRuleTest, minimal_with_CircleRelu)
-{
- // Create a simple network
- luci::test::TestGraph graph;
- auto relu_node = graph.append<luci::CircleRelu>(graph.input_node);
- graph.complete(relu_node);
-
- // set shape
- {
- graph.input_node->rank(2);
- graph.input_node->dim(0) = 3;
- graph.input_node->dim(1) = 4;
-
- graph.output_node->rank(2);
- graph.output_node->dim(0) = 3;
- graph.output_node->dim(1) = 4;
-
- luci::test::graph_input_shape(graph.input_node);
- luci::test::graph_output_shape(graph.output_node);
- }
-
- // pre-check
- ASSERT_FALSE(loco::shape_known(relu_node));
-
- // shape inference
- while (shape_pass(graph.graph()) == true)
- ;
-
- // Verify
- {
- ASSERT_TRUE(loco::shape_known(relu_node));
- ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(relu_node).domain());
-
- auto shape = loco::shape_get(relu_node).as<loco::TensorShape>();
- ASSERT_EQ(2, shape.rank());
- ASSERT_EQ(3, shape.dim(0));
- ASSERT_EQ(4, shape.dim(1));
- }
-}
-
-// based on the case shown in
-// https://www.corvil.com/kb/what-is-the-difference-between-same-and-valid-padding-in-tf-nn-max-pool-of-tensorflow
-TEST(CircleShapeInferenceRuleTest, avgpool2d_valid)
-{
- luci::test::TestGraph graph;
- auto avg_node = graph.append<luci::CircleAveragePool2D>(graph.input_node);
- graph.complete();
-
- auto input_node = graph.input_node;
- {
- input_node->shape({1, 4, 3, 1});
- luci::test::graph_input_shape(input_node);
- }
- auto output_node = graph.output_node;
- {
- output_node->shape({1, 2, 1, 1});
- luci::test::graph_output_shape(output_node);
- }
- // setting CircleAveragePool2D
- {
- avg_node->filter()->h(2);
- avg_node->filter()->w(2);
- avg_node->stride()->h(2);
- avg_node->stride()->w(2);
- avg_node->fusedActivationFunction(luci::FusedActFunc::NONE);
- avg_node->padding(luci::Padding::VALID);
- }
- ASSERT_FALSE(loco::shape_known(avg_node));
-
- // shape inference
- while (shape_pass(graph.graph()) == true)
- ;
-
- // Verify
- {
- ASSERT_TRUE(loco::shape_known(avg_node));
- ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(avg_node).domain());
-
- auto shape = loco::shape_get(avg_node).as<loco::TensorShape>();
- ASSERT_EQ(4, shape.rank());
- ASSERT_EQ(1, shape.dim(0).value());
- ASSERT_EQ(2, shape.dim(1).value());
- ASSERT_EQ(1, shape.dim(2).value());
- ASSERT_EQ(1, shape.dim(3).value());
- }
-}
-
-TEST(CircleShapeInferenceRuleTest, avgpool2d_same)
-{
- luci::test::TestGraph graph;
- auto avg_node = graph.append<luci::CircleAveragePool2D>(graph.input_node);
- graph.complete();
-
- auto input_node = graph.input_node;
- {
- input_node->shape({1, 4, 3, 1});
- luci::test::graph_input_shape(input_node);
- }
- auto output_node = graph.output_node;
- {
- output_node->shape({1, 2, 2, 1});
- luci::test::graph_output_shape(output_node);
- }
-
- // setting CircleAveragePool2D
- {
- avg_node->filter()->h(2);
- avg_node->filter()->w(2);
- avg_node->stride()->h(2);
- avg_node->stride()->w(2);
- avg_node->fusedActivationFunction(luci::FusedActFunc::NONE);
- avg_node->padding(luci::Padding::SAME);
- }
-
- ASSERT_FALSE(loco::shape_known(avg_node));
-
- // shape inference
- while (shape_pass(graph.graph()) == true)
- ;
-
- // Verify
- {
- ASSERT_TRUE(loco::shape_known(avg_node));
- ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(avg_node).domain());
-
- auto shape = loco::shape_get(avg_node).as<loco::TensorShape>();
- ASSERT_EQ(4, shape.rank());
- ASSERT_EQ(1, shape.dim(0).value());
- ASSERT_EQ(2, shape.dim(1).value());
- ASSERT_EQ(2, shape.dim(2).value());
- ASSERT_EQ(1, shape.dim(3).value());
- }
-}
-
-/**
- * @note Function to test: Shape inference of two different input shapes
- *
- * Rank expansion to higher input side
- * x(2,1,5) + y(3,5) --> x(2,1,5) + y(1,3,5)
- * Do output shape inference like numpy
- * x(2,1,5) + y(1,3,5) --> output(2,3,5)
- * For each axis, dim value should be same OR one of them should be 1
- */
-TEST(CircleShapeInferenceRuleTest, TFAdd_shapeinf_different)
-{
- auto g = loco::make_graph();
-
- auto x_node = g->nodes()->create<luci::CircleInput>();
- {
- x_node->rank(3);
- x_node->dim(0) = 2;
- x_node->dim(1) = 1;
- x_node->dim(2) = 5;
- }
- auto y_node = g->nodes()->create<luci::CircleInput>();
- {
- y_node->rank(2);
- y_node->dim(0) = 3;
- y_node->dim(1) = 5;
- }
- auto add_node = g->nodes()->create<luci::CircleAdd>();
- {
- add_node->x(x_node);
- add_node->y(y_node);
- }
- auto output_node = g->nodes()->create<luci::CircleOutput>();
- {
- output_node->from(add_node);
- }
-
- auto x_input = g->inputs()->create();
- {
- x_input->name("x");
- luci::link(x_input, x_node);
- }
- auto y_input = g->inputs()->create();
- {
- y_input->name("y");
- luci::link(y_input, y_node);
- }
- auto output = g->outputs()->create();
- {
- output->name("output");
- luci::link(output, output_node);
- }
-
- luci::test::graph_input_shape(x_node);
- luci::test::graph_input_shape(y_node);
- luci::test::graph_output_shape(output_node);
-
- // pre-check
- ASSERT_FALSE(loco::shape_known(add_node));
-
- // shape inference
- while (shape_pass(g.get()) == true)
- ;
-
- // Verify
- {
- ASSERT_TRUE(loco::shape_known(add_node));
- ASSERT_EQ(loco::Domain::Tensor, loco::shape_get(add_node).domain());
-
- auto shape = loco::shape_get(add_node).as<loco::TensorShape>();
- ASSERT_EQ(3, shape.rank());
- ASSERT_EQ(2, shape.dim(0));
- ASSERT_EQ(3, shape.dim(1));
- ASSERT_EQ(5, shape.dim(2));
- }
-}
-
-TEST(CircleShapeInferenceRuleTest, CircleTranspose_simple)
-{
- luci::test::ExampleGraph<luci::test::ExampleGraphType::CircleTranspose> g;
-
- g.input_node->rank(3);
- g.input_node->dim(0) = 3;
- g.input_node->dim(1) = 8;
- g.input_node->dim(2) = 1;
-
- g.const_perm->dtype(loco::DataType::S32);
- g.const_perm->rank(1);
- g.const_perm->dim(0) = 3;
- g.const_perm->size<loco::DataType::S32>(3);
- g.const_perm->at<loco::DataType::S32>(0) = 1;
- g.const_perm->at<loco::DataType::S32>(1) = 2;
- g.const_perm->at<loco::DataType::S32>(2) = 0;
-
- luci::test::graph_input_shape(g.input_node);
- luci::test::graph_output_shape(g.output_node);
-
- // pre-check
- ASSERT_FALSE(loco::shape_known(g.transpose_node));
-
- // shape inference
- while (shape_pass(g.graph()) == true)
- ;
-
- // Verify
- {
- ASSERT_TRUE(loco::shape_known(g.transpose_node));
-
- auto shape = loco::shape_get(g.transpose_node).as<loco::TensorShape>();
- ASSERT_EQ(3, shape.rank());
- ASSERT_EQ(8, shape.dim(0));
- ASSERT_EQ(1, shape.dim(1));
- ASSERT_EQ(3, shape.dim(2));
- }
-}
-
-TEST(CircleShapeInferenceRuleTest, CircleSqueeze)
-{
- luci::test::TestGraph graph;
- auto squeeze_node = graph.append<luci::CircleSqueeze>(graph.input_node);
- graph.complete();
-
- auto input_node = graph.input_node;
- {
- input_node->shape({1, 4, 3, 1});
- }
- auto output_node = graph.output_node;
- {
- output_node->shape({4, 3, 1});
- }
-
- luci::test::graph_input_shape(input_node);
- luci::test::graph_output_shape(output_node);
-
- squeeze_node->squeeze_dims({0});
-
- // pre-check
- ASSERT_FALSE(loco::shape_known(squeeze_node));
-
- // shape inference
- while (shape_pass(graph.graph()) == true)
- ;
-
- // Verify
- {
- ASSERT_TRUE(loco::shape_known(squeeze_node));
-
- auto shape = loco::shape_get(squeeze_node).as<loco::TensorShape>();
- ASSERT_EQ(3, shape.rank());
- ASSERT_EQ(4, shape.dim(0));
- ASSERT_EQ(3, shape.dim(1));
- ASSERT_EQ(1, shape.dim(2));
- }
-}
-
-TEST(CircleShapeInferenceRuleTest, CircleExpandDims)
-{
- luci::test::TestGraph graph;
- auto axis = graph.append<luci::CircleConst>();
- axis->dtype(loco::DataType::S32);
- axis->rank(0);
- axis->size<loco::DataType::S32>(1);
- axis->at<loco::DataType::S32>(0) = 1;
-
- auto expand_dims = graph.append<luci::CircleExpandDims>(graph.input_node, axis);
- graph.complete();
-
- auto input_node = graph.input_node;
- {
- input_node->shape({4, 3});
- }
-
- auto output_node = graph.output_node;
- {
- output_node->from(expand_dims);
- }
-
- luci::test::graph_input_shape(input_node);
- luci::test::graph_output_shape(output_node);
-
- // shape inference
- while (shape_pass(graph.graph()))
- ;
-
- // validation
- {
- ASSERT_TRUE(loco::shape_known(expand_dims));
-
- auto shape = loco::shape_get(expand_dims).as<loco::TensorShape>();
-
- ASSERT_EQ(3, shape.rank());
- ASSERT_EQ(4, shape.dim(0));
- ASSERT_EQ(1, shape.dim(1));
- ASSERT_EQ(3, shape.dim(2));
- }
-}
-
-TEST(CircleShapeInferenceRuleTest, CircleSqueezeAll)
-{
- luci::test::TestGraph graph;
- auto squeeze_node = graph.append<luci::CircleSqueeze>(graph.input_node);
- graph.complete();
-
- auto input_node = graph.input_node;
- {
- input_node->shape({1, 4, 3, 1});
- }
- auto output_node = graph.output_node;
- {
- input_node->shape({4, 3});
- }
-
- luci::test::graph_input_shape(input_node);
- luci::test::graph_output_shape(output_node);
-
- squeeze_node->squeeze_dims({});
-
- // pre-check
- ASSERT_FALSE(loco::shape_known(squeeze_node));
-
- // shape inference
- while (shape_pass(graph.graph()) == true)
- ;
-
- // Verify
- {
- ASSERT_TRUE(loco::shape_known(squeeze_node));
-
- auto shape = loco::shape_get(squeeze_node).as<loco::TensorShape>();
- ASSERT_EQ(2, shape.rank());
- ASSERT_EQ(4, shape.dim(0));
- ASSERT_EQ(3, shape.dim(1));
- }
-}
-
-TEST(CircleShapeInferenceRuleTest, CircleGatherNd_simple)
-{
- luci::test::TestGraph graph;
- auto indices_const = graph.append<luci::CircleConst>();
- auto gather_nd_node = graph.append<luci::CircleGatherNd>(graph.input_node, indices_const);
- graph.complete();
-
- {
- auto input_node = graph.input_node;
- input_node->shape({1, 4, 4, 3});
- luci::test::graph_input_shape(input_node);
- }
- {
- auto output_node = graph.output_node;
- output_node->shape({1, 2, 2, 3});
- luci::test::graph_output_shape(output_node);
- }
-
- {
- indices_const->shape({1, 2, 3});
- }
-
- // pre-check
- ASSERT_FALSE(loco::shape_known(gather_nd_node));
-
- // shape inference
- while (shape_pass(graph.graph()) == true)
- ;
-
- // Verify
- {
- ASSERT_TRUE(loco::shape_known(gather_nd_node));
-
- auto shape = loco::shape_get(gather_nd_node).as<loco::TensorShape>();
- ASSERT_EQ(3, shape.rank());
- ASSERT_EQ(1, shape.dim(0));
- ASSERT_EQ(2, shape.dim(1));
- ASSERT_EQ(3, shape.dim(2));
- }
-}
-
-TEST(CircleShapeInferenceRuleTest, CircleGatherNd_slices)
-{
- luci::test::TestGraph graph;
- auto indices_const = graph.append<luci::CircleConst>();
- auto gather_nd_node = graph.append<luci::CircleGatherNd>(graph.input_node, indices_const);
- graph.complete();
-
- {
- auto input_node = graph.input_node;
- input_node->shape({1, 4, 4, 3});
- luci::test::graph_input_shape(input_node);
- }
- {
- auto output_node = graph.output_node;
- output_node->shape({1, 2, 4, 4, 3});
- luci::test::graph_output_shape(output_node);
- }
-
- {
- indices_const->shape({1, 2, 1});
- }
-
- // pre-check
- ASSERT_FALSE(loco::shape_known(gather_nd_node));
-
- // shape inference
- while (shape_pass(graph.graph()) == true)
- ;
-
- // Verify
- {
- ASSERT_TRUE(loco::shape_known(gather_nd_node));
-
- auto shape = loco::shape_get(gather_nd_node).as<loco::TensorShape>();
- ASSERT_EQ(5, shape.rank());
- ASSERT_EQ(1, shape.dim(0));
- ASSERT_EQ(2, shape.dim(1));
- ASSERT_EQ(4, shape.dim(2));
- ASSERT_EQ(4, shape.dim(3));
- ASSERT_EQ(3, shape.dim(4));
- }
-}
-
-TEST(CircleShapeInferenceRuleTest, CircleGatherNd_NEG)
-{
- luci::test::TestGraph graph;
- auto indices_const = graph.append<luci::CircleConst>();
- auto gather_nd_node = graph.append<luci::CircleGatherNd>(graph.input_node, indices_const);
- graph.complete();
-
- {
- auto input_node = graph.input_node;
- input_node->shape({1, 4, 4, 3});
- luci::test::graph_input_shape(input_node);
- }
- {
- // Does not matter, because test should fail anyway
- auto output_node = graph.output_node;
- output_node->shape({0, 0, 0});
- luci::test::graph_output_shape(output_node);
- }
-
- {
- indices_const->shape({1, 2, 5});
- }
-
- // pre-check
- ASSERT_FALSE(loco::shape_known(gather_nd_node));
-
- // had to pack into lambda to check throw
- auto lambda = [&]() {
- // shape inference
- while (shape_pass(graph.graph()) == true)
- ;
- };
-
- ASSERT_THROW(lambda(), oops::InternalExn);
-}
-
-TEST(CircleShapeInferenceRuleTest, CircleResizeNearestNeighbor)
-{
- luci::test::TestGraph graph;
- auto size_const = graph.append<luci::CircleConst>();
- size_const->dtype(loco::DataType::S32);
- size_const->rank(1);
- size_const->dim(0) = 2;
- size_const->size<loco::DataType::S32>(2);
- size_const->at<loco::DataType::S32>(0) = 16;
- size_const->at<loco::DataType::S32>(1) = 16;
- auto resize_node = graph.append<luci::CircleResizeNearestNeighbor>(graph.input_node, size_const);
- graph.complete();
-
- {
- auto input_node = graph.input_node;
- input_node->shape({1, 4, 4, 3});
- luci::test::graph_input_shape(input_node);
- }
- {
- auto output_node = graph.output_node;
- output_node->from(resize_node);
- luci::test::graph_output_shape(output_node);
- }
-
- // pre-check
- ASSERT_FALSE(loco::shape_known(resize_node));
-
- // shape inference
- while (shape_pass(graph.graph()) == true)
- ;
-
- // Verify
- {
- ASSERT_TRUE(loco::shape_known(resize_node));
-
- auto shape = loco::shape_get(resize_node).as<loco::TensorShape>();
- ASSERT_EQ(4, shape.rank());
- ASSERT_EQ(1, shape.dim(0));
- ASSERT_EQ(16, shape.dim(1));
- ASSERT_EQ(16, shape.dim(2));
- ASSERT_EQ(3, shape.dim(3));
- }
-}
-
-TEST(CircleShapeInferenceRuleTest, CircleResizeBilinear)
-{
- luci::test::TestGraph graph;
- auto size_const = graph.append<luci::CircleConst>();
- size_const->dtype(loco::DataType::S32);
- size_const->rank(1);
- size_const->dim(0) = 2;
- size_const->size<loco::DataType::S32>(2);
- size_const->at<loco::DataType::S32>(0) = 16;
- size_const->at<loco::DataType::S32>(1) = 16;
- auto resize_node = graph.append<luci::CircleResizeBilinear>(graph.input_node, size_const);
- graph.complete();
-
- {
- auto input_node = graph.input_node;
- input_node->shape({1, 4, 4, 3});
- luci::test::graph_input_shape(input_node);
- }
- {
- auto output_node = graph.output_node;
- output_node->from(resize_node);
- luci::test::graph_output_shape(output_node);
- }
-
- // pre-check
- ASSERT_FALSE(loco::shape_known(resize_node));
-
- // shape inference
- while (shape_pass(graph.graph()) == true)
- ;
-
- // Verify
- {
- ASSERT_TRUE(loco::shape_known(resize_node));
-
- auto shape = loco::shape_get(resize_node).as<loco::TensorShape>();
- ASSERT_EQ(4, shape.rank());
- ASSERT_EQ(1, shape.dim(0));
- ASSERT_EQ(16, shape.dim(1));
- ASSERT_EQ(16, shape.dim(2));
- ASSERT_EQ(3, shape.dim(3));
- }
-}
diff --git a/compiler/luci/service/src/CircleShapeSignatureInference.cpp b/compiler/luci/service/src/CircleShapeSignatureInference.cpp
deleted file mode 100644
index 1ccaa19d5..000000000
--- a/compiler/luci/service/src/CircleShapeSignatureInference.cpp
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "luci/Service/CircleShapeSignatureInference.h"
-
-#include <luci/Log.h>
-
-namespace
-{
-
-std::ostream &operator<<(std::ostream &os, const luci::ShapeSignature &shape_signature)
-{
- os << "[";
- for (uint32_t r = 0; r < shape_signature.rank(); ++r)
- {
- if (r)
- os << ",";
- os << shape_signature.dim(r);
- }
- os << "]";
- return os;
-}
-
-} // namespace
-
-namespace luci
-{
-
-namespace ssinf
-{
-
-bool Rule::infer(const luci::CircleNode *circle_node, ShapeSignature &shape_signature) const
-{
- LOGGER(l);
-
- // There is nothing to check before ShapeSignatureInference.
-
- Algorithm alg;
-
- shape_signature = circle_node->accept(&alg);
-
- VERBOSE(l, 1) << "[luci] Shape Signature( " << circle_node->name() << " )";
- VERBOSE(l, 1) << " before: " << circle_node->shape_signature();
- VERBOSE(l, 1) << " after: " << shape_signature;
-
- return true;
-}
-
-} // namespace ssinf
-
-} // namespace luci
diff --git a/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp b/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp
deleted file mode 100644
index d7d1a24e8..000000000
--- a/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp
+++ /dev/null
@@ -1,160 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "luci/Service/CircleShapeSignatureInferenceHelper.h"
-
-#include <loco.h>
-
-#include <luci/Log.h>
-
-#include <oops/InternalExn.h>
-
-namespace luci
-{
-
-namespace ssinf
-{
-
-luci::ShapeSignature legalized_signature(const luci::ShapeSignature &signature)
-{
- // If shape signature has at least one -1, it is not static.
- for (uint32_t i = 0; i < signature.rank(); ++i)
- if (signature.dim(i) == -1)
- return signature;
-
- // If all dimensions are static, return empty shape signature.
- return luci::ShapeSignature();
-}
-
-ShapeSignature reduced_signature(const loco::Node *node, const loco::Node *indices, bool keep_dims)
-{
- LOGGER(l);
-
- ShapeSignature input_signature;
- ShapeSignature output_signature;
-
- auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
- if (circle_node->shape_signature().rank() > 0)
- input_signature = circle_node->shape_signature();
- else
- {
- input_signature.rank(circle_node->rank());
- for (uint32_t i = 0; i < circle_node->rank(); ++i)
- input_signature.dim(i) = circle_node->dim(i).value();
- }
-
- // If input rank is 0, it means that one of following case is occurred.
- // - Input is scalar : result is always scalar
- // - Input shape signature is not inferenced : cannot infer output shape signauture
- // Therefore, when input signature rank is 0, always return empty signature.
- if (input_signature.rank() == 0)
- return output_signature;
-
- // When reduction_indices is not constant
- auto reduction_indices = dynamic_cast<const luci::CircleConst *>(indices);
- if (reduction_indices == nullptr)
- {
- if (keep_dims)
- {
- // If keep_dims is true, rank is not changed.
- output_signature.rank(input_signature.rank());
- for (uint32_t i = 0; i < output_signature.rank(); ++i)
- output_signature.dim(i) = -1;
- }
- else
- {
- // There is no way to inference for this case.
- // Do nothing to return empty signature.
- INFO(l) << "[CircleShapeSignatureInferenceHelper] " << circle_node->name() << std::endl;
- INFO(l) << " reduced_signature : cannot infer because of non-constant node" << std::endl;
- }
-
- return output_signature;
- }
-
- std::vector<int32_t> reduction_values;
- if (reduction_indices->dtype() == loco::DataType::S32)
- {
- auto reduction_size = reduction_indices->size<loco::DataType::S32>();
- for (uint32_t i = 0; i < reduction_size; ++i)
- {
- int32_t axis = reduction_indices->at<loco::DataType::S32>(i);
- if (axis < 0)
- axis += input_signature.rank();
-
- if (!(0 <= axis && axis < static_cast<int32_t>(input_signature.rank())))
- INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis));
-
- reduction_values.push_back(axis);
- }
- }
- else if (reduction_indices->dtype() == loco::DataType::S64)
- {
- auto reduction_size = reduction_indices->size<loco::DataType::S64>();
- for (uint32_t i = 0; i < reduction_size; ++i)
- {
- int32_t axis = static_cast<int32_t>(reduction_indices->at<loco::DataType::S64>(i));
- if (axis < 0)
- axis += input_signature.rank();
-
- if (!(0 <= axis && axis < static_cast<int32_t>(input_signature.rank())))
- INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis));
-
- reduction_values.push_back(axis);
- }
- }
- else
- {
- INTERNAL_EXN("Wrong reduction axis type, Only INT32, INT64 supported.");
- }
-
- if (keep_dims)
- {
- output_signature.rank(input_signature.rank());
- for (uint32_t i = 0; i < input_signature.rank(); ++i)
- output_signature.dim(i) = input_signature.dim(i);
- for (uint32_t i = 0; i < reduction_values.size(); ++i)
- output_signature.dim(reduction_values.at(i)) = 1;
- }
- else
- {
- std::vector<bool> check_reduce(input_signature.rank(), false);
- for (uint32_t i = 0; i < reduction_values.size(); ++i)
- check_reduce.at(reduction_values.at(i)) = true;
-
- uint32_t reduce_cnt = 0;
- for (uint32_t i = 0; i < check_reduce.size(); ++i)
- if (check_reduce.at(i))
- ++reduce_cnt;
-
- output_signature.rank(input_signature.rank() - reduce_cnt);
- for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
- if (check_reduce.at(i) == false)
- output_signature.dim(j++) = input_signature.dim(i);
- }
-
- return output_signature;
-}
-
-ShapeSignature input_arg_signature(const luci::CircleNode *node, uint32_t index)
-{
- auto circle_input = loco::must_cast<luci::CircleNode *>(node->arg(index));
- return circle_input->shape_signature();
-}
-
-} // namespace ssinf
-
-} // namespace luci
diff --git a/compiler/luci/service/src/CircleTypeInference.cpp b/compiler/luci/service/src/CircleTypeInference.cpp
index b4755b51a..db9a37cb0 100644
--- a/compiler/luci/service/src/CircleTypeInference.cpp
+++ b/compiler/luci/service/src/CircleTypeInference.cpp
@@ -15,72 +15,23 @@
*/
#include "luci/Service/CircleTypeInference.h"
+#include "CircleTypeInferenceHelper.h"
#include <luci/Log.h>
#include <loco.h>
-#include <loco/Service/TypeInference.h>
-
-#include <mio/circle/schema_generated.h>
-#include <oops/InternalExn.h>
#include <type_traits>
namespace
{
-circle::TensorType translateLocoTypeToCircle(loco::DataType dtype)
-{
- switch (dtype)
- {
- case loco::DataType::U8:
- return circle::TensorType_UINT8;
- // case loco::DataType::U16: unsupported
- // case loco::DataType::U32: unsupported
- // case loco::DataType::U64: unsupported
- case loco::DataType::S8:
- return circle::TensorType_INT8;
- case loco::DataType::S16:
- return circle::TensorType_INT16;
- case loco::DataType::S32:
- return circle::TensorType_INT32;
- case loco::DataType::S64:
- return circle::TensorType_INT64;
- case loco::DataType::FLOAT16:
- return circle::TensorType_FLOAT16;
- case loco::DataType::FLOAT32:
- return circle::TensorType_FLOAT32;
- // case loco::DataType::FLOAT64: unsupported
- case loco::DataType::BOOL:
- return circle::TensorType_BOOL;
- default:
- break;
- }
-
- INTERNAL_EXN_V("Invalid loco dtype", oops::to_uint32(dtype));
-}
-
-} // namespace
-
-namespace luci
-{
-
-circle::TensorType TypeInference::get(loco::Node *node)
-{
- assert(loco::dtype_known(node));
- return translateLocoTypeToCircle(loco::dtype_get(node));
-}
-
-} // namespace luci
-
-namespace
-{
-
bool inputs_dtype_ready(const luci::CircleNode *node)
{
for (uint32_t arity = 0; arity < node->arity(); ++arity)
{
- if (node->dtype() == loco::DataType::Unknown)
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->arg(arity));
+ if (input_node->dtype() == loco::DataType::Unknown)
return false;
}
diff --git a/compiler/luci/service/src/CircleTypeInferenceHelper.cpp b/compiler/luci/service/src/CircleTypeInferenceHelper.cpp
index 75cd9f7b2..06edd70f2 100644
--- a/compiler/luci/service/src/CircleTypeInferenceHelper.cpp
+++ b/compiler/luci/service/src/CircleTypeInferenceHelper.cpp
@@ -14,7 +14,23 @@
* limitations under the License.
*/
-#include "luci/Service/CircleTypeInferenceHelper.h"
+#include "CircleTypeInferenceHelper.h"
+
+namespace luci
+{
+
+loco::DataType dtype_get(const loco::Node *node)
+{
+ assert(luci::dtype_known(node));
+ return loco::must_cast<const luci::CircleNode *>(node)->dtype();
+}
+
+bool dtype_known(const loco::Node *node)
+{
+ return loco::must_cast<const luci::CircleNode *>(node)->dtype() != loco::DataType::Unknown;
+}
+
+} // namespace luci
namespace luci
{
diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h b/compiler/luci/service/src/CircleTypeInferenceHelper.h
index 296f99355..751340cc7 100644
--- a/compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h
+++ b/compiler/luci/service/src/CircleTypeInferenceHelper.h
@@ -23,6 +23,20 @@
namespace luci
{
+
+// NOTE Functions in this namespace will be removed after new inference
+// algorithms are fully implemented.
+
+// This function is temporary function for deprecating loco::dtype_get
+loco::DataType dtype_get(const loco::Node *node);
+
+// This function is temporary function for deprecating loco::dtype_known
+bool dtype_known(const loco::Node *node);
+
+} // namespace luci
+
+namespace luci
+{
namespace tinf // Namespace for Type Inference
{
diff --git a/compiler/luci/service/src/CircleTypeInferenceRule.cpp b/compiler/luci/service/src/CircleTypeInferenceRule.cpp
index f738ab5a8..0b8d2af9e 100644
--- a/compiler/luci/service/src/CircleTypeInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleTypeInferenceRule.cpp
@@ -15,6 +15,7 @@
*/
#include "luci/Service/CircleTypeInferenceRule.h"
+#include "CircleTypeInferenceHelper.h"
#include <luci/IR/CircleDialect.h>
#include <luci/IR/CircleNodeVisitor.h>
@@ -29,24 +30,24 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
{
// TODO Given a tensor x of complex numbers, Abs operation returns a tensor of type float32 or
// float64.
- loco::DataType visit(const luci::CircleAbs *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleAbs *node) final { return luci::dtype_get(node->x()); }
- loco::DataType visit(const luci::CircleAdd *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleAdd *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleAddN *node) final
{
- auto dtype = loco::dtype_get(node->inputs(0));
+ auto dtype = luci::dtype_get(node->inputs(0));
for (uint32_t idx = 1; idx < node->arity(); ++idx)
{
- auto dtype_idx = loco::dtype_get(node->inputs(idx));
+ auto dtype_idx = luci::dtype_get(node->inputs(idx));
if (dtype != dtype_idx)
{
INTERNAL_EXN_V("ADD_N dtype not same as the first input: ", idx);
}
}
- return loco::dtype_get(node->inputs(0));
+ return luci::dtype_get(node->inputs(0));
}
loco::DataType visit(const luci::CircleArgMax *node) final { return node->output_type(); }
@@ -55,22 +56,22 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
loco::DataType visit(const luci::CircleAveragePool2D *node) final
{
- return loco::dtype_get(node->value());
+ return luci::dtype_get(node->value());
}
loco::DataType visit(const luci::CircleBatchMatMul *node) final
{
- return loco::dtype_get(node->x());
+ return luci::dtype_get(node->x());
}
loco::DataType visit(const luci::CircleBatchToSpaceND *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleCast *node) final { return node->dtype(); }
- loco::DataType visit(const luci::CircleCeil *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleCeil *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleConcatenation *node) final
{
@@ -78,87 +79,92 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
assert(node->numValues() > 0);
for (uint32_t i = 1; i < node->numValues(); ++i)
- assert(loco::dtype_get(node->values(i - 1)) == loco::dtype_get(node->values(i)));
+ assert(luci::dtype_get(node->values(i - 1)) == luci::dtype_get(node->values(i)));
- return loco::dtype_get(node->values(0));
+ return luci::dtype_get(node->values(0));
}
loco::DataType visit(const luci::CircleConst *node) final { return node->dtype(); }
loco::DataType visit(const luci::CircleConv2D *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
- loco::DataType visit(const luci::CircleCos *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleCos *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleCustom *node) final
{
if (node->custom_code() == "BatchMatMulV2")
{
- return loco::dtype_get(node->inputs(0));
+ return luci::dtype_get(node->inputs(0));
}
return node->dtype();
}
loco::DataType visit(const luci::CircleDepthToSpace *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleDepthwiseConv2D *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleDequantize *) final { return loco::DataType::FLOAT32; }
- loco::DataType visit(const luci::CircleDiv *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleDiv *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleElu *node) final
{
- return loco::dtype_get(node->features());
+ return luci::dtype_get(node->features());
}
loco::DataType visit(const luci::CircleEqual *) final { return loco::DataType::BOOL; }
- loco::DataType visit(const luci::CircleExp *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleExp *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleExpandDims *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
+ }
+
+ loco::DataType visit(const luci::CircleFakeQuant *node) final
+ {
+ return luci::dtype_get(node->inputs());
}
loco::DataType visit(const luci::CircleFill *node) final
{
- return loco::dtype_get(node->value());
+ return luci::dtype_get(node->value());
}
- loco::DataType visit(const luci::CircleFloor *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleFloor *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleFloorDiv *node) final
{
- return loco::dtype_get(node->x());
+ return luci::dtype_get(node->x());
}
loco::DataType visit(const luci::CircleFloorMod *node) final
{
- return loco::dtype_get(node->x());
+ return luci::dtype_get(node->x());
}
loco::DataType visit(const luci::CircleFullyConnected *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleGather *node) final
{
- return loco::dtype_get(node->params());
+ return luci::dtype_get(node->params());
}
loco::DataType visit(const luci::CircleGatherNd *node) final
{
- return loco::dtype_get(node->params());
+ return luci::dtype_get(node->params());
}
loco::DataType visit(const luci::CircleGreater *) final { return loco::DataType::BOOL; }
@@ -169,22 +175,22 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
{
// Type of If is not used. Just use input 0
assert(node->input_count() > 0);
- return loco::dtype_get(node->input(0));
+ return luci::dtype_get(node->input(0));
}
loco::DataType visit(const luci::CircleL2Normalize *node) final
{
- return loco::dtype_get(node->x());
+ return luci::dtype_get(node->x());
}
loco::DataType visit(const luci::CircleL2Pool2D *node) final
{
- return loco::dtype_get(node->value());
+ return luci::dtype_get(node->value());
}
loco::DataType visit(const luci::CircleLeakyRelu *node) final
{
- return loco::dtype_get(node->features());
+ return luci::dtype_get(node->features());
}
loco::DataType visit(const luci::CircleLess *) final { return loco::DataType::BOOL; }
@@ -193,75 +199,75 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
loco::DataType visit(const luci::CircleLocalResponseNormalization *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
- loco::DataType visit(const luci::CircleLog *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleLog *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleLogicalAnd *node) final
{
- return loco::dtype_get(node->x());
+ return luci::dtype_get(node->x());
}
loco::DataType visit(const luci::CircleLogicalNot *node) final
{
- return loco::dtype_get(node->x());
+ return luci::dtype_get(node->x());
}
loco::DataType visit(const luci::CircleLogicalOr *node) final
{
- return loco::dtype_get(node->x());
+ return luci::dtype_get(node->x());
}
loco::DataType visit(const luci::CircleLogistic *node) final
{
- return loco::dtype_get(node->x());
+ return luci::dtype_get(node->x());
}
loco::DataType visit(const luci::CircleLogSoftmax *node) final
{
- return loco::dtype_get(node->logits());
+ return luci::dtype_get(node->logits());
}
loco::DataType visit(const luci::CircleMatrixDiag *node) final
{
- return loco::dtype_get(node->diagonal());
+ return luci::dtype_get(node->diagonal());
}
loco::DataType visit(const luci::CircleMatrixSetDiag *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
- loco::DataType visit(const luci::CircleMaximum *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleMaximum *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleMaxPool2D *node) final
{
- return loco::dtype_get(node->value());
+ return luci::dtype_get(node->value());
}
loco::DataType visit(const luci::CircleMean *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
- loco::DataType visit(const luci::CircleMinimum *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleMinimum *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleMirrorPad *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
- loco::DataType visit(const luci::CircleNeg *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleNeg *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleNonMaxSuppressionV4 *node) final
{
- return loco::dtype_get(node->boxes());
+ return luci::dtype_get(node->boxes());
}
loco::DataType visit(const luci::CircleNonMaxSuppressionV5 *node) final
{
- return loco::dtype_get(node->boxes());
+ return luci::dtype_get(node->boxes());
}
loco::DataType visit(const luci::CircleNotEqual *) final { return loco::DataType::BOOL; }
@@ -271,25 +277,25 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
// Only support CirclePack with one or more inputs
assert(node->values_count() > 0);
- auto first_value_type = loco::dtype_get(node->values(0));
+ auto first_value_type = luci::dtype_get(node->values(0));
for (uint32_t i = 1; i < node->values_count(); ++i)
- assert(first_value_type == loco::dtype_get(node->values(i)));
+ assert(first_value_type == luci::dtype_get(node->values(i)));
return first_value_type;
}
- loco::DataType visit(const luci::CirclePad *node) final { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const luci::CirclePad *node) final { return luci::dtype_get(node->input()); }
loco::DataType visit(const luci::CirclePadV2 *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CirclePow *node) final
{
// TODO make sure types cannot differ
- auto x_type = loco::dtype_get(node->x());
- auto y_type = loco::dtype_get(node->y());
+ auto x_type = luci::dtype_get(node->x());
+ auto y_type = luci::dtype_get(node->y());
if (x_type != y_type)
INTERNAL_EXN("Different datatype for x and y are not supported");
@@ -299,8 +305,8 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
loco::DataType visit(const luci::CirclePRelu *node) final
{
- auto input_type = loco::dtype_get(node->input());
- auto alpha_type = loco::dtype_get(node->alpha());
+ auto input_type = luci::dtype_get(node->input());
+ auto alpha_type = luci::dtype_get(node->alpha());
if (input_type != alpha_type)
INTERNAL_EXN("Different datatype for input and alpha are not supported");
@@ -310,201 +316,201 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
loco::DataType visit(const luci::CircleRange *node) final
{
- return loco::dtype_get(node->start());
+ return luci::dtype_get(node->start());
}
loco::DataType visit(const luci::CircleRank *) final { return loco::DataType::S32; }
- loco::DataType visit(const luci::CircleMul *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleMul *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleOneHot *node) final
{
- return loco::dtype_get(node->on_value());
+ return luci::dtype_get(node->on_value());
}
loco::DataType visit(const luci::CircleReduceAny *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleReduceMax *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleReduceMin *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleReduceProd *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleRelu *node) final
{
- return loco::dtype_get(node->features());
+ return luci::dtype_get(node->features());
}
loco::DataType visit(const luci::CircleRelu6 *node) final
{
- return loco::dtype_get(node->features());
+ return luci::dtype_get(node->features());
}
loco::DataType visit(const luci::CircleReluN1To1 *node) final
{
- return loco::dtype_get(node->features());
+ return luci::dtype_get(node->features());
}
loco::DataType visit(const luci::CircleReshape *node) final
{
- return loco::dtype_get(node->tensor());
+ return luci::dtype_get(node->tensor());
}
loco::DataType visit(const luci::CircleResizeBilinear *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleResizeNearestNeighbor *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleReverseSequence *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleReverseV2 *node) final
{
- return loco::dtype_get(node->tensor());
+ return luci::dtype_get(node->tensor());
}
- loco::DataType visit(const luci::CircleRound *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleRound *node) final { return luci::dtype_get(node->x()); }
- loco::DataType visit(const luci::CircleRsqrt *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleRsqrt *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleScatterNd *node) final
{
- return loco::dtype_get(node->updates());
+ return luci::dtype_get(node->updates());
}
loco::DataType visit(const luci::CircleSegmentSum *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleSelect *node) final
{
- assert(loco::dtype_get(node->t()) == loco::dtype_get(node->e()));
- return loco::dtype_get(node->t());
+ assert(luci::dtype_get(node->t()) == luci::dtype_get(node->e()));
+ return luci::dtype_get(node->t());
}
loco::DataType visit(const luci::CircleSelectV2 *node) final
{
- assert(loco::dtype_get(node->t()) == loco::dtype_get(node->e()));
- return loco::dtype_get(node->t());
+ assert(luci::dtype_get(node->t()) == luci::dtype_get(node->e()));
+ return luci::dtype_get(node->t());
}
loco::DataType visit(const luci::CircleShape *node) final { return node->out_type(); }
- loco::DataType visit(const luci::CircleSin *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleSin *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleSlice *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleSoftmax *node) final
{
- return loco::dtype_get(node->logits());
+ return luci::dtype_get(node->logits());
}
loco::DataType visit(const luci::CircleSpaceToBatchND *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleSpaceToDepth *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleSparseToDense *node) final
{
- return loco::dtype_get(node->values());
+ return luci::dtype_get(node->values());
}
loco::DataType visit(const luci::CircleSplit *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleSplitV *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
- loco::DataType visit(const luci::CircleSqrt *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleSqrt *node) final { return luci::dtype_get(node->x()); }
- loco::DataType visit(const luci::CircleSquare *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleSquare *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleSquaredDifference *node) final
{
- return loco::dtype_get(node->x());
+ return luci::dtype_get(node->x());
}
loco::DataType visit(const luci::CircleSqueeze *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleStridedSlice *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
- loco::DataType visit(const luci::CircleSub *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleSub *node) final { return luci::dtype_get(node->x()); }
- loco::DataType visit(const luci::CircleSum *node) final { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const luci::CircleSum *node) final { return luci::dtype_get(node->input()); }
- loco::DataType visit(const luci::CircleTanh *node) final { return loco::dtype_get(node->x()); }
+ loco::DataType visit(const luci::CircleTanh *node) final { return luci::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleTile *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleTopKV2 *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleTranspose *node) final
{
- return loco::dtype_get(node->a());
+ return luci::dtype_get(node->a());
}
loco::DataType visit(const luci::CircleTransposeConv *node) final
{
- return loco::dtype_get(node->outBackprop());
+ return luci::dtype_get(node->outBackprop());
}
loco::DataType visit(const luci::CircleUnidirectionalSequenceLSTM *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleUnique *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleUnpack *node) final
{
- return loco::dtype_get(node->value());
+ return luci::dtype_get(node->value());
}
loco::DataType visit(const luci::CircleWhere *) final { return loco::DataType::S64; }
@@ -513,12 +519,12 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
{
// Type of While is not used. Just use input 0
assert(node->input_count() > 0);
- return loco::dtype_get(node->input(0));
+ return luci::dtype_get(node->input(0));
}
loco::DataType visit(const luci::CircleZerosLike *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
// Circle Only
@@ -531,7 +537,7 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
loco::DataType visit(const luci::CircleInstanceNorm *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
// Virtual
@@ -548,7 +554,7 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
{
// We don't care for the type if from() is CircleOutputDummy or CircleOutputExclude
// from() type should match that of CircleOutput
- assert(output_dtype == loco::dtype_get(node->from()));
+ assert(output_dtype == luci::dtype_get(node->from()));
}
return output_dtype;
}
@@ -559,46 +565,6 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
loco::DataType visit(const luci::CircleCustomOut *node) final { return node->dtype(); }
- loco::DataType visit(const luci::CircleIfOut *node) final
- {
- /**
- * @note IF operator type and shape are that of the "then" and "else"
- * Graph Outputs.
- */
- auto circle_if = dynamic_cast<const luci::CircleIf *>(node->input());
- if (circle_if == nullptr)
- {
- INTERNAL_EXN("CircleIf IR is not configured correctly");
- }
-
- auto index = node->index();
- auto then_graph = circle_if->then_graph();
- auto else_graph = circle_if->else_graph();
- assert(then_graph != nullptr);
- assert(else_graph != nullptr);
-
- // shape and type are assumed to be same
- // these are checked at post_import_graph() in Import
- auto then_outputs = loco::output_nodes(then_graph);
- auto else_outputs = loco::output_nodes(else_graph);
- assert(then_outputs.size() == else_outputs.size());
- assert(index < static_cast<int32_t>(then_outputs.size()));
-
- auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index));
- auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index));
-
- auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items
- auto else_graph_outputs = else_graph->outputs();
- assert(then_graph_outputs->size() == else_graph_outputs->size());
-
- auto then_graph_output = then_graph_outputs->at(then_out->index());
- auto else_graph_output = else_graph_outputs->at(else_out->index());
- (void)else_graph_output; // make compiler happy for unused variable warnings
- assert(then_graph_output->dtype() == else_graph_output->dtype());
-
- return then_graph_output->dtype();
- }
-
loco::DataType visit(const luci::CircleNonMaxSuppressionV4Out *node) final
{
(void)node;
@@ -619,19 +585,19 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
loco::DataType visit(const luci::CircleSplitOut *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleSplitVOut *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleTopKV2Out *node) final
{
// First output is same as input
if (node->index() == 0)
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
// Second outout is always S32
assert(node->index() == 1);
return loco::DataType::S32;
@@ -641,7 +607,7 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
{
if (node->index() == 0)
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
assert(node->index() == 1);
auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
@@ -650,7 +616,7 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
loco::DataType visit(const luci::CircleUnpackOut *node) final
{
- return loco::dtype_get(node->input());
+ return luci::dtype_get(node->input());
}
loco::DataType visit(const luci::CircleWhileOut *node) final
diff --git a/compiler/luci/service/src/CircleTypeInferenceRule.test.cpp b/compiler/luci/service/src/CircleTypeInferenceRule.test.cpp
deleted file mode 100644
index 711a489af..000000000
--- a/compiler/luci/service/src/CircleTypeInferenceRule.test.cpp
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "TestGraph.h"
-#include <luci/Service/CircleTypeInferenceRule.h>
-
-#include <luci/IR/CircleNodes.h>
-#include <luci/IR/CircleDialect.h>
-
-#include <loco.h>
-#include <loco/IR/CanonicalDialect.h>
-#include <loco/Service/TypeInference.h>
-
-#include <gtest/gtest.h>
-
-#include <memory>
-
-TEST(CircleTypeInferenceRuleTest, minimal_with_CircleRelu)
-{
- // Create a simple network
- luci::test::TestGraph graph;
- auto relu_node = graph.append<luci::CircleRelu>(graph.input_node);
- graph.complete(relu_node);
-
- // set dtype for nodes; like setting them in import
- graph.input_node->dtype(loco::DataType::S32);
- relu_node->dtype(loco::DataType::S32);
- graph.output_node->dtype(loco::DataType::S32);
-
- luci::test::graph_input_dtype(graph.input_node);
- luci::test::graph_output_dtype(graph.output_node);
-
- // pre-check
- ASSERT_FALSE(loco::dtype_known(relu_node));
-
- // type inference
- luci::CircleTypeInferenceRule circle_rule;
- loco::CanonicalTypeInferenceRule canon_rule;
- loco::MultiDialectTypeInferenceRule rules;
-
- rules.bind(loco::CanonicalDialect::get(), &canon_rule);
- rules.bind(luci::CircleDialect::get(), &circle_rule);
-
- loco::apply(&rules).to(graph.g.get());
-
- // Verify
- ASSERT_TRUE(loco::dtype_known(relu_node));
- auto type = loco::dtype_get(relu_node);
- ASSERT_EQ(loco::DataType::S32, type);
-}
diff --git a/compiler/luci/service/src/Nodes/CircleAbs.cpp b/compiler/luci/service/src/Nodes/CircleAbs.cpp
new file mode 100644
index 000000000..132760957
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleAbs.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleAbs *)
+{
+ return _graph->nodes()->create<luci::CircleAbs>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleAbs.test.cpp b/compiler/luci/service/src/Nodes/CircleAbs.test.cpp
new file mode 100644
index 000000000..885b395b8
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleAbs.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Abs)
+{
+ auto g = loco::make_graph();
+ auto node_abs = g->nodes()->create<luci::CircleAbs>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_abs, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_abs = dynamic_cast<luci::CircleAbs *>(cloned);
+ ASSERT_NE(nullptr, cloned_abs);
+}
diff --git a/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h b/compiler/luci/service/src/Nodes/CircleAdd.cpp
index 9d964bdd6..08634320e 100644
--- a/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h
+++ b/compiler/luci/service/src/Nodes/CircleAdd.cpp
@@ -1,6 +1,5 @@
-
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,29 +14,20 @@
* limitations under the License.
*/
-#ifndef __LUCI_TYPE_INFERENCE_PASS_H__
-#define __LUCI_TYPE_INFERENCE_PASS_H__
-
-#include <loco.h>
-
-#include <luci/ModulePass.h>
+#include "CircleCloneNode.h"
namespace luci
{
-/**
- * @brief Pass to infer type of nodes
- */
-class TypeInferencePass : public luci::Pass
+luci::CircleNode *CloneNode::visit(const luci::CircleAdd *node)
{
-public:
- virtual const char *name(void) const { return "luci::TypeInferencePass"; }
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
-public:
- bool run(luci::Module *m);
- bool run(loco::Graph *graph);
-};
+ auto *cloned = _graph->nodes()->create<luci::CircleAdd>();
+ if (cloned != nullptr)
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ return cloned;
+}
} // namespace luci
-
-#endif //__LUCI_TYPE_INFERENCE_PASS_H__
diff --git a/compiler/luci/service/src/Nodes/CircleAdd.test.cpp b/compiler/luci/service/src/Nodes/CircleAdd.test.cpp
new file mode 100644
index 000000000..41a818b0a
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleAdd.test.cpp
@@ -0,0 +1,84 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Service/CircleShapeInference.h>
+
+#include <loco/IR/TensorShape.h>
+
+#include <gtest/gtest.h>
+
+/**
+ * @note Function to test: Shape inference of two different input shapes
+ *
+ * Rank expansion to higher input side
+ * x(2,1,5) + y(3,5) --> x(2,1,5) + y(1,3,5)
+ * Do output shape inference like numpy
+ * x(2,1,5) + y(1,3,5) --> output(2,3,5)
+ * For each axis, dim value should be same OR one of them should be 1
+ */
+TEST(ShapeRuleTest, different_input_shapes_add)
+{
+ luci::CircleInput input1;
+ luci::CircleInput input2;
+ luci::CircleAdd add;
+
+ input1.shape({2, 1, 5});
+ input1.shape_status(luci::ShapeStatus::VALID);
+ input2.shape({3, 5});
+ input2.shape_status(luci::ShapeStatus::VALID);
+
+ add.x(&input1);
+ add.y(&input2);
+
+ loco::TensorShape shape;
+ luci::sinf::Rule shape_inf_rule;
+
+ ASSERT_TRUE(shape_inf_rule.infer(&add, shape));
+ ASSERT_EQ(3, shape.rank());
+ ASSERT_EQ(2, shape.dim(0).value());
+ ASSERT_EQ(3, shape.dim(1).value());
+ ASSERT_EQ(5, shape.dim(2).value());
+}
+
+TEST(CloneNodeTest, clone_Add)
+{
+ auto g = loco::make_graph();
+ auto node_add = g->nodes()->create<luci::CircleAdd>();
+ node_add->fusedActivationFunction(luci::FusedActFunc::RELU);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_add, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_add = dynamic_cast<luci::CircleAdd *>(cloned);
+ ASSERT_NE(nullptr, cloned_add);
+ ASSERT_EQ(node_add->fusedActivationFunction(), cloned_add->fusedActivationFunction());
+}
+
+TEST(CloneNodeTest, clone_Add_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_add = g->nodes()->create<luci::CircleAdd>();
+ node_add->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_add, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleAddN.cpp b/compiler/luci/service/src/Nodes/CircleAddN.cpp
new file mode 100644
index 000000000..e536e54bb
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleAddN.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleAddN *node)
+{
+ auto arity = node->arity();
+ return _graph->nodes()->create<luci::CircleAddN>(arity);
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleAddN.test.cpp b/compiler/luci/service/src/Nodes/CircleAddN.test.cpp
new file mode 100644
index 000000000..5d5b82247
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleAddN.test.cpp
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_AddN)
+{
+ auto g = loco::make_graph();
+ auto node_addn = g->nodes()->create<luci::CircleAddN>(3);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_addn, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_addn = dynamic_cast<luci::CircleAddN *>(cloned);
+ ASSERT_NE(nullptr, cloned_addn);
+ ASSERT_EQ(node_addn->arity(), cloned_addn->arity());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleArgMax.cpp b/compiler/luci/service/src/Nodes/CircleArgMax.cpp
new file mode 100644
index 000000000..1b3bafa86
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleArgMax.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleArgMax *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleArgMax>();
+ if (cloned != nullptr)
+ cloned->output_type(node->output_type());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleArgMax.test.cpp b/compiler/luci/service/src/Nodes/CircleArgMax.test.cpp
new file mode 100644
index 000000000..bb7588403
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleArgMax.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_ArgMax)
+{
+ auto g = loco::make_graph();
+ auto node_argmax = g->nodes()->create<luci::CircleArgMax>();
+ node_argmax->output_type(loco::DataType::FLOAT32);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_argmax, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_argmax = dynamic_cast<luci::CircleArgMax *>(cloned);
+ ASSERT_NE(nullptr, cloned_argmax);
+ ASSERT_EQ(node_argmax->output_type(), cloned_argmax->output_type());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleArgMin.cpp b/compiler/luci/service/src/Nodes/CircleArgMin.cpp
new file mode 100644
index 000000000..fa54f7b76
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleArgMin.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleArgMin *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleArgMin>();
+ if (cloned != nullptr)
+ cloned->output_type(node->output_type());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleArgMin.test.cpp b/compiler/luci/service/src/Nodes/CircleArgMin.test.cpp
new file mode 100644
index 000000000..ca57946f9
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleArgMin.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_ArgMin)
+{
+ auto g = loco::make_graph();
+ auto node_argmin = g->nodes()->create<luci::CircleArgMin>();
+ node_argmin->output_type(loco::DataType::FLOAT32);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_argmin, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_argmin = dynamic_cast<luci::CircleArgMin *>(cloned);
+ ASSERT_NE(nullptr, cloned_argmin);
+ ASSERT_EQ(node_argmin->output_type(), cloned_argmin->output_type());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleAveragePool2D.cpp b/compiler/luci/service/src/Nodes/CircleAveragePool2D.cpp
new file mode 100644
index 000000000..4d2791833
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleAveragePool2D.cpp
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleAveragePool2D *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+ if (node->padding() == luci::Padding::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleAveragePool2D>();
+ if (cloned != nullptr)
+ {
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ cloned->padding(node->padding());
+ cloned->filter()->h(node->filter()->h());
+ cloned->filter()->w(node->filter()->w());
+ cloned->stride()->h(node->stride()->h());
+ cloned->stride()->w(node->stride()->w());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleAveragePool2D.test.cpp b/compiler/luci/service/src/Nodes/CircleAveragePool2D.test.cpp
new file mode 100644
index 000000000..d048d1426
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleAveragePool2D.test.cpp
@@ -0,0 +1,128 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Service/CircleShapeInference.h>
+
+#include <loco/IR/TensorShape.h>
+
+#include <gtest/gtest.h>
+
+TEST(ShapeRuleTest, simple_valid_pad_avgpool2d)
+{
+ luci::CircleInput input;
+ luci::CircleAveragePool2D avgpool_2d;
+
+ input.shape({1, 4, 3, 1});
+ input.shape_status(luci::ShapeStatus::VALID);
+
+ avgpool_2d.value(&input);
+ avgpool_2d.filter()->h(2);
+ avgpool_2d.filter()->w(2);
+ avgpool_2d.stride()->h(2);
+ avgpool_2d.stride()->w(2);
+ avgpool_2d.fusedActivationFunction(luci::FusedActFunc::NONE);
+ avgpool_2d.padding(luci::Padding::VALID);
+
+ loco::TensorShape shape;
+ luci::sinf::Rule shape_inf_rule;
+
+ ASSERT_TRUE(shape_inf_rule.infer(&avgpool_2d, shape));
+ ASSERT_EQ(4, shape.rank());
+ ASSERT_EQ(1, shape.dim(0).value());
+ ASSERT_EQ(2, shape.dim(1).value());
+ ASSERT_EQ(1, shape.dim(2).value());
+ ASSERT_EQ(1, shape.dim(3).value());
+}
+
+TEST(ShapeRuleTest, simple_same_pad_avgpool2d)
+{
+ luci::CircleInput input;
+ luci::CircleAveragePool2D avgpool_2d;
+
+ input.shape({1, 4, 3, 1});
+ input.shape_status(luci::ShapeStatus::VALID);
+
+ avgpool_2d.value(&input);
+ avgpool_2d.filter()->h(2);
+ avgpool_2d.filter()->w(2);
+ avgpool_2d.stride()->h(2);
+ avgpool_2d.stride()->w(2);
+ avgpool_2d.fusedActivationFunction(luci::FusedActFunc::NONE);
+ avgpool_2d.padding(luci::Padding::SAME);
+
+ loco::TensorShape shape;
+ luci::sinf::Rule shape_inf_rule;
+
+ ASSERT_TRUE(shape_inf_rule.infer(&avgpool_2d, shape));
+ ASSERT_EQ(4, shape.rank());
+ ASSERT_EQ(1, shape.dim(0).value());
+ ASSERT_EQ(2, shape.dim(1).value());
+ ASSERT_EQ(2, shape.dim(2).value());
+ ASSERT_EQ(1, shape.dim(3).value());
+}
+
+TEST(CloneNodeTest, clone_AveragePool2D)
+{
+ auto g = loco::make_graph();
+ auto node_avgpool2d = g->nodes()->create<luci::CircleAveragePool2D>();
+ node_avgpool2d->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_avgpool2d->padding(luci::Padding::SAME);
+ node_avgpool2d->filter()->h(1);
+ node_avgpool2d->filter()->w(2);
+ node_avgpool2d->stride()->h(3);
+ node_avgpool2d->stride()->w(4);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_avgpool2d, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_avgpool2d = dynamic_cast<luci::CircleAveragePool2D *>(cloned);
+ ASSERT_NE(nullptr, cloned_avgpool2d);
+ ASSERT_EQ(node_avgpool2d->fusedActivationFunction(), cloned_avgpool2d->fusedActivationFunction());
+ ASSERT_EQ(node_avgpool2d->padding(), cloned_avgpool2d->padding());
+ ASSERT_EQ(node_avgpool2d->filter()->h(), cloned_avgpool2d->filter()->h());
+ ASSERT_EQ(node_avgpool2d->filter()->w(), cloned_avgpool2d->filter()->w());
+ ASSERT_EQ(node_avgpool2d->stride()->h(), cloned_avgpool2d->stride()->h());
+ ASSERT_EQ(node_avgpool2d->stride()->w(), cloned_avgpool2d->stride()->w());
+}
+
+TEST(CloneNodeTest, clone_AveragePool2D_fusedact_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_avgpool2d = g->nodes()->create<luci::CircleAveragePool2D>();
+ node_avgpool2d->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node_avgpool2d->padding(luci::Padding::SAME);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_avgpool2d, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
+
+TEST(CloneNodeTest, clone_AveragePool2D_padding_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_avgpool2d = g->nodes()->create<luci::CircleAveragePool2D>();
+ node_avgpool2d->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_avgpool2d->padding(luci::Padding::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_avgpool2d, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h b/compiler/luci/service/src/Nodes/CircleBCQFullyConnected.cpp
index 2c6ffcf4e..3edc06ab8 100644
--- a/compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h
+++ b/compiler/luci/service/src/Nodes/CircleBCQFullyConnected.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,29 +14,23 @@
* limitations under the License.
*/
-#ifndef __LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__
-#define __LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__
-
-#include <loco.h>
-
-#include <luci/ModulePass.h>
+#include "CircleCloneNode.h"
namespace luci
{
-/**
- * @brief Pass to infer shape_signature of nodes
- */
-class ShapeSignatureInferencePass : public luci::Pass
+luci::CircleNode *CloneNode::visit(const luci::CircleBCQFullyConnected *node)
{
-public:
- virtual const char *name(void) const { return "luci::ShapeSignatureInferencePass"; }
-
-public:
- bool run(luci::Module *m);
- bool run(loco::Graph *graph);
-};
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleBCQFullyConnected>();
+ if (cloned != nullptr)
+ {
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ cloned->weights_hidden_size(node->weights_hidden_size());
+ }
+ return cloned;
+}
} // namespace luci
-
-#endif //__LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__
diff --git a/compiler/luci/service/src/Nodes/CircleBCQFullyConnected.test.cpp b/compiler/luci/service/src/Nodes/CircleBCQFullyConnected.test.cpp
new file mode 100644
index 000000000..90c192e07
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleBCQFullyConnected.test.cpp
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_BCQFullyConnected)
+{
+ auto g = loco::make_graph();
+ auto node_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
+ node_fc->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_fc->weights_hidden_size(3);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_fc, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_fc = dynamic_cast<luci::CircleBCQFullyConnected *>(cloned);
+ ASSERT_NE(nullptr, cloned_fc);
+ ASSERT_EQ(node_fc->fusedActivationFunction(), cloned_fc->fusedActivationFunction());
+ ASSERT_EQ(node_fc->weights_hidden_size(), cloned_fc->weights_hidden_size());
+}
+
+TEST(CloneNodeTest, clone_BCQFullyConnected_fusedact_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
+ node_fc->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_fc, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleBCQGather.cpp b/compiler/luci/service/src/Nodes/CircleBCQGather.cpp
new file mode 100644
index 000000000..35b6be744
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleBCQGather.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleBCQGather *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleBCQGather>();
+ if (cloned != nullptr)
+ {
+ cloned->axis(node->axis());
+ cloned->input_hidden_size(node->input_hidden_size());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleBCQGather.test.cpp b/compiler/luci/service/src/Nodes/CircleBCQGather.test.cpp
new file mode 100644
index 000000000..a3f9e8850
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleBCQGather.test.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_BCQGather)
+{
+ auto g = loco::make_graph();
+ auto node_gat = g->nodes()->create<luci::CircleBCQGather>();
+ node_gat->axis(3);
+ node_gat->input_hidden_size(5);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_gat, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_gat = dynamic_cast<luci::CircleBCQGather *>(cloned);
+ ASSERT_NE(nullptr, cloned_gat);
+ ASSERT_EQ(node_gat->axis(), cloned_gat->axis());
+ ASSERT_EQ(node_gat->input_hidden_size(), cloned_gat->input_hidden_size());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleBatchMatMul.cpp b/compiler/luci/service/src/Nodes/CircleBatchMatMul.cpp
new file mode 100644
index 000000000..c7a8bbd52
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleBatchMatMul.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleBatchMatMul *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleBatchMatMul>();
+ if (cloned != nullptr)
+ {
+ cloned->adj_x(node->adj_x());
+ cloned->adj_y(node->adj_y());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleBatchMatMul.test.cpp b/compiler/luci/service/src/Nodes/CircleBatchMatMul.test.cpp
new file mode 100644
index 000000000..e013feae8
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleBatchMatMul.test.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_BatchMatMul)
+{
+ auto g = loco::make_graph();
+ auto node_bmm = g->nodes()->create<luci::CircleBatchMatMul>();
+ node_bmm->adj_x(true);
+ node_bmm->adj_y(true);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_bmm, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_bmm = dynamic_cast<luci::CircleBatchMatMul *>(cloned);
+ ASSERT_NE(nullptr, cloned_bmm);
+ ASSERT_EQ(node_bmm->adj_x(), cloned_bmm->adj_x());
+ ASSERT_EQ(node_bmm->adj_y(), cloned_bmm->adj_y());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleBatchToSpaceND.cpp b/compiler/luci/service/src/Nodes/CircleBatchToSpaceND.cpp
new file mode 100644
index 000000000..70aa05f72
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleBatchToSpaceND.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleBatchToSpaceND *)
+{
+ return _graph->nodes()->create<luci::CircleBatchToSpaceND>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleBatchToSpaceND.test.cpp b/compiler/luci/service/src/Nodes/CircleBatchToSpaceND.test.cpp
new file mode 100644
index 000000000..a45039fc7
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleBatchToSpaceND.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_BatchToSpaceND)
+{
+ auto g = loco::make_graph();
+ auto node_b2s = g->nodes()->create<luci::CircleBatchToSpaceND>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_b2s, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_b2s = dynamic_cast<luci::CircleBatchToSpaceND *>(cloned);
+ ASSERT_NE(nullptr, cloned_b2s);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleCast.cpp b/compiler/luci/service/src/Nodes/CircleCast.cpp
new file mode 100644
index 000000000..75f15f9de
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleCast.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleCast *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleCast>();
+ if (cloned != nullptr)
+ {
+ cloned->in_data_type(node->in_data_type());
+ cloned->out_data_type(node->out_data_type());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleCast.test.cpp b/compiler/luci/service/src/Nodes/CircleCast.test.cpp
new file mode 100644
index 000000000..1c4bacb73
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleCast.test.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Cast)
+{
+ auto g = loco::make_graph();
+ auto node_cast = g->nodes()->create<luci::CircleCast>();
+ node_cast->in_data_type(loco::DataType::U16);
+ node_cast->out_data_type(loco::DataType::S32);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_cast, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_cast = dynamic_cast<luci::CircleCast *>(cloned);
+ ASSERT_NE(nullptr, cloned_cast);
+ ASSERT_EQ(node_cast->in_data_type(), cloned_cast->in_data_type());
+ ASSERT_EQ(node_cast->out_data_type(), cloned_cast->out_data_type());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleCeil.cpp b/compiler/luci/service/src/Nodes/CircleCeil.cpp
new file mode 100644
index 000000000..92d039a7d
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleCeil.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleCeil *)
+{
+ return _graph->nodes()->create<luci::CircleCeil>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleCeil.test.cpp b/compiler/luci/service/src/Nodes/CircleCeil.test.cpp
new file mode 100644
index 000000000..b182127d9
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleCeil.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Ceil)
+{
+ auto g = loco::make_graph();
+ auto node_ceil = g->nodes()->create<luci::CircleCeil>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_ceil, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_ceil = dynamic_cast<luci::CircleCeil *>(cloned);
+ ASSERT_NE(nullptr, cloned_ceil);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleConcatenation.cpp b/compiler/luci/service/src/Nodes/CircleConcatenation.cpp
new file mode 100644
index 000000000..75d6a53e6
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleConcatenation.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleConcatenation *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleConcatenation>(node->numValues());
+ if (cloned != nullptr)
+ {
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ cloned->axis(node->axis());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleConcatenation.test.cpp b/compiler/luci/service/src/Nodes/CircleConcatenation.test.cpp
new file mode 100644
index 000000000..270068cf0
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleConcatenation.test.cpp
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Concatenation)
+{
+ auto g = loco::make_graph();
+ auto node_concat = g->nodes()->create<luci::CircleConcatenation>(3);
+ node_concat->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_concat->axis(7);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_concat, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_concat = dynamic_cast<luci::CircleConcatenation *>(cloned);
+ ASSERT_NE(nullptr, cloned_concat);
+ ASSERT_EQ(node_concat->numValues(), cloned_concat->numValues());
+ ASSERT_EQ(node_concat->fusedActivationFunction(), cloned_concat->fusedActivationFunction());
+ ASSERT_EQ(node_concat->axis(), cloned_concat->axis());
+}
+
+TEST(CloneNodeTest, clone_Concatenation_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_concat = g->nodes()->create<luci::CircleConcatenation>(3);
+ node_concat->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_concat, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleConst.cpp b/compiler/luci/service/src/Nodes/CircleConst.cpp
new file mode 100644
index 000000000..0306ef4eb
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleConst.cpp
@@ -0,0 +1,118 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <luci/IR/Nodes/CircleConst.h>
+
+#include <loco.h>
+#include <loco/IR/Graph.h>
+
+#include <oops/UserExn.h>
+
+#include <cassert>
+
+namespace
+{
+
+template <loco::DataType T>
+void copy_values(const luci::CircleConst *node, luci::CircleConst *cloned)
+{
+ assert(T == node->dtype());
+ assert(T == cloned->dtype());
+
+ const auto size = node->size<T>();
+ cloned->size<T>(size);
+ for (uint32_t i = 0; i < size; i++)
+ cloned->at<T>(i) = node->at<T>(i);
+}
+
+luci::CircleConst *clone_circleconst(const luci::CircleConst *node, loco::Graph *graph)
+{
+ auto cloned = graph->nodes()->create<luci::CircleConst>();
+
+ if (cloned != nullptr)
+ {
+ // dtype/shape
+ cloned->dtype(node->dtype());
+ cloned->rank(node->rank());
+
+ // values
+ switch (node->dtype())
+ {
+ case loco::DataType::FLOAT32:
+ copy_values<loco::DataType::FLOAT32>(node, cloned);
+ break;
+
+ case loco::DataType::U8:
+ copy_values<loco::DataType::U8>(node, cloned);
+ break;
+
+ case loco::DataType::S8:
+ copy_values<loco::DataType::S8>(node, cloned);
+ break;
+
+ case loco::DataType::S16:
+ copy_values<loco::DataType::S16>(node, cloned);
+ break;
+
+ case loco::DataType::S32:
+ copy_values<loco::DataType::S32>(node, cloned);
+ break;
+
+ case loco::DataType::S64:
+ copy_values<loco::DataType::S64>(node, cloned);
+ break;
+
+ case loco::DataType::BOOL:
+ copy_values<loco::DataType::BOOL>(node, cloned);
+ break;
+
+ default:
+ throw oops::UserExn("Unsupported tensor dtype");
+ }
+ }
+
+ return cloned;
+}
+
+} // namespace
+
+namespace luci
+{
+
+luci::CircleConst *clone(luci::CircleConst *node)
+{
+ auto *cloned = clone_circleconst(node, node->graph());
+
+ copy_common_attributes(node, cloned);
+
+ return cloned;
+}
+
+} // namespace luci
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleConst *node)
+{
+ return clone_circleconst(node, _graph);
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleConst.test.cpp b/compiler/luci/service/src/Nodes/CircleConst.test.cpp
new file mode 100644
index 000000000..5d94798f4
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleConst.test.cpp
@@ -0,0 +1,177 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/Nodes/CircleConst.h"
+#include "luci/Service/CircleNodeClone.h"
+
+#include <loco.h>
+#include <loco/IR/Graph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+luci::CircleConst *new_const_s32(loco::Graph *g)
+{
+ // prepare source CircleConst
+ auto circle_const = g->nodes()->create<luci::CircleConst>();
+
+ const auto size = 2;
+
+ circle_const->dtype(loco::DataType::S32);
+ circle_const->rank(1);
+ circle_const->dim(0).set(size);
+ circle_const->shape_status(luci::ShapeStatus::VALID);
+
+ circle_const->size<loco::DataType::S32>(size);
+ for (uint32_t i = 0; i < size; i++)
+ circle_const->at<loco::DataType::S32>(i) = i;
+
+ // quantparam
+ auto quantparam = std::make_unique<luci::CircleQuantParam>();
+ quantparam->scale = {1.0};
+ quantparam->zerop = {0};
+ quantparam->min = {-127.0};
+ quantparam->max = {127.0};
+ quantparam->quantized_dimension = 1;
+ circle_const->quantparam(std::move(quantparam));
+
+ // sparsityparam
+ auto sparam = std::make_unique<luci::SparsityParam>();
+ sparam->traversal_order = {1};
+ sparam->block_map = {1};
+ sparam->dim_metadata = {};
+ circle_const->sparsityparam(std::move(sparam));
+
+ return circle_const;
+}
+
+template <loco::DataType DT> luci::CircleConst *new_empty_const(loco::Graph *g)
+{
+ auto circle_const = g->nodes()->create<luci::CircleConst>();
+
+ const auto size = 0;
+
+ circle_const->dtype(DT);
+ circle_const->rank(1);
+ circle_const->dim(0).set(size);
+ circle_const->shape_status(luci::ShapeStatus::VALID);
+ circle_const->size<DT>(size);
+
+ return circle_const;
+}
+
+} // namespace
+
+TEST(CircleConstTest, clone)
+{
+ auto g = loco::make_graph();
+
+ // prepare source CircleConst
+ auto circle_const = new_const_s32(g.get());
+
+ // make a clone
+ auto const_cloned = luci::clone(circle_const);
+
+ // check attributes
+ ASSERT_EQ(loco::DataType::S32, const_cloned->dtype());
+ ASSERT_EQ(1, const_cloned->rank());
+ ASSERT_EQ(2, const_cloned->dim(0).value());
+ ASSERT_EQ(2, const_cloned->size<loco::DataType::S32>());
+ ASSERT_EQ(0, const_cloned->at<loco::DataType::S32>(0));
+ ASSERT_EQ(1, const_cloned->at<loco::DataType::S32>(1));
+ ASSERT_NE(nullptr, const_cloned->quantparam());
+ ASSERT_NE(nullptr, const_cloned->sparsityparam());
+}
+
+TEST(CircleConstTest, clone_U8)
+{
+ auto g = loco::make_graph();
+
+ // prepare source CircleConst
+ auto circle_const = new_empty_const<loco::DataType::U8>(g.get());
+
+ // make a clone
+ auto const_cloned = luci::clone(circle_const);
+
+ // check attributes
+ ASSERT_EQ(loco::DataType::U8, const_cloned->dtype());
+}
+
+TEST(CircleConstTest, clone_S8)
+{
+ auto g = loco::make_graph();
+
+ // prepare source CircleConst
+ auto circle_const = new_empty_const<loco::DataType::S8>(g.get());
+
+ // make a clone
+ auto const_cloned = luci::clone(circle_const);
+
+ // check attributes
+ ASSERT_EQ(loco::DataType::S8, const_cloned->dtype());
+}
+
+TEST(CircleConstTest, clone_S64)
+{
+ auto g = loco::make_graph();
+
+ // prepare source CircleConst
+ auto circle_const = new_empty_const<loco::DataType::S64>(g.get());
+
+ // make a clone
+ auto const_cloned = luci::clone(circle_const);
+
+ // check attributes
+ ASSERT_EQ(loco::DataType::S64, const_cloned->dtype());
+}
+
+TEST(CircleConstTest, clone_BOOL)
+{
+ auto g = loco::make_graph();
+
+ // prepare source CircleConst
+ auto circle_const = new_empty_const<loco::DataType::BOOL>(g.get());
+
+ // make a clone
+ auto const_cloned = luci::clone(circle_const);
+
+ // check attributes
+ ASSERT_EQ(loco::DataType::BOOL, const_cloned->dtype());
+}
+
+TEST(CloneNodeTest, clone_Const)
+{
+ auto g = loco::make_graph();
+ auto node_const = new_const_s32(g.get());
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_const, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_const = dynamic_cast<luci::CircleConst *>(cloned);
+ ASSERT_NE(nullptr, cloned_const);
+ ASSERT_EQ(loco::DataType::S32, cloned_const->dtype());
+ ASSERT_EQ(1, cloned_const->rank());
+ ASSERT_EQ(2, cloned_const->dim(0).value());
+ ASSERT_EQ(2, cloned_const->size<loco::DataType::S32>());
+ ASSERT_EQ(0, cloned_const->at<loco::DataType::S32>(0));
+ ASSERT_EQ(1, cloned_const->at<loco::DataType::S32>(1));
+ ASSERT_NE(nullptr, cloned_const->quantparam());
+ ASSERT_NE(nullptr, cloned_const->sparsityparam());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleConv2D.cpp b/compiler/luci/service/src/Nodes/CircleConv2D.cpp
new file mode 100644
index 000000000..08cd87ef7
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleConv2D.cpp
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleConv2D *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+ if (node->padding() == luci::Padding::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleConv2D>();
+ if (cloned != nullptr)
+ {
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ cloned->padding(node->padding());
+ cloned->stride()->h(node->stride()->h());
+ cloned->stride()->w(node->stride()->w());
+ cloned->dilation()->h(node->dilation()->h());
+ cloned->dilation()->w(node->dilation()->w());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleConv2D.test.cpp b/compiler/luci/service/src/Nodes/CircleConv2D.test.cpp
new file mode 100644
index 000000000..c265d6cd1
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleConv2D.test.cpp
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Conv2D)
+{
+ auto g = loco::make_graph();
+ auto node_conv2d = g->nodes()->create<luci::CircleConv2D>();
+ node_conv2d->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_conv2d->padding(luci::Padding::SAME);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_conv2d, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_conv2d = dynamic_cast<luci::CircleConv2D *>(cloned);
+ ASSERT_NE(nullptr, cloned_conv2d);
+ ASSERT_EQ(node_conv2d->fusedActivationFunction(), cloned_conv2d->fusedActivationFunction());
+ ASSERT_EQ(node_conv2d->padding(), cloned_conv2d->padding());
+}
+
+TEST(CloneNodeTest, clone_Conv2D_fusedact_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_conv2d = g->nodes()->create<luci::CircleConv2D>();
+ node_conv2d->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node_conv2d->padding(luci::Padding::SAME);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_conv2d, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
+
+TEST(CloneNodeTest, clone_Conv2D_padding_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_conv2d = g->nodes()->create<luci::CircleConv2D>();
+ node_conv2d->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_conv2d->padding(luci::Padding::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_conv2d, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleCos.cpp b/compiler/luci/service/src/Nodes/CircleCos.cpp
new file mode 100644
index 000000000..c46e3741b
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleCos.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleCos *)
+{
+ return _graph->nodes()->create<luci::CircleCos>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleCos.test.cpp b/compiler/luci/service/src/Nodes/CircleCos.test.cpp
new file mode 100644
index 000000000..a25943b98
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleCos.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Cos)
+{
+ auto g = loco::make_graph();
+ auto node_cos = g->nodes()->create<luci::CircleCos>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_cos, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_cos = dynamic_cast<luci::CircleCos *>(cloned);
+ ASSERT_NE(nullptr, cloned_cos);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleCustom.cpp b/compiler/luci/service/src/Nodes/CircleCustom.cpp
new file mode 100644
index 000000000..a9764c373
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleCustom.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleCustom *node)
+{
+ uint32_t num_in = node->numInputs();
+ uint32_t num_out = node->numOutputs();
+ auto *cloned = _graph->nodes()->create<luci::CircleCustom>(num_in, num_out);
+ if (cloned != nullptr)
+ {
+ cloned->custom_options(node->custom_options());
+ cloned->custom_code(node->custom_code());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleCustom.test.cpp b/compiler/luci/service/src/Nodes/CircleCustom.test.cpp
new file mode 100644
index 000000000..6fee68e71
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleCustom.test.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <vector>
+
+TEST(CloneNodeTest, clone_Custom)
+{
+ auto g = loco::make_graph();
+ auto node_custom = g->nodes()->create<luci::CircleCustom>(2, 3);
+ std::vector<uint8_t> options({0x55, 0x56, 0x57});
+ std::string code = "hello";
+ node_custom->custom_options(options);
+ node_custom->custom_code(code);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_custom, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_custom = dynamic_cast<luci::CircleCustom *>(cloned);
+ ASSERT_NE(nullptr, cloned_custom);
+ auto cloned_options = cloned_custom->custom_options();
+ ASSERT_EQ(options.size(), cloned_options.size());
+ auto size = options.size();
+ for (size_t s = 0; s < size; ++s)
+ ASSERT_EQ(options.at(s), cloned_options.at(s));
+ ASSERT_TRUE(node_custom->custom_code() == cloned_custom->custom_code());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleCustomOut.cpp b/compiler/luci/service/src/Nodes/CircleCustomOut.cpp
new file mode 100644
index 000000000..84577f529
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleCustomOut.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleCustomOut *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleCustomOut>();
+ if (cloned != nullptr)
+ cloned->index(node->index());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleCustomOut.test.cpp b/compiler/luci/service/src/Nodes/CircleCustomOut.test.cpp
new file mode 100644
index 000000000..15121bab6
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleCustomOut.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_CustomOut)
+{
+ auto g = loco::make_graph();
+ auto node_cout = g->nodes()->create<luci::CircleCustomOut>();
+ node_cout->index(1);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_cout, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_cout = dynamic_cast<luci::CircleCustomOut *>(cloned);
+ ASSERT_NE(nullptr, cloned_cout);
+ ASSERT_EQ(node_cout->index(), cloned_cout->index());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleDepthToSpace.cpp b/compiler/luci/service/src/Nodes/CircleDepthToSpace.cpp
new file mode 100644
index 000000000..7e0bc7d74
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleDepthToSpace.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleDepthToSpace *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleDepthToSpace>();
+ if (cloned != nullptr)
+ cloned->block_size(node->block_size());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleDepthToSpace.test.cpp b/compiler/luci/service/src/Nodes/CircleDepthToSpace.test.cpp
new file mode 100644
index 000000000..192b10b90
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleDepthToSpace.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_DepthToSpace)
+{
+ auto g = loco::make_graph();
+ auto node_d2s = g->nodes()->create<luci::CircleDepthToSpace>();
+ node_d2s->block_size(32);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_d2s, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_d2s = dynamic_cast<luci::CircleDepthToSpace *>(cloned);
+ ASSERT_NE(nullptr, cloned_d2s);
+ ASSERT_EQ(node_d2s->block_size(), cloned_d2s->block_size());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleDepthwiseConv2D.cpp b/compiler/luci/service/src/Nodes/CircleDepthwiseConv2D.cpp
new file mode 100644
index 000000000..8e0b23d94
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleDepthwiseConv2D.cpp
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleDepthwiseConv2D *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+ if (node->padding() == luci::Padding::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleDepthwiseConv2D>();
+ if (cloned != nullptr)
+ {
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ cloned->padding(node->padding());
+ cloned->stride()->h(node->stride()->h());
+ cloned->stride()->w(node->stride()->w());
+ cloned->depthMultiplier(node->depthMultiplier());
+ cloned->dilation()->h(node->dilation()->h());
+ cloned->dilation()->w(node->dilation()->w());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleDepthwiseConv2D.test.cpp b/compiler/luci/service/src/Nodes/CircleDepthwiseConv2D.test.cpp
new file mode 100644
index 000000000..8657464bc
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleDepthwiseConv2D.test.cpp
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_DepthwiseConv2D)
+{
+ auto g = loco::make_graph();
+ auto node_dwconv2d = g->nodes()->create<luci::CircleDepthwiseConv2D>();
+ node_dwconv2d->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_dwconv2d->padding(luci::Padding::SAME);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_dwconv2d, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_dwconv2d = dynamic_cast<luci::CircleDepthwiseConv2D *>(cloned);
+ ASSERT_NE(nullptr, cloned_dwconv2d);
+ ASSERT_EQ(node_dwconv2d->fusedActivationFunction(), cloned_dwconv2d->fusedActivationFunction());
+ ASSERT_EQ(node_dwconv2d->padding(), cloned_dwconv2d->padding());
+}
+
+TEST(CloneNodeTest, clone_DepthwiseConv2D_fusedact_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_dwconv2d = g->nodes()->create<luci::CircleDepthwiseConv2D>();
+ node_dwconv2d->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node_dwconv2d->padding(luci::Padding::SAME);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_dwconv2d, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
+
+TEST(CloneNodeTest, clone_DepthwiseConv2D_padding_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_dwconv2d = g->nodes()->create<luci::CircleDepthwiseConv2D>();
+ node_dwconv2d->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_dwconv2d->padding(luci::Padding::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_dwconv2d, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleDequantize.cpp b/compiler/luci/service/src/Nodes/CircleDequantize.cpp
new file mode 100644
index 000000000..79983e4d3
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleDequantize.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleDequantize *)
+{
+ return _graph->nodes()->create<luci::CircleDequantize>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleDequantize.test.cpp b/compiler/luci/service/src/Nodes/CircleDequantize.test.cpp
new file mode 100644
index 000000000..e1c563acf
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleDequantize.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Dequantize)
+{
+ auto g = loco::make_graph();
+ auto node_dq = g->nodes()->create<luci::CircleDequantize>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_dq, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_dq = dynamic_cast<luci::CircleDequantize *>(cloned);
+ ASSERT_NE(nullptr, cloned_dq);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleDiv.cpp b/compiler/luci/service/src/Nodes/CircleDiv.cpp
new file mode 100644
index 000000000..7c48d8b76
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleDiv.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleDiv *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleDiv>();
+ if (cloned != nullptr)
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleDiv.test.cpp b/compiler/luci/service/src/Nodes/CircleDiv.test.cpp
new file mode 100644
index 000000000..5182ac908
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleDiv.test.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Div)
+{
+ auto g = loco::make_graph();
+ auto node_div = g->nodes()->create<luci::CircleDiv>();
+ node_div->fusedActivationFunction(luci::FusedActFunc::RELU);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_div, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_div = dynamic_cast<luci::CircleDiv *>(cloned);
+ ASSERT_NE(nullptr, cloned_div);
+ ASSERT_EQ(node_div->fusedActivationFunction(), cloned_div->fusedActivationFunction());
+}
+
+TEST(CloneNodeTest, clone_Div_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_div = g->nodes()->create<luci::CircleDiv>();
+ node_div->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_div, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleElu.cpp b/compiler/luci/service/src/Nodes/CircleElu.cpp
new file mode 100644
index 000000000..e2df30285
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleElu.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleElu *)
+{
+ return _graph->nodes()->create<luci::CircleElu>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleElu.test.cpp b/compiler/luci/service/src/Nodes/CircleElu.test.cpp
new file mode 100644
index 000000000..e75b3bcb1
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleElu.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Elu)
+{
+ auto g = loco::make_graph();
+ auto node_elu = g->nodes()->create<luci::CircleElu>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_elu, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_elu = dynamic_cast<luci::CircleElu *>(cloned);
+ ASSERT_NE(nullptr, cloned_elu);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleEqual.cpp b/compiler/luci/service/src/Nodes/CircleEqual.cpp
new file mode 100644
index 000000000..5dd382d0b
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleEqual.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleEqual *)
+{
+ return _graph->nodes()->create<luci::CircleEqual>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleEqual.test.cpp b/compiler/luci/service/src/Nodes/CircleEqual.test.cpp
new file mode 100644
index 000000000..99a5535fc
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleEqual.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Equal)
+{
+ auto g = loco::make_graph();
+ auto node_eq = g->nodes()->create<luci::CircleEqual>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_eq, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_eq = dynamic_cast<luci::CircleEqual *>(cloned);
+ ASSERT_NE(nullptr, cloned_eq);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleExp.cpp b/compiler/luci/service/src/Nodes/CircleExp.cpp
new file mode 100644
index 000000000..3d4918320
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleExp.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleExp *)
+{
+ return _graph->nodes()->create<luci::CircleExp>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleExp.test.cpp b/compiler/luci/service/src/Nodes/CircleExp.test.cpp
new file mode 100644
index 000000000..ff2bb65db
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleExp.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Exp)
+{
+ auto g = loco::make_graph();
+ auto node_exp = g->nodes()->create<luci::CircleExp>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_exp, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_exp = dynamic_cast<luci::CircleExp *>(cloned);
+ ASSERT_NE(nullptr, cloned_exp);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleExpandDims.cpp b/compiler/luci/service/src/Nodes/CircleExpandDims.cpp
new file mode 100644
index 000000000..4dd1cec86
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleExpandDims.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleExpandDims *)
+{
+ return _graph->nodes()->create<luci::CircleExpandDims>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleExpandDims.test.cpp b/compiler/luci/service/src/Nodes/CircleExpandDims.test.cpp
new file mode 100644
index 000000000..e3481bccd
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleExpandDims.test.cpp
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Service/CircleShapeInference.h>
+
+#include <loco/IR/TensorShape.h>
+
+#include <gtest/gtest.h>
+
+TEST(ShapeRuleTest, simple_expand_dims)
+{
+ luci::CircleInput input;
+ luci::CircleConst axis;
+ luci::CircleExpandDims expand_dims;
+
+ input.shape({4, 3});
+ input.shape_status(luci::ShapeStatus::VALID);
+
+ axis.dtype(loco::DataType::S32);
+ axis.rank(0);
+ axis.size<loco::DataType::S32>(1);
+ axis.at<loco::DataType::S32>(0) = 1;
+ axis.shape_status(luci::ShapeStatus::VALID);
+
+ expand_dims.input(&input);
+ expand_dims.axis(&axis);
+
+ loco::TensorShape shape;
+ luci::sinf::Rule shape_inf_rule;
+
+ ASSERT_TRUE(shape_inf_rule.infer(&expand_dims, shape));
+ ASSERT_EQ(3, shape.rank());
+ ASSERT_EQ(4, shape.dim(0).value());
+ ASSERT_EQ(1, shape.dim(1).value());
+ ASSERT_EQ(3, shape.dim(2).value());
+}
+
+TEST(CloneNodeTest, clone_ExpandDims)
+{
+ auto g = loco::make_graph();
+ auto node_ed = g->nodes()->create<luci::CircleExpandDims>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_ed, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_ed = dynamic_cast<luci::CircleExpandDims *>(cloned);
+ ASSERT_NE(nullptr, cloned_ed);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleFakeQuant.cpp b/compiler/luci/service/src/Nodes/CircleFakeQuant.cpp
new file mode 100644
index 000000000..7abaca685
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleFakeQuant.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleFakeQuant *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleFakeQuant>();
+ if (cloned != nullptr)
+ {
+ cloned->min(node->min());
+ cloned->max(node->max());
+ cloned->num_bits(node->num_bits());
+ cloned->narrow_range(node->narrow_range());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleFakeQuant.test.cpp b/compiler/luci/service/src/Nodes/CircleFakeQuant.test.cpp
new file mode 100644
index 000000000..2c4e3b836
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleFakeQuant.test.cpp
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_FakeQuant)
+{
+ auto g = loco::make_graph();
+ auto node_fq = g->nodes()->create<luci::CircleFakeQuant>();
+ node_fq->min(1.0f);
+ node_fq->max(2.0f);
+ node_fq->num_bits(8);
+ node_fq->narrow_range(true);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_fq, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_fq = dynamic_cast<luci::CircleFakeQuant *>(cloned);
+ ASSERT_NE(nullptr, cloned_fq);
+ ASSERT_EQ(node_fq->min(), cloned_fq->min());
+ ASSERT_EQ(node_fq->max(), cloned_fq->max());
+ ASSERT_EQ(node_fq->num_bits(), cloned_fq->num_bits());
+ ASSERT_EQ(node_fq->narrow_range(), cloned_fq->narrow_range());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleFill.cpp b/compiler/luci/service/src/Nodes/CircleFill.cpp
new file mode 100644
index 000000000..d9b74c63a
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleFill.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleFill *)
+{
+ return _graph->nodes()->create<luci::CircleFill>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleFill.test.cpp b/compiler/luci/service/src/Nodes/CircleFill.test.cpp
new file mode 100644
index 000000000..56c807585
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleFill.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Fill)
+{
+ auto g = loco::make_graph();
+ auto node_fill = g->nodes()->create<luci::CircleFill>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_fill, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_fill = dynamic_cast<luci::CircleFill *>(cloned);
+ ASSERT_NE(nullptr, cloned_fill);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleFloor.cpp b/compiler/luci/service/src/Nodes/CircleFloor.cpp
new file mode 100644
index 000000000..532808bc8
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleFloor.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleFloor *)
+{
+ return _graph->nodes()->create<luci::CircleFloor>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleFloor.test.cpp b/compiler/luci/service/src/Nodes/CircleFloor.test.cpp
new file mode 100644
index 000000000..3d53fd2c3
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleFloor.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Floor)
+{
+ auto g = loco::make_graph();
+ auto node_floor = g->nodes()->create<luci::CircleFloor>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_floor, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_floor = dynamic_cast<luci::CircleFloor *>(cloned);
+ ASSERT_NE(nullptr, cloned_floor);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleFloorDiv.cpp b/compiler/luci/service/src/Nodes/CircleFloorDiv.cpp
new file mode 100644
index 000000000..65be3e868
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleFloorDiv.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleFloorDiv *)
+{
+ return _graph->nodes()->create<luci::CircleFloorDiv>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleFloorDiv.test.cpp b/compiler/luci/service/src/Nodes/CircleFloorDiv.test.cpp
new file mode 100644
index 000000000..6365ccd3b
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleFloorDiv.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_FloorDiv)
+{
+ auto g = loco::make_graph();
+ auto node_floordiv = g->nodes()->create<luci::CircleFloorDiv>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_floordiv, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_floordiv = dynamic_cast<luci::CircleFloorDiv *>(cloned);
+ ASSERT_NE(nullptr, cloned_floordiv);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleFloorMod.cpp b/compiler/luci/service/src/Nodes/CircleFloorMod.cpp
new file mode 100644
index 000000000..00e6a0499
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleFloorMod.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleFloorMod *)
+{
+ return _graph->nodes()->create<luci::CircleFloorMod>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleFloorMod.test.cpp b/compiler/luci/service/src/Nodes/CircleFloorMod.test.cpp
new file mode 100644
index 000000000..ce91d5881
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleFloorMod.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_FloorMod)
+{
+ auto g = loco::make_graph();
+ auto node_floormod = g->nodes()->create<luci::CircleFloorMod>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_floormod, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_floormod = dynamic_cast<luci::CircleFloorMod *>(cloned);
+ ASSERT_NE(nullptr, cloned_floormod);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleFullyConnected.cpp b/compiler/luci/service/src/Nodes/CircleFullyConnected.cpp
new file mode 100644
index 000000000..8acb35cbf
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleFullyConnected.cpp
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleFullyConnected *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+ if (node->weights_format() == luci::CircleFullyConnected::WeightsFormat::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleFullyConnected>();
+ if (cloned != nullptr)
+ {
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ cloned->weights_format(node->weights_format());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleFullyConnected.test.cpp b/compiler/luci/service/src/Nodes/CircleFullyConnected.test.cpp
new file mode 100644
index 000000000..965b59130
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleFullyConnected.test.cpp
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_FullyConnected)
+{
+ auto g = loco::make_graph();
+ auto node_fc = g->nodes()->create<luci::CircleFullyConnected>();
+ node_fc->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_fc->weights_format(luci::CircleFullyConnected::WeightsFormat::DEFAULT);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_fc, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_fc = dynamic_cast<luci::CircleFullyConnected *>(cloned);
+ ASSERT_NE(nullptr, cloned_fc);
+ ASSERT_EQ(node_fc->fusedActivationFunction(), cloned_fc->fusedActivationFunction());
+ ASSERT_EQ(node_fc->weights_format(), cloned_fc->weights_format());
+}
+
+TEST(CloneNodeTest, clone_FullyConnected_fusedact_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_fc = g->nodes()->create<luci::CircleFullyConnected>();
+ node_fc->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node_fc->weights_format(luci::CircleFullyConnected::WeightsFormat::DEFAULT);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_fc, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
+
+TEST(CloneNodeTest, clone_FullyConnected_wf_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_fc = g->nodes()->create<luci::CircleFullyConnected>();
+ node_fc->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_fc->weights_format(luci::CircleFullyConnected::WeightsFormat::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_fc, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleGather.cpp b/compiler/luci/service/src/Nodes/CircleGather.cpp
new file mode 100644
index 000000000..072bdeabc
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleGather.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleGather *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleGather>();
+ if (cloned != nullptr)
+ cloned->axis(node->axis());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleGather.test.cpp b/compiler/luci/service/src/Nodes/CircleGather.test.cpp
new file mode 100644
index 000000000..f48dbdb67
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleGather.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Gather)
+{
+ auto g = loco::make_graph();
+ auto node_gat = g->nodes()->create<luci::CircleGather>();
+ node_gat->axis(3);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_gat, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_gat = dynamic_cast<luci::CircleGather *>(cloned);
+ ASSERT_NE(nullptr, cloned_gat);
+ ASSERT_EQ(node_gat->axis(), cloned_gat->axis());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleGatherNd.cpp b/compiler/luci/service/src/Nodes/CircleGatherNd.cpp
new file mode 100644
index 000000000..df7ae6e79
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleGatherNd.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleGatherNd *)
+{
+ return _graph->nodes()->create<luci::CircleGatherNd>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleGatherNd.test.cpp b/compiler/luci/service/src/Nodes/CircleGatherNd.test.cpp
new file mode 100644
index 000000000..3a705710c
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleGatherNd.test.cpp
@@ -0,0 +1,113 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Service/CircleShapeInference.h>
+
+#include <loco/IR/TensorShape.h>
+
+#include <oops/InternalExn.h>
+
+#include <gtest/gtest.h>
+
+TEST(ShapeRuleTest, gather_nd_simple)
+{
+ luci::CircleInput input;
+ luci::CircleConst indices_const;
+ luci::CircleGatherNd gather_nd;
+
+ input.shape({1, 4, 4, 3});
+ indices_const.shape({1, 2, 3});
+
+ input.shape_status(luci::ShapeStatus::VALID);
+ indices_const.shape_status(luci::ShapeStatus::VALID);
+
+ gather_nd.params(&input);
+ gather_nd.indices(&indices_const);
+
+ loco::TensorShape shape;
+ luci::sinf::Rule shape_inf_rule;
+
+ ASSERT_TRUE(shape_inf_rule.infer(&gather_nd, shape));
+ ASSERT_EQ(3, shape.rank());
+ ASSERT_EQ(1, shape.dim(0).value());
+ ASSERT_EQ(2, shape.dim(1).value());
+ ASSERT_EQ(3, shape.dim(2).value());
+}
+
+TEST(ShapeRuleTest, gather_nd_slices)
+{
+ luci::CircleInput input;
+ luci::CircleConst indices_const;
+ luci::CircleGatherNd gather_nd;
+
+ input.shape({1, 4, 4, 3});
+ indices_const.shape({1, 2, 1});
+
+ input.shape_status(luci::ShapeStatus::VALID);
+ indices_const.shape_status(luci::ShapeStatus::VALID);
+
+ gather_nd.params(&input);
+ gather_nd.indices(&indices_const);
+
+ loco::TensorShape shape;
+ luci::sinf::Rule shape_inf_rule;
+
+ ASSERT_TRUE(shape_inf_rule.infer(&gather_nd, shape));
+ ASSERT_EQ(5, shape.rank());
+ ASSERT_EQ(1, shape.dim(0).value());
+ ASSERT_EQ(2, shape.dim(1).value());
+ ASSERT_EQ(4, shape.dim(2).value());
+ ASSERT_EQ(4, shape.dim(3).value());
+ ASSERT_EQ(3, shape.dim(4).value());
+}
+
+TEST(ShapeRuleTest, gather_nd_NEG)
+{
+ luci::CircleInput input;
+ luci::CircleConst indices_const;
+ luci::CircleGatherNd gather_nd;
+
+ input.shape({1, 4, 4, 3});
+ indices_const.shape({1, 2, 5});
+
+ input.shape_status(luci::ShapeStatus::VALID);
+ indices_const.shape_status(luci::ShapeStatus::VALID);
+
+ gather_nd.params(&input);
+ gather_nd.indices(&indices_const);
+
+ loco::TensorShape shape;
+ luci::sinf::Rule shape_inf_rule;
+
+ ASSERT_THROW(shape_inf_rule.infer(&gather_nd, shape), oops::InternalExn);
+}
+
+TEST(CloneNodeTest, clone_GatherNd)
+{
+ auto g = loco::make_graph();
+ auto node_gtnd = g->nodes()->create<luci::CircleGatherNd>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_gtnd, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_gtnd = dynamic_cast<luci::CircleGatherNd *>(cloned);
+ ASSERT_NE(nullptr, cloned_gtnd);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleGreater.cpp b/compiler/luci/service/src/Nodes/CircleGreater.cpp
new file mode 100644
index 000000000..366d955bf
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleGreater.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleGreater *)
+{
+ return _graph->nodes()->create<luci::CircleGreater>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleGreater.test.cpp b/compiler/luci/service/src/Nodes/CircleGreater.test.cpp
new file mode 100644
index 000000000..6d2df61f0
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleGreater.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Greater)
+{
+ auto g = loco::make_graph();
+ auto node_gt = g->nodes()->create<luci::CircleGreater>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_gt, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_gt = dynamic_cast<luci::CircleGreater *>(cloned);
+ ASSERT_NE(nullptr, cloned_gt);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleGreaterEqual.cpp b/compiler/luci/service/src/Nodes/CircleGreaterEqual.cpp
new file mode 100644
index 000000000..9705bbe1e
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleGreaterEqual.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleGreaterEqual *)
+{
+ return _graph->nodes()->create<luci::CircleGreaterEqual>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleGreaterEqual.test.cpp b/compiler/luci/service/src/Nodes/CircleGreaterEqual.test.cpp
new file mode 100644
index 000000000..10387df3a
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleGreaterEqual.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_GreaterEqual)
+{
+ auto g = loco::make_graph();
+ auto node_ge = g->nodes()->create<luci::CircleGreaterEqual>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_ge, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_ge = dynamic_cast<luci::CircleGreaterEqual *>(cloned);
+ ASSERT_NE(nullptr, cloned_ge);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleIfOut.cpp b/compiler/luci/service/src/Nodes/CircleIfOut.cpp
new file mode 100644
index 000000000..31ad7203f
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleIfOut.cpp
@@ -0,0 +1,89 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeInference.h>
+#include <luci/Service/CircleTypeInference.h>
+
+namespace
+{
+
+struct CircleIfOutGraphs
+{
+ loco::GraphOutput *then_graph_output;
+ loco::GraphOutput *else_graph_output;
+};
+
+} // namespace
+
+namespace
+{
+
+CircleIfOutGraphs get_out_graphs(const luci::CircleIfOut *node)
+{
+ CircleIfOutGraphs ret_out;
+
+ /**
+ * @note IF operator type and shape are that of the "then" and "else"
+ * Graph Outputs.
+ */
+ auto circle_if = loco::must_cast<const luci::CircleIf *>(node->input());
+
+ auto index = node->index();
+ auto then_graph = circle_if->then_graph();
+ auto else_graph = circle_if->else_graph();
+ assert(then_graph != nullptr);
+ assert(else_graph != nullptr);
+
+ // shape and type are assumed to be same
+ // these are checked at post_import_graph() in Import
+ auto then_outputs = loco::output_nodes(then_graph);
+ auto else_outputs = loco::output_nodes(else_graph);
+ assert(then_outputs.size() == else_outputs.size());
+ assert(index < static_cast<int32_t>(then_outputs.size()));
+
+ auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index));
+ auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index));
+
+ auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items
+ auto else_graph_outputs = else_graph->outputs();
+ assert(then_graph_outputs->size() == else_graph_outputs->size());
+
+ ret_out.then_graph_output = then_graph_outputs->at(then_out->index());
+ ret_out.else_graph_output = else_graph_outputs->at(else_out->index());
+
+ return ret_out;
+}
+
+} // namespace
+
+namespace luci
+{
+
+loco::TensorShape sinf::Algorithm::visit(const luci::CircleIfOut *node)
+{
+ auto graphs = get_out_graphs(node);
+ assert(*graphs.then_graph_output->shape() == *graphs.else_graph_output->shape());
+ return *graphs.then_graph_output->shape();
+}
+
+loco::DataType tinf::Algorithm::visit(const luci::CircleIfOut *node)
+{
+ auto graphs = get_out_graphs(node);
+ assert(graphs.then_graph_output->dtype() == graphs.else_graph_output->dtype());
+ return graphs.then_graph_output->dtype();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleInstanceNorm.cpp b/compiler/luci/service/src/Nodes/CircleInstanceNorm.cpp
new file mode 100644
index 000000000..d9e49d8ed
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleInstanceNorm.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleInstanceNorm *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleInstanceNorm>();
+ if (cloned != nullptr)
+ {
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ cloned->epsilon(node->epsilon());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleInstanceNorm.test.cpp b/compiler/luci/service/src/Nodes/CircleInstanceNorm.test.cpp
new file mode 100644
index 000000000..bae92b1ae
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleInstanceNorm.test.cpp
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_InstanceNorm)
+{
+ auto g = loco::make_graph();
+ auto node_fc = g->nodes()->create<luci::CircleInstanceNorm>();
+ node_fc->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_fc->epsilon(3);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_fc, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_fc = dynamic_cast<luci::CircleInstanceNorm *>(cloned);
+ ASSERT_NE(nullptr, cloned_fc);
+ ASSERT_EQ(node_fc->fusedActivationFunction(), cloned_fc->fusedActivationFunction());
+ ASSERT_EQ(node_fc->epsilon(), cloned_fc->epsilon());
+}
+
+TEST(CloneNodeTest, clone_InstanceNorm_fusedact_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_fc = g->nodes()->create<luci::CircleInstanceNorm>();
+ node_fc->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_fc, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleL2Normalize.cpp b/compiler/luci/service/src/Nodes/CircleL2Normalize.cpp
new file mode 100644
index 000000000..afa2a6acb
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleL2Normalize.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleL2Normalize *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleL2Normalize>();
+ if (cloned != nullptr)
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleL2Normalize.test.cpp b/compiler/luci/service/src/Nodes/CircleL2Normalize.test.cpp
new file mode 100644
index 000000000..0f148797e
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleL2Normalize.test.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_L2Normalize)
+{
+ auto g = loco::make_graph();
+ auto node_l2n = g->nodes()->create<luci::CircleL2Normalize>();
+ node_l2n->fusedActivationFunction(luci::FusedActFunc::RELU);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_l2n, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_l2n = dynamic_cast<luci::CircleL2Normalize *>(cloned);
+ ASSERT_NE(nullptr, cloned_l2n);
+ ASSERT_EQ(node_l2n->fusedActivationFunction(), cloned_l2n->fusedActivationFunction());
+}
+
+TEST(CloneNodeTest, clone_L2Normalize_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_l2n = g->nodes()->create<luci::CircleL2Normalize>();
+ node_l2n->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_l2n, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleL2Pool2D.cpp b/compiler/luci/service/src/Nodes/CircleL2Pool2D.cpp
new file mode 100644
index 000000000..2d876c5bc
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleL2Pool2D.cpp
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleL2Pool2D *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+ if (node->padding() == luci::Padding::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleL2Pool2D>();
+ if (cloned != nullptr)
+ {
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ cloned->padding(node->padding());
+ cloned->filter()->h(node->filter()->h());
+ cloned->filter()->w(node->filter()->w());
+ cloned->stride()->h(node->stride()->h());
+ cloned->stride()->w(node->stride()->w());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleL2Pool2D.test.cpp b/compiler/luci/service/src/Nodes/CircleL2Pool2D.test.cpp
new file mode 100644
index 000000000..37344fd9a
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleL2Pool2D.test.cpp
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_L2Pool2D)
+{
+ auto g = loco::make_graph();
+ auto node_l2n = g->nodes()->create<luci::CircleL2Pool2D>();
+ node_l2n->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_l2n->padding(luci::Padding::SAME);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_l2n, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_l2n = dynamic_cast<luci::CircleL2Pool2D *>(cloned);
+ ASSERT_NE(nullptr, cloned_l2n);
+ ASSERT_EQ(node_l2n->fusedActivationFunction(), cloned_l2n->fusedActivationFunction());
+ ASSERT_EQ(node_l2n->padding(), cloned_l2n->padding());
+}
+
+TEST(CloneNodeTest, clone_L2Normalize_fusedact_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_l2n = g->nodes()->create<luci::CircleL2Pool2D>();
+ node_l2n->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node_l2n->padding(luci::Padding::SAME);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_l2n, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
+
+TEST(CloneNodeTest, clone_L2Normalize_padding_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_l2n = g->nodes()->create<luci::CircleL2Pool2D>();
+ node_l2n->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_l2n->padding(luci::Padding::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_l2n, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleLeakyRelu.cpp b/compiler/luci/service/src/Nodes/CircleLeakyRelu.cpp
new file mode 100644
index 000000000..91030618c
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLeakyRelu.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleLeakyRelu *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleLeakyRelu>();
+ if (cloned != nullptr)
+ cloned->alpha(node->alpha());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleLeakyRelu.test.cpp b/compiler/luci/service/src/Nodes/CircleLeakyRelu.test.cpp
new file mode 100644
index 000000000..17fc1442a
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLeakyRelu.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_LeakyRelu)
+{
+ auto g = loco::make_graph();
+ auto node_lr = g->nodes()->create<luci::CircleLeakyRelu>();
+ node_lr->alpha(1.2f);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_lr, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_lr = dynamic_cast<luci::CircleLeakyRelu *>(cloned);
+ ASSERT_NE(nullptr, cloned_lr);
+ ASSERT_EQ(node_lr->alpha(), cloned_lr->alpha());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleLess.cpp b/compiler/luci/service/src/Nodes/CircleLess.cpp
new file mode 100644
index 000000000..33b70b735
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLess.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleLess *)
+{
+ return _graph->nodes()->create<luci::CircleLess>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleLess.test.cpp b/compiler/luci/service/src/Nodes/CircleLess.test.cpp
new file mode 100644
index 000000000..43248948d
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLess.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Less)
+{
+ auto g = loco::make_graph();
+ auto node_less = g->nodes()->create<luci::CircleLess>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_less, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_less = dynamic_cast<luci::CircleLess *>(cloned);
+ ASSERT_NE(nullptr, cloned_less);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleLessEqual.cpp b/compiler/luci/service/src/Nodes/CircleLessEqual.cpp
new file mode 100644
index 000000000..22491365a
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLessEqual.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleLessEqual *)
+{
+ return _graph->nodes()->create<luci::CircleLessEqual>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleLessEqual.test.cpp b/compiler/luci/service/src/Nodes/CircleLessEqual.test.cpp
new file mode 100644
index 000000000..0a87daf5d
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLessEqual.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_LessEqual)
+{
+ auto g = loco::make_graph();
+ auto node_le = g->nodes()->create<luci::CircleLessEqual>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_le, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_le = dynamic_cast<luci::CircleLessEqual *>(cloned);
+ ASSERT_NE(nullptr, cloned_le);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleLocalResponseNormalization.cpp b/compiler/luci/service/src/Nodes/CircleLocalResponseNormalization.cpp
new file mode 100644
index 000000000..bf69b5ef5
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLocalResponseNormalization.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleLocalResponseNormalization *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleLocalResponseNormalization>();
+ if (cloned != nullptr)
+ {
+ cloned->radius(node->radius());
+ cloned->bias(node->bias());
+ cloned->alpha(node->alpha());
+ cloned->beta(node->beta());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleLocalResponseNormalization.test.cpp b/compiler/luci/service/src/Nodes/CircleLocalResponseNormalization.test.cpp
new file mode 100644
index 000000000..262b119bb
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLocalResponseNormalization.test.cpp
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_LocalResponseNormalization)
+{
+ auto g = loco::make_graph();
+ auto node_lrn = g->nodes()->create<luci::CircleLocalResponseNormalization>();
+ node_lrn->radius(32);
+ node_lrn->bias(1.2f);
+ node_lrn->alpha(3.4f);
+ node_lrn->beta(5.7f);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_lrn, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_lrn = dynamic_cast<luci::CircleLocalResponseNormalization *>(cloned);
+ ASSERT_NE(nullptr, cloned_lrn);
+ ASSERT_EQ(node_lrn->radius(), cloned_lrn->radius());
+ ASSERT_EQ(node_lrn->bias(), cloned_lrn->bias());
+ ASSERT_EQ(node_lrn->alpha(), cloned_lrn->alpha());
+ ASSERT_EQ(node_lrn->beta(), cloned_lrn->beta());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleLog.cpp b/compiler/luci/service/src/Nodes/CircleLog.cpp
new file mode 100644
index 000000000..5788f129f
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLog.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleLog *)
+{
+ return _graph->nodes()->create<luci::CircleLog>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleLog.test.cpp b/compiler/luci/service/src/Nodes/CircleLog.test.cpp
new file mode 100644
index 000000000..d1ee1428e
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLog.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Log)
+{
+ auto g = loco::make_graph();
+ auto node_log = g->nodes()->create<luci::CircleLog>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_log, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_log = dynamic_cast<luci::CircleLog *>(cloned);
+ ASSERT_NE(nullptr, cloned_log);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleLogSoftmax.cpp b/compiler/luci/service/src/Nodes/CircleLogSoftmax.cpp
new file mode 100644
index 000000000..352160aff
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLogSoftmax.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleLogSoftmax *)
+{
+ return _graph->nodes()->create<luci::CircleLogSoftmax>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleLogSoftmax.test.cpp b/compiler/luci/service/src/Nodes/CircleLogSoftmax.test.cpp
new file mode 100644
index 000000000..feebb79cb
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLogSoftmax.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_LogSoftmax)
+{
+ auto g = loco::make_graph();
+ auto node_logs = g->nodes()->create<luci::CircleLogSoftmax>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_logs, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_logs = dynamic_cast<luci::CircleLogSoftmax *>(cloned);
+ ASSERT_NE(nullptr, cloned_logs);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleLogicalAnd.cpp b/compiler/luci/service/src/Nodes/CircleLogicalAnd.cpp
new file mode 100644
index 000000000..5df62b951
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLogicalAnd.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleLogicalAnd *)
+{
+ return _graph->nodes()->create<luci::CircleLogicalAnd>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleLogicalAnd.test.cpp b/compiler/luci/service/src/Nodes/CircleLogicalAnd.test.cpp
new file mode 100644
index 000000000..aa811edfa
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLogicalAnd.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_LogicalAnd)
+{
+ auto g = loco::make_graph();
+ auto node_logand = g->nodes()->create<luci::CircleLogicalAnd>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_logand, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_logand = dynamic_cast<luci::CircleLogicalAnd *>(cloned);
+ ASSERT_NE(nullptr, cloned_logand);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleLogicalNot.cpp b/compiler/luci/service/src/Nodes/CircleLogicalNot.cpp
new file mode 100644
index 000000000..ac982829d
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLogicalNot.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleLogicalNot *)
+{
+ return _graph->nodes()->create<luci::CircleLogicalNot>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleLogicalNot.test.cpp b/compiler/luci/service/src/Nodes/CircleLogicalNot.test.cpp
new file mode 100644
index 000000000..9e55be944
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLogicalNot.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_LogicalNot)
+{
+ auto g = loco::make_graph();
+ auto node_lognot = g->nodes()->create<luci::CircleLogicalNot>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_lognot, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_lognot = dynamic_cast<luci::CircleLogicalNot *>(cloned);
+ ASSERT_NE(nullptr, cloned_lognot);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleLogicalOr.cpp b/compiler/luci/service/src/Nodes/CircleLogicalOr.cpp
new file mode 100644
index 000000000..1201d6f34
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLogicalOr.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleLogicalOr *)
+{
+ return _graph->nodes()->create<luci::CircleLogicalOr>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleLogicalOr.test.cpp b/compiler/luci/service/src/Nodes/CircleLogicalOr.test.cpp
new file mode 100644
index 000000000..19b706dcd
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLogicalOr.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_LogicalOr)
+{
+ auto g = loco::make_graph();
+ auto node_logor = g->nodes()->create<luci::CircleLogicalOr>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_logor, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_logor = dynamic_cast<luci::CircleLogicalOr *>(cloned);
+ ASSERT_NE(nullptr, cloned_logor);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleLogistic.cpp b/compiler/luci/service/src/Nodes/CircleLogistic.cpp
new file mode 100644
index 000000000..b21b187e9
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLogistic.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleLogistic *)
+{
+ return _graph->nodes()->create<luci::CircleLogistic>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleLogistic.test.cpp b/compiler/luci/service/src/Nodes/CircleLogistic.test.cpp
new file mode 100644
index 000000000..05dbe46e4
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleLogistic.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Logistic)
+{
+ auto g = loco::make_graph();
+ auto node_log = g->nodes()->create<luci::CircleLogistic>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_log, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_log = dynamic_cast<luci::CircleLogistic *>(cloned);
+ ASSERT_NE(nullptr, cloned_log);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleMatrixDiag.cpp b/compiler/luci/service/src/Nodes/CircleMatrixDiag.cpp
new file mode 100644
index 000000000..2bffa07b1
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMatrixDiag.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleMatrixDiag *)
+{
+ return _graph->nodes()->create<luci::CircleMatrixDiag>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleMatrixDiag.test.cpp b/compiler/luci/service/src/Nodes/CircleMatrixDiag.test.cpp
new file mode 100644
index 000000000..c08c4cb94
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMatrixDiag.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_MatrixDiag)
+{
+ auto g = loco::make_graph();
+ auto node_md = g->nodes()->create<luci::CircleMatrixDiag>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_md, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_md = dynamic_cast<luci::CircleMatrixDiag *>(cloned);
+ ASSERT_NE(nullptr, cloned_md);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleMatrixSetDiag.cpp b/compiler/luci/service/src/Nodes/CircleMatrixSetDiag.cpp
new file mode 100644
index 000000000..5ea2a5339
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMatrixSetDiag.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleMatrixSetDiag *)
+{
+ return _graph->nodes()->create<luci::CircleMatrixSetDiag>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleMatrixSetDiag.test.cpp b/compiler/luci/service/src/Nodes/CircleMatrixSetDiag.test.cpp
new file mode 100644
index 000000000..5ea77ba75
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMatrixSetDiag.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_MatrixSetDiag)
+{
+ auto g = loco::make_graph();
+ auto node_msd = g->nodes()->create<luci::CircleMatrixSetDiag>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_msd, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_msd = dynamic_cast<luci::CircleMatrixSetDiag *>(cloned);
+ ASSERT_NE(nullptr, cloned_msd);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleMaxPool2D.cpp b/compiler/luci/service/src/Nodes/CircleMaxPool2D.cpp
new file mode 100644
index 000000000..b21610c7f
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMaxPool2D.cpp
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleMaxPool2D *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+ if (node->padding() == luci::Padding::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleMaxPool2D>();
+ if (cloned != nullptr)
+ {
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ cloned->padding(node->padding());
+ cloned->filter()->h(node->filter()->h());
+ cloned->filter()->w(node->filter()->w());
+ cloned->stride()->h(node->stride()->h());
+ cloned->stride()->w(node->stride()->w());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleMaxPool2D.test.cpp b/compiler/luci/service/src/Nodes/CircleMaxPool2D.test.cpp
new file mode 100644
index 000000000..415cf7c44
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMaxPool2D.test.cpp
@@ -0,0 +1,69 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_MaxPool2D)
+{
+ auto g = loco::make_graph();
+ auto node_mp = g->nodes()->create<luci::CircleMaxPool2D>();
+ node_mp->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_mp->padding(luci::Padding::SAME);
+ node_mp->filter()->h(1);
+ node_mp->filter()->w(2);
+ node_mp->stride()->h(3);
+ node_mp->stride()->w(4);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_mp, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_mp = dynamic_cast<luci::CircleMaxPool2D *>(cloned);
+ ASSERT_NE(nullptr, cloned_mp);
+ ASSERT_EQ(node_mp->fusedActivationFunction(), cloned_mp->fusedActivationFunction());
+ ASSERT_EQ(node_mp->padding(), cloned_mp->padding());
+ ASSERT_EQ(node_mp->filter()->h(), cloned_mp->filter()->h());
+ ASSERT_EQ(node_mp->filter()->w(), cloned_mp->filter()->w());
+ ASSERT_EQ(node_mp->stride()->h(), cloned_mp->stride()->h());
+ ASSERT_EQ(node_mp->stride()->w(), cloned_mp->stride()->w());
+}
+
+TEST(CloneNodeTest, clone_MaxPool2D_fusedact_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_mp = g->nodes()->create<luci::CircleMaxPool2D>();
+ node_mp->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+ node_mp->padding(luci::Padding::SAME);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_mp, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
+
+TEST(CloneNodeTest, clone_MaxPool2D_padding_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_mp = g->nodes()->create<luci::CircleMaxPool2D>();
+ node_mp->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_mp->padding(luci::Padding::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_mp, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleMaximum.cpp b/compiler/luci/service/src/Nodes/CircleMaximum.cpp
new file mode 100644
index 000000000..545f4ca21
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMaximum.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleMaximum *)
+{
+ return _graph->nodes()->create<luci::CircleMaximum>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleMaximum.test.cpp b/compiler/luci/service/src/Nodes/CircleMaximum.test.cpp
new file mode 100644
index 000000000..6f1ada060
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMaximum.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Maximum)
+{
+ auto g = loco::make_graph();
+ auto node_max = g->nodes()->create<luci::CircleMaximum>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_max, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_max = dynamic_cast<luci::CircleMaximum *>(cloned);
+ ASSERT_NE(nullptr, cloned_max);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleMean.cpp b/compiler/luci/service/src/Nodes/CircleMean.cpp
index a78713698..95bc54532 100644
--- a/compiler/luci/service/src/Nodes/CircleMean.cpp
+++ b/compiler/luci/service/src/Nodes/CircleMean.cpp
@@ -1,11 +1,11 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -14,15 +14,17 @@
* limitations under the License.
*/
-#include <luci/Service/CircleShapeSignatureInference.h>
+#include "CircleCloneNode.h"
namespace luci
{
-ShapeSignature ssinf::Algorithm::visit(const luci::CircleMean *node)
+luci::CircleNode *CloneNode::visit(const luci::CircleMean *node)
{
- return legalized_signature(
- reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+ auto *cloned = _graph->nodes()->create<luci::CircleMean>();
+ if (cloned != nullptr)
+ cloned->keep_dims(node->keep_dims());
+ return cloned;
}
} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleMean.test.cpp b/compiler/luci/service/src/Nodes/CircleMean.test.cpp
new file mode 100644
index 000000000..aa1b88f13
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMean.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Mean)
+{
+ auto g = loco::make_graph();
+ auto node_mean = g->nodes()->create<luci::CircleMean>();
+ node_mean->keep_dims(true);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_mean, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_mean = dynamic_cast<luci::CircleMean *>(cloned);
+ ASSERT_NE(nullptr, cloned_mean);
+ ASSERT_EQ(node_mean->keep_dims(), cloned_mean->keep_dims());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleMinimum.cpp b/compiler/luci/service/src/Nodes/CircleMinimum.cpp
new file mode 100644
index 000000000..2c2755c55
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMinimum.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleMinimum *)
+{
+ return _graph->nodes()->create<luci::CircleMinimum>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleMinimum.test.cpp b/compiler/luci/service/src/Nodes/CircleMinimum.test.cpp
new file mode 100644
index 000000000..0a54be71c
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMinimum.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Minimum)
+{
+ auto g = loco::make_graph();
+ auto node_min = g->nodes()->create<luci::CircleMinimum>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_min, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_min = dynamic_cast<luci::CircleMinimum *>(cloned);
+ ASSERT_NE(nullptr, cloned_min);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleMirrorPad.cpp b/compiler/luci/service/src/Nodes/CircleMirrorPad.cpp
new file mode 100644
index 000000000..919221a0b
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMirrorPad.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleMirrorPad *node)
+{
+ if (node->mode() == luci::MirrorPadMode::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleMirrorPad>();
+ if (cloned != nullptr)
+ cloned->mode(node->mode());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleMirrorPad.test.cpp b/compiler/luci/service/src/Nodes/CircleMirrorPad.test.cpp
new file mode 100644
index 000000000..911cf6d3b
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMirrorPad.test.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_MirrorPad)
+{
+ auto g = loco::make_graph();
+ auto node_mp = g->nodes()->create<luci::CircleMirrorPad>();
+ node_mp->mode(luci::MirrorPadMode::REFLECT);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_mp, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_mp = dynamic_cast<luci::CircleMirrorPad *>(cloned);
+ ASSERT_NE(nullptr, cloned_mp);
+ ASSERT_EQ(node_mp->mode(), cloned_mp->mode());
+}
+
+TEST(CloneNodeTest, clone_MirrorPad_mode_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_mp = g->nodes()->create<luci::CircleMirrorPad>();
+ node_mp->mode(luci::MirrorPadMode::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_mp, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleMul.cpp b/compiler/luci/service/src/Nodes/CircleMul.cpp
new file mode 100644
index 000000000..096aed196
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMul.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleMul *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleMul>();
+ if (cloned != nullptr)
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleMul.test.cpp b/compiler/luci/service/src/Nodes/CircleMul.test.cpp
new file mode 100644
index 000000000..dc5565f11
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMul.test.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Mul)
+{
+ auto g = loco::make_graph();
+ auto node_mul = g->nodes()->create<luci::CircleMul>();
+ node_mul->fusedActivationFunction(luci::FusedActFunc::RELU);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_mul, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_mul = dynamic_cast<luci::CircleMul *>(cloned);
+ ASSERT_NE(nullptr, cloned_mul);
+ ASSERT_EQ(node_mul->fusedActivationFunction(), cloned_mul->fusedActivationFunction());
+}
+
+TEST(CloneNodeTest, clone_Mul_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_mul = g->nodes()->create<luci::CircleMul>();
+ node_mul->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_mul, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleNeg.cpp b/compiler/luci/service/src/Nodes/CircleNeg.cpp
new file mode 100644
index 000000000..312189e77
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleNeg.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleNeg *)
+{
+ return _graph->nodes()->create<luci::CircleNeg>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleNeg.test.cpp b/compiler/luci/service/src/Nodes/CircleNeg.test.cpp
new file mode 100644
index 000000000..8c2880324
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleNeg.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Neg)
+{
+ auto g = loco::make_graph();
+ auto node_neg = g->nodes()->create<luci::CircleNeg>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_neg, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_neg = dynamic_cast<luci::CircleNeg *>(cloned);
+ ASSERT_NE(nullptr, cloned_neg);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4.cpp
new file mode 100644
index 000000000..4757e8314
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleNonMaxSuppressionV4 *)
+{
+ return _graph->nodes()->create<luci::CircleNonMaxSuppressionV4>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4.test.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4.test.cpp
new file mode 100644
index 000000000..34f5b0325
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_NonMaxSuppressionV4)
+{
+ auto g = loco::make_graph();
+ auto node_nms = g->nodes()->create<luci::CircleNonMaxSuppressionV4>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_nms, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_nms = dynamic_cast<luci::CircleNonMaxSuppressionV4 *>(cloned);
+ ASSERT_NE(nullptr, cloned_nms);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4Out.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4Out.cpp
new file mode 100644
index 000000000..2a12f2a45
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4Out.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleNonMaxSuppressionV4Out *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleNonMaxSuppressionV4Out>();
+ if (cloned != nullptr)
+ cloned->index(node->index());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4Out.test.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4Out.test.cpp
new file mode 100644
index 000000000..ed9e0e019
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV4Out.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_NonMaxSuppressionV4Out)
+{
+ auto g = loco::make_graph();
+ auto node_nout = g->nodes()->create<luci::CircleNonMaxSuppressionV4Out>();
+ node_nout->index(1);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_nout, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_nout = dynamic_cast<luci::CircleNonMaxSuppressionV4Out *>(cloned);
+ ASSERT_NE(nullptr, cloned_nout);
+ ASSERT_EQ(node_nout->index(), cloned_nout->index());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5.cpp
new file mode 100644
index 000000000..34d128072
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleNonMaxSuppressionV5 *)
+{
+ return _graph->nodes()->create<luci::CircleNonMaxSuppressionV5>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5.test.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5.test.cpp
new file mode 100644
index 000000000..faaee969e
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_NonMaxSuppressionV5)
+{
+ auto g = loco::make_graph();
+ auto node_nms = g->nodes()->create<luci::CircleNonMaxSuppressionV5>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_nms, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_nms = dynamic_cast<luci::CircleNonMaxSuppressionV5 *>(cloned);
+ ASSERT_NE(nullptr, cloned_nms);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5Out.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5Out.cpp
new file mode 100644
index 000000000..e1d7875e7
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5Out.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleNonMaxSuppressionV5Out *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleNonMaxSuppressionV5Out>();
+ if (cloned != nullptr)
+ cloned->index(node->index());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5Out.test.cpp b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5Out.test.cpp
new file mode 100644
index 000000000..ef0f766b9
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleNonMaxSuppressionV5Out.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_NonMaxSuppressionV5Out)
+{
+ auto g = loco::make_graph();
+ auto node_nout = g->nodes()->create<luci::CircleNonMaxSuppressionV5Out>();
+ node_nout->index(1);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_nout, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_nout = dynamic_cast<luci::CircleNonMaxSuppressionV5Out *>(cloned);
+ ASSERT_NE(nullptr, cloned_nout);
+ ASSERT_EQ(node_nout->index(), cloned_nout->index());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleNotEqual.cpp b/compiler/luci/service/src/Nodes/CircleNotEqual.cpp
new file mode 100644
index 000000000..4cb5320e8
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleNotEqual.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleNotEqual *)
+{
+ return _graph->nodes()->create<luci::CircleNotEqual>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleNotEqual.test.cpp b/compiler/luci/service/src/Nodes/CircleNotEqual.test.cpp
new file mode 100644
index 000000000..20f7dbc4b
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleNotEqual.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_NotEqual)
+{
+ auto g = loco::make_graph();
+ auto node_ne = g->nodes()->create<luci::CircleNotEqual>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_ne, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_ne = dynamic_cast<luci::CircleNotEqual *>(cloned);
+ ASSERT_NE(nullptr, cloned_ne);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleOneHot.cpp b/compiler/luci/service/src/Nodes/CircleOneHot.cpp
new file mode 100644
index 000000000..a33c8ff26
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleOneHot.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleOneHot *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleOneHot>();
+ if (cloned != nullptr)
+ cloned->axis(node->axis());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleOneHot.test.cpp b/compiler/luci/service/src/Nodes/CircleOneHot.test.cpp
new file mode 100644
index 000000000..dea927d1b
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleOneHot.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_OneHot)
+{
+ auto g = loco::make_graph();
+ auto node_oh = g->nodes()->create<luci::CircleOneHot>();
+ node_oh->axis(3);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_oh, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_oh = dynamic_cast<luci::CircleOneHot *>(cloned);
+ ASSERT_NE(nullptr, cloned_oh);
+ ASSERT_EQ(node_oh->axis(), cloned_oh->axis());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp b/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp
index e0f13c439..ce94dff94 100644
--- a/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp
+++ b/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp
@@ -1,11 +1,11 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -14,11 +14,14 @@
* limitations under the License.
*/
-#include <luci/Service/CircleShapeSignatureInference.h>
+#include "CircleCloneNode.h"
namespace luci
{
-ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutputDummy *) { return ShapeSignature(); }
+luci::CircleNode *CloneNode::visit(const luci::CircleOutputDummy *)
+{
+ return _graph->nodes()->create<luci::CircleOutputDummy>();
+}
} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleOutputDummy.test.cpp b/compiler/luci/service/src/Nodes/CircleOutputDummy.test.cpp
new file mode 100644
index 000000000..6170c7c41
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleOutputDummy.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_OutputDummy)
+{
+ auto g = loco::make_graph();
+ auto node_dummy = g->nodes()->create<luci::CircleOutputDummy>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_dummy, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_dummy = dynamic_cast<luci::CircleOutputDummy *>(cloned);
+ ASSERT_NE(nullptr, cloned_dummy);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp b/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp
index 75bbbb3c0..1b0f919c3 100644
--- a/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp
+++ b/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp
@@ -1,11 +1,11 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -14,14 +14,14 @@
* limitations under the License.
*/
-#include <luci/Service/CircleShapeSignatureInference.h>
+#include "CircleCloneNode.h"
namespace luci
{
-ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutputExclude *)
+luci::CircleNode *CloneNode::visit(const luci::CircleOutputExclude *)
{
- return ShapeSignature();
+ return _graph->nodes()->create<luci::CircleOutputExclude>();
}
} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleOutputExclude.test.cpp b/compiler/luci/service/src/Nodes/CircleOutputExclude.test.cpp
new file mode 100644
index 000000000..120ffe86b
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleOutputExclude.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_OutputExclude)
+{
+ auto g = loco::make_graph();
+ auto node_outex = g->nodes()->create<luci::CircleOutputExclude>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_outex, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_outex = dynamic_cast<luci::CircleOutputExclude *>(cloned);
+ ASSERT_NE(nullptr, cloned_outex);
+}
diff --git a/compiler/luci/service/src/Nodes/CirclePRelu.cpp b/compiler/luci/service/src/Nodes/CirclePRelu.cpp
new file mode 100644
index 000000000..8a34e507e
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CirclePRelu.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CirclePRelu *)
+{
+ return _graph->nodes()->create<luci::CirclePRelu>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CirclePRelu.test.cpp b/compiler/luci/service/src/Nodes/CirclePRelu.test.cpp
new file mode 100644
index 000000000..1150e3fa4
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CirclePRelu.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_PRelu)
+{
+ auto g = loco::make_graph();
+ auto node_pr = g->nodes()->create<luci::CirclePRelu>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_pr, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_pr = dynamic_cast<luci::CirclePRelu *>(cloned);
+ ASSERT_NE(nullptr, cloned_pr);
+}
diff --git a/compiler/luci/service/src/Nodes/CirclePack.cpp b/compiler/luci/service/src/Nodes/CirclePack.cpp
new file mode 100644
index 000000000..a3cee0bfd
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CirclePack.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CirclePack *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CirclePack>(node->values_count());
+ if (cloned != nullptr)
+ cloned->axis(node->axis());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CirclePack.test.cpp b/compiler/luci/service/src/Nodes/CirclePack.test.cpp
new file mode 100644
index 000000000..b808956dc
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CirclePack.test.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Pack)
+{
+ auto g = loco::make_graph();
+ auto node_pack = g->nodes()->create<luci::CirclePack>(3);
+ node_pack->axis(7);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_pack, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_pack = dynamic_cast<luci::CirclePack *>(cloned);
+ ASSERT_NE(nullptr, cloned_pack);
+ ASSERT_EQ(node_pack->values_count(), cloned_pack->values_count());
+ ASSERT_EQ(node_pack->axis(), cloned_pack->axis());
+}
diff --git a/compiler/luci/service/src/Nodes/CirclePad.cpp b/compiler/luci/service/src/Nodes/CirclePad.cpp
new file mode 100644
index 000000000..425bdce4d
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CirclePad.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CirclePad *)
+{
+ return _graph->nodes()->create<luci::CirclePad>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CirclePad.test.cpp b/compiler/luci/service/src/Nodes/CirclePad.test.cpp
new file mode 100644
index 000000000..1d5f8375e
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CirclePad.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Pad)
+{
+ auto g = loco::make_graph();
+ auto node_pad = g->nodes()->create<luci::CirclePad>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_pad, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_pad = dynamic_cast<luci::CirclePad *>(cloned);
+ ASSERT_NE(nullptr, cloned_pad);
+}
diff --git a/compiler/luci/service/src/Nodes/CirclePadV2.cpp b/compiler/luci/service/src/Nodes/CirclePadV2.cpp
new file mode 100644
index 000000000..0e93869b6
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CirclePadV2.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CirclePadV2 *)
+{
+ return _graph->nodes()->create<luci::CirclePadV2>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CirclePadV2.test.cpp b/compiler/luci/service/src/Nodes/CirclePadV2.test.cpp
new file mode 100644
index 000000000..d011f69f8
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CirclePadV2.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_PadV2)
+{
+ auto g = loco::make_graph();
+ auto node_pad = g->nodes()->create<luci::CirclePadV2>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_pad, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_pad = dynamic_cast<luci::CirclePadV2 *>(cloned);
+ ASSERT_NE(nullptr, cloned_pad);
+}
diff --git a/compiler/luci/service/src/Nodes/CirclePow.cpp b/compiler/luci/service/src/Nodes/CirclePow.cpp
new file mode 100644
index 000000000..bf9388913
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CirclePow.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CirclePow *)
+{
+ return _graph->nodes()->create<luci::CirclePow>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CirclePow.test.cpp b/compiler/luci/service/src/Nodes/CirclePow.test.cpp
new file mode 100644
index 000000000..946298932
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CirclePow.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Pow)
+{
+ auto g = loco::make_graph();
+ auto node_pow = g->nodes()->create<luci::CirclePow>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_pow, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_pow = dynamic_cast<luci::CirclePow *>(cloned);
+ ASSERT_NE(nullptr, cloned_pow);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleRange.cpp b/compiler/luci/service/src/Nodes/CircleRange.cpp
new file mode 100644
index 000000000..9c6f7b494
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleRange.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleRange *)
+{
+ return _graph->nodes()->create<luci::CircleRange>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleRange.test.cpp b/compiler/luci/service/src/Nodes/CircleRange.test.cpp
new file mode 100644
index 000000000..b2fb29617
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleRange.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Range)
+{
+ auto g = loco::make_graph();
+ auto node_range = g->nodes()->create<luci::CircleRange>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_range, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_range = dynamic_cast<luci::CircleRange *>(cloned);
+ ASSERT_NE(nullptr, cloned_range);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleRank.cpp b/compiler/luci/service/src/Nodes/CircleRank.cpp
new file mode 100644
index 000000000..db8171c51
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleRank.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleRank *)
+{
+ return _graph->nodes()->create<luci::CircleRank>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleRank.test.cpp b/compiler/luci/service/src/Nodes/CircleRank.test.cpp
new file mode 100644
index 000000000..0e81fb254
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleRank.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Rank)
+{
+ auto g = loco::make_graph();
+ auto node_rank = g->nodes()->create<luci::CircleRank>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_rank, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_rank = dynamic_cast<luci::CircleRank *>(cloned);
+ ASSERT_NE(nullptr, cloned_rank);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleReduceAny.cpp b/compiler/luci/service/src/Nodes/CircleReduceAny.cpp
index 27da81466..3ab0b3b59 100644
--- a/compiler/luci/service/src/Nodes/CircleReduceAny.cpp
+++ b/compiler/luci/service/src/Nodes/CircleReduceAny.cpp
@@ -1,11 +1,11 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -14,15 +14,17 @@
* limitations under the License.
*/
-#include <luci/Service/CircleShapeSignatureInference.h>
+#include "CircleCloneNode.h"
namespace luci
{
-ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceAny *node)
+luci::CircleNode *CloneNode::visit(const luci::CircleReduceAny *node)
{
- return legalized_signature(
- reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+ auto *cloned = _graph->nodes()->create<luci::CircleReduceAny>();
+ if (cloned != nullptr)
+ cloned->keep_dims(node->keep_dims());
+ return cloned;
}
} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReduceAny.test.cpp b/compiler/luci/service/src/Nodes/CircleReduceAny.test.cpp
new file mode 100644
index 000000000..904b5a139
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReduceAny.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_ReduceAny)
+{
+ auto g = loco::make_graph();
+ auto node_ra = g->nodes()->create<luci::CircleReduceAny>();
+ node_ra->keep_dims(true);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_ra, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_ra = dynamic_cast<luci::CircleReduceAny *>(cloned);
+ ASSERT_NE(nullptr, cloned_ra);
+ ASSERT_EQ(node_ra->keep_dims(), cloned_ra->keep_dims());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleReduceMax.cpp b/compiler/luci/service/src/Nodes/CircleReduceMax.cpp
index 48d9cb970..c026905ca 100644
--- a/compiler/luci/service/src/Nodes/CircleReduceMax.cpp
+++ b/compiler/luci/service/src/Nodes/CircleReduceMax.cpp
@@ -1,11 +1,11 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -14,15 +14,17 @@
* limitations under the License.
*/
-#include <luci/Service/CircleShapeSignatureInference.h>
+#include "CircleCloneNode.h"
namespace luci
{
-ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceMax *node)
+luci::CircleNode *CloneNode::visit(const luci::CircleReduceMax *node)
{
- return legalized_signature(
- reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+ auto *cloned = _graph->nodes()->create<luci::CircleReduceMax>();
+ if (cloned != nullptr)
+ cloned->keep_dims(node->keep_dims());
+ return cloned;
}
} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReduceMax.test.cpp b/compiler/luci/service/src/Nodes/CircleReduceMax.test.cpp
new file mode 100644
index 000000000..b3f3c881e
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReduceMax.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_ReduceMax)
+{
+ auto g = loco::make_graph();
+ auto node_rmax = g->nodes()->create<luci::CircleReduceMax>();
+ node_rmax->keep_dims(true);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_rmax, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_rmax = dynamic_cast<luci::CircleReduceMax *>(cloned);
+ ASSERT_NE(nullptr, cloned_rmax);
+ ASSERT_EQ(node_rmax->keep_dims(), cloned_rmax->keep_dims());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleReduceMin.cpp b/compiler/luci/service/src/Nodes/CircleReduceMin.cpp
index 9a9997118..3dfa19680 100644
--- a/compiler/luci/service/src/Nodes/CircleReduceMin.cpp
+++ b/compiler/luci/service/src/Nodes/CircleReduceMin.cpp
@@ -1,11 +1,11 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -14,15 +14,17 @@
* limitations under the License.
*/
-#include <luci/Service/CircleShapeSignatureInference.h>
+#include "CircleCloneNode.h"
namespace luci
{
-ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceMin *node)
+luci::CircleNode *CloneNode::visit(const luci::CircleReduceMin *node)
{
- return legalized_signature(
- reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+ auto *cloned = _graph->nodes()->create<luci::CircleReduceMin>();
+ if (cloned != nullptr)
+ cloned->keep_dims(node->keep_dims());
+ return cloned;
}
} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReduceMin.test.cpp b/compiler/luci/service/src/Nodes/CircleReduceMin.test.cpp
new file mode 100644
index 000000000..b3faa68da
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReduceMin.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_ReduceMin)
+{
+ auto g = loco::make_graph();
+ auto node_rmin = g->nodes()->create<luci::CircleReduceMin>();
+ node_rmin->keep_dims(true);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_rmin, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_rmin = dynamic_cast<luci::CircleReduceMin *>(cloned);
+ ASSERT_NE(nullptr, cloned_rmin);
+ ASSERT_EQ(node_rmin->keep_dims(), cloned_rmin->keep_dims());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleReduceProd.cpp b/compiler/luci/service/src/Nodes/CircleReduceProd.cpp
index a9d381a74..418a8ce32 100644
--- a/compiler/luci/service/src/Nodes/CircleReduceProd.cpp
+++ b/compiler/luci/service/src/Nodes/CircleReduceProd.cpp
@@ -1,11 +1,11 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -14,15 +14,17 @@
* limitations under the License.
*/
-#include <luci/Service/CircleShapeSignatureInference.h>
+#include "CircleCloneNode.h"
namespace luci
{
-ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceProd *node)
+luci::CircleNode *CloneNode::visit(const luci::CircleReduceProd *node)
{
- return legalized_signature(
- reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+ auto *cloned = _graph->nodes()->create<luci::CircleReduceProd>();
+ if (cloned != nullptr)
+ cloned->keep_dims(node->keep_dims());
+ return cloned;
}
} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReduceProd.test.cpp b/compiler/luci/service/src/Nodes/CircleReduceProd.test.cpp
new file mode 100644
index 000000000..8caf8e91f
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReduceProd.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_ReduceProd)
+{
+ auto g = loco::make_graph();
+ auto node_rp = g->nodes()->create<luci::CircleReduceProd>();
+ node_rp->keep_dims(true);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_rp, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_rp = dynamic_cast<luci::CircleReduceProd *>(cloned);
+ ASSERT_NE(nullptr, cloned_rp);
+ ASSERT_EQ(node_rp->keep_dims(), cloned_rp->keep_dims());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleRelu.cpp b/compiler/luci/service/src/Nodes/CircleRelu.cpp
index a7a7f6f0a..7447eea0c 100644
--- a/compiler/luci/service/src/Nodes/CircleRelu.cpp
+++ b/compiler/luci/service/src/Nodes/CircleRelu.cpp
@@ -1,11 +1,11 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -14,14 +14,14 @@
* limitations under the License.
*/
-#include <luci/Service/CircleShapeSignatureInference.h>
+#include "CircleCloneNode.h"
namespace luci
{
-ShapeSignature ssinf::Algorithm::visit(const luci::CircleRelu *node)
+luci::CircleNode *CloneNode::visit(const luci::CircleRelu *)
{
- return input_arg_signature(node, 0);
+ return _graph->nodes()->create<luci::CircleRelu>();
}
} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleRelu.test.cpp b/compiler/luci/service/src/Nodes/CircleRelu.test.cpp
new file mode 100644
index 000000000..6154376ba
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleRelu.test.cpp
@@ -0,0 +1,74 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Service/CircleShapeInference.h>
+#include <luci/Service/CircleTypeInference.h>
+
+#include <loco/IR/TensorShape.h>
+
+#include <gtest/gtest.h>
+
+TEST(ShapeRuleTest, simple_relu)
+{
+ luci::CircleInput input;
+ luci::CircleRelu relu;
+
+ input.shape({3, 4});
+ input.shape_status(luci::ShapeStatus::VALID);
+
+ relu.features(&input);
+
+ loco::TensorShape shape;
+ luci::sinf::Rule shape_inf_rule;
+
+ ASSERT_TRUE(shape_inf_rule.infer(&relu, shape));
+ ASSERT_EQ(2, shape.rank());
+ ASSERT_EQ(3, shape.dim(0).value());
+ ASSERT_EQ(4, shape.dim(1).value());
+}
+
+TEST(DataTypeRuleTest, simple_relu)
+{
+ luci::CircleInput input;
+ luci::CircleRelu relu;
+
+ input.dtype(loco::DataType::S32);
+
+ relu.features(&input);
+
+ loco::DataType dtype;
+ luci::tinf::Rule type_inf_rule;
+
+ ASSERT_TRUE(type_inf_rule.infer(&relu, dtype));
+ ASSERT_EQ(loco::DataType::S32, dtype);
+}
+
+TEST(CloneNodeTest, clone_Relu)
+{
+ auto g = loco::make_graph();
+ auto node_relu = g->nodes()->create<luci::CircleRelu>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_relu, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_relu = dynamic_cast<luci::CircleRelu *>(cloned);
+ ASSERT_NE(nullptr, cloned_relu);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleRelu6.cpp b/compiler/luci/service/src/Nodes/CircleRelu6.cpp
index 92a596d08..7b98311ed 100644
--- a/compiler/luci/service/src/Nodes/CircleRelu6.cpp
+++ b/compiler/luci/service/src/Nodes/CircleRelu6.cpp
@@ -1,11 +1,11 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -14,14 +14,14 @@
* limitations under the License.
*/
-#include <luci/Service/CircleShapeSignatureInference.h>
+#include "CircleCloneNode.h"
namespace luci
{
-ShapeSignature ssinf::Algorithm::visit(const luci::CircleRelu6 *node)
+luci::CircleNode *CloneNode::visit(const luci::CircleRelu6 *)
{
- return input_arg_signature(node, 0);
+ return _graph->nodes()->create<luci::CircleRelu6>();
}
} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleRelu6.test.cpp b/compiler/luci/service/src/Nodes/CircleRelu6.test.cpp
new file mode 100644
index 000000000..213dbcb09
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleRelu6.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <loco/IR/TensorShape.h>
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Relu6)
+{
+ auto g = loco::make_graph();
+ auto node_relu6 = g->nodes()->create<luci::CircleRelu6>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_relu6, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_relu6 = dynamic_cast<luci::CircleRelu6 *>(cloned);
+ ASSERT_NE(nullptr, cloned_relu6);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp b/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp
index 1e8d9971d..4efedb9fc 100644
--- a/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp
+++ b/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp
@@ -1,11 +1,11 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -14,14 +14,14 @@
* limitations under the License.
*/
-#include <luci/Service/CircleShapeSignatureInference.h>
+#include "CircleCloneNode.h"
namespace luci
{
-ShapeSignature ssinf::Algorithm::visit(const luci::CircleReluN1To1 *node)
+luci::CircleNode *CloneNode::visit(const luci::CircleReluN1To1 *)
{
- return input_arg_signature(node, 0);
+ return _graph->nodes()->create<luci::CircleReluN1To1>();
}
} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReluN1To1.test.cpp b/compiler/luci/service/src/Nodes/CircleReluN1To1.test.cpp
new file mode 100644
index 000000000..b828e795c
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReluN1To1.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <loco/IR/TensorShape.h>
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_ReluN1To1)
+{
+ auto g = loco::make_graph();
+ auto node_relun1 = g->nodes()->create<luci::CircleReluN1To1>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_relun1, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_relun1 = dynamic_cast<luci::CircleReluN1To1 *>(cloned);
+ ASSERT_NE(nullptr, cloned_relun1);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleReshape.cpp b/compiler/luci/service/src/Nodes/CircleReshape.cpp
new file mode 100644
index 000000000..07a81b306
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReshape.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleReshape *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleReshape>();
+ if (cloned != nullptr)
+ {
+ uint32_t rank = node->newShape()->rank();
+ cloned->newShape()->rank(rank);
+ for (uint32_t r = 0; r < rank; ++r)
+ {
+ cloned->newShape()->dim(r) = node->newShape()->dim(r);
+ }
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp
new file mode 100644
index 000000000..ca92b717d
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Reshape)
+{
+ auto g = loco::make_graph();
+ auto node_reshape = g->nodes()->create<luci::CircleReshape>();
+ node_reshape->newShape()->rank(2);
+ node_reshape->newShape()->dim(0) = 3;
+ node_reshape->newShape()->dim(1) = 4;
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_reshape, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_reshape = dynamic_cast<luci::CircleReshape *>(cloned);
+ ASSERT_NE(nullptr, cloned_reshape);
+ ASSERT_EQ(node_reshape->newShape()->rank(), cloned_reshape->newShape()->rank());
+ ASSERT_EQ(node_reshape->newShape()->dim(0), cloned_reshape->newShape()->dim(0));
+ ASSERT_EQ(node_reshape->newShape()->dim(1), cloned_reshape->newShape()->dim(1));
+}
diff --git a/compiler/luci/service/src/Nodes/CircleResizeBilinear.cpp b/compiler/luci/service/src/Nodes/CircleResizeBilinear.cpp
new file mode 100644
index 000000000..55d21af45
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleResizeBilinear.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleResizeBilinear *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleResizeBilinear>();
+ if (cloned != nullptr)
+ {
+ cloned->align_corners(node->align_corners());
+ cloned->half_pixel_centers(node->half_pixel_centers());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleResizeBilinear.test.cpp b/compiler/luci/service/src/Nodes/CircleResizeBilinear.test.cpp
new file mode 100644
index 000000000..bff71261d
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleResizeBilinear.test.cpp
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Service/CircleShapeInference.h>
+
+#include <loco/IR/TensorShape.h>
+
+#include <gtest/gtest.h>
+
+TEST(ShapeRuleTest, resize_bilinear_simple)
+{
+ luci::CircleInput input;
+ luci::CircleConst rb_size;
+ luci::CircleResizeBilinear rb;
+
+ input.shape({1, 4, 4, 3});
+ input.shape_status(luci::ShapeStatus::VALID);
+
+ rb_size.dtype(loco::DataType::S32);
+ rb_size.rank(1);
+ rb_size.dim(0).set(2);
+ rb_size.size<loco::DataType::S32>(2);
+ rb_size.at<loco::DataType::S32>(0) = 16;
+ rb_size.at<loco::DataType::S32>(1) = 16;
+ rb_size.shape_status(luci::ShapeStatus::VALID);
+
+ rb.input(&input);
+ rb.size(&rb_size);
+
+ loco::TensorShape shape;
+ luci::sinf::Rule shape_inf_rule;
+
+ ASSERT_TRUE(shape_inf_rule.infer(&rb, shape));
+ ASSERT_EQ(4, shape.rank());
+ ASSERT_EQ(1, shape.dim(0).value());
+ ASSERT_EQ(16, shape.dim(1).value());
+ ASSERT_EQ(16, shape.dim(2).value());
+ ASSERT_EQ(3, shape.dim(3).value());
+}
+
+TEST(CloneNodeTest, clone_ResizeBilinear)
+{
+ auto g = loco::make_graph();
+ auto node_rb = g->nodes()->create<luci::CircleResizeBilinear>();
+ node_rb->align_corners(true);
+ node_rb->half_pixel_centers(true);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_rb, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_rb = dynamic_cast<luci::CircleResizeBilinear *>(cloned);
+ ASSERT_NE(nullptr, cloned_rb);
+ ASSERT_EQ(node_rb->align_corners(), cloned_rb->align_corners());
+ ASSERT_EQ(node_rb->half_pixel_centers(), cloned_rb->half_pixel_centers());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleResizeNearestNeighbor.cpp b/compiler/luci/service/src/Nodes/CircleResizeNearestNeighbor.cpp
new file mode 100644
index 000000000..5727786a7
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleResizeNearestNeighbor.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleResizeNearestNeighbor *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleResizeNearestNeighbor>();
+ if (cloned != nullptr)
+ cloned->align_corners(node->align_corners());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleResizeNearestNeighbor.test.cpp b/compiler/luci/service/src/Nodes/CircleResizeNearestNeighbor.test.cpp
new file mode 100644
index 000000000..a1d781c65
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleResizeNearestNeighbor.test.cpp
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Service/CircleShapeInference.h>
+
+#include <loco/IR/TensorShape.h>
+
+#include <gtest/gtest.h>
+
+TEST(ShapeRuleTest, resize_nearest_neighbor_simple)
+{
+ luci::CircleInput input;
+ luci::CircleConst rnn_size;
+ luci::CircleResizeNearestNeighbor rnn;
+
+ input.shape({1, 4, 4, 3});
+ input.shape_status(luci::ShapeStatus::VALID);
+
+ rnn_size.dtype(loco::DataType::S32);
+ rnn_size.rank(1);
+ rnn_size.dim(0).set(2);
+ rnn_size.size<loco::DataType::S32>(2);
+ rnn_size.at<loco::DataType::S32>(0) = 16;
+ rnn_size.at<loco::DataType::S32>(1) = 16;
+ rnn_size.shape_status(luci::ShapeStatus::VALID);
+
+ rnn.input(&input);
+ rnn.size(&rnn_size);
+
+ loco::TensorShape shape;
+ luci::sinf::Rule shape_inf_rule;
+
+ ASSERT_TRUE(shape_inf_rule.infer(&rnn, shape));
+ ASSERT_EQ(4, shape.rank());
+ ASSERT_EQ(1, shape.dim(0).value());
+ ASSERT_EQ(16, shape.dim(1).value());
+ ASSERT_EQ(16, shape.dim(2).value());
+ ASSERT_EQ(3, shape.dim(3).value());
+}
+
+TEST(CloneNodeTest, clone_ResizeNearestNeighbor)
+{
+ auto g = loco::make_graph();
+ auto node_rnn = g->nodes()->create<luci::CircleResizeNearestNeighbor>();
+ node_rnn->align_corners(true);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_rnn, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_rnn = dynamic_cast<luci::CircleResizeNearestNeighbor *>(cloned);
+ ASSERT_NE(nullptr, cloned_rnn);
+ ASSERT_EQ(node_rnn->align_corners(), cloned_rnn->align_corners());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleReverseSequence.cpp b/compiler/luci/service/src/Nodes/CircleReverseSequence.cpp
new file mode 100644
index 000000000..6e6919b0c
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReverseSequence.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleReverseSequence *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleReverseSequence>();
+ if (cloned != nullptr)
+ {
+ cloned->seq_axis(node->seq_axis());
+ cloned->batch_axis(node->batch_axis());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReverseSequence.test.cpp b/compiler/luci/service/src/Nodes/CircleReverseSequence.test.cpp
new file mode 100644
index 000000000..a7a8e3949
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReverseSequence.test.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_ReverseSequence)
+{
+ auto g = loco::make_graph();
+ auto node_rs = g->nodes()->create<luci::CircleReverseSequence>();
+ node_rs->seq_axis(1);
+ node_rs->batch_axis(2);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_rs, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_rs = dynamic_cast<luci::CircleReverseSequence *>(cloned);
+ ASSERT_NE(nullptr, cloned_rs);
+ ASSERT_EQ(node_rs->seq_axis(), cloned_rs->seq_axis());
+ ASSERT_EQ(node_rs->batch_axis(), cloned_rs->batch_axis());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleReverseV2.cpp b/compiler/luci/service/src/Nodes/CircleReverseV2.cpp
new file mode 100644
index 000000000..e8fee6c3e
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReverseV2.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleReverseV2 *)
+{
+ return _graph->nodes()->create<luci::CircleReverseV2>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReverseV2.test.cpp b/compiler/luci/service/src/Nodes/CircleReverseV2.test.cpp
new file mode 100644
index 000000000..0e5ff933c
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReverseV2.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_ReverseV2)
+{
+ auto g = loco::make_graph();
+ auto node_rev = g->nodes()->create<luci::CircleReverseV2>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_rev, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_rev = dynamic_cast<luci::CircleReverseV2 *>(cloned);
+ ASSERT_NE(nullptr, cloned_rev);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleRound.cpp b/compiler/luci/service/src/Nodes/CircleRound.cpp
new file mode 100644
index 000000000..2c23f2df6
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleRound.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleRound *)
+{
+ return _graph->nodes()->create<luci::CircleRound>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleRound.test.cpp b/compiler/luci/service/src/Nodes/CircleRound.test.cpp
new file mode 100644
index 000000000..2c2c3a9d0
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleRound.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Round)
+{
+ auto g = loco::make_graph();
+ auto node_rnd = g->nodes()->create<luci::CircleRound>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_rnd, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_rnd = dynamic_cast<luci::CircleRound *>(cloned);
+ ASSERT_NE(nullptr, cloned_rnd);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleRsqrt.cpp b/compiler/luci/service/src/Nodes/CircleRsqrt.cpp
new file mode 100644
index 000000000..aca702fe1
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleRsqrt.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleRsqrt *)
+{
+ return _graph->nodes()->create<luci::CircleRsqrt>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleRsqrt.test.cpp b/compiler/luci/service/src/Nodes/CircleRsqrt.test.cpp
new file mode 100644
index 000000000..3e4ced562
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleRsqrt.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Rsqrt)
+{
+ auto g = loco::make_graph();
+ auto node_rsqrt = g->nodes()->create<luci::CircleRsqrt>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_rsqrt, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_rsqrt = dynamic_cast<luci::CircleRsqrt *>(cloned);
+ ASSERT_NE(nullptr, cloned_rsqrt);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleScatterNd.cpp b/compiler/luci/service/src/Nodes/CircleScatterNd.cpp
new file mode 100644
index 000000000..6c477a598
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleScatterNd.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleScatterNd *)
+{
+ return _graph->nodes()->create<luci::CircleScatterNd>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleScatterNd.test.cpp b/compiler/luci/service/src/Nodes/CircleScatterNd.test.cpp
new file mode 100644
index 000000000..ce63603cc
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleScatterNd.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_ScatterNd)
+{
+ auto g = loco::make_graph();
+ auto node_snd = g->nodes()->create<luci::CircleScatterNd>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_snd, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_snd = dynamic_cast<luci::CircleScatterNd *>(cloned);
+ ASSERT_NE(nullptr, cloned_snd);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSegmentSum.cpp b/compiler/luci/service/src/Nodes/CircleSegmentSum.cpp
new file mode 100644
index 000000000..aa4001f57
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSegmentSum.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSegmentSum *)
+{
+ return _graph->nodes()->create<luci::CircleSegmentSum>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSegmentSum.test.cpp b/compiler/luci/service/src/Nodes/CircleSegmentSum.test.cpp
new file mode 100644
index 000000000..ff17b0745
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSegmentSum.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_SegmentSum)
+{
+ auto g = loco::make_graph();
+ auto node_ss = g->nodes()->create<luci::CircleSegmentSum>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_ss, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_ss = dynamic_cast<luci::CircleSegmentSum *>(cloned);
+ ASSERT_NE(nullptr, cloned_ss);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSelect.cpp b/compiler/luci/service/src/Nodes/CircleSelect.cpp
new file mode 100644
index 000000000..71b31d33f
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSelect.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSelect *)
+{
+ return _graph->nodes()->create<luci::CircleSelect>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSelect.test.cpp b/compiler/luci/service/src/Nodes/CircleSelect.test.cpp
new file mode 100644
index 000000000..e8d631618
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSelect.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Select)
+{
+ auto g = loco::make_graph();
+ auto node_sel = g->nodes()->create<luci::CircleSelect>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_sel, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_sel = dynamic_cast<luci::CircleSelect *>(cloned);
+ ASSERT_NE(nullptr, cloned_sel);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSelectV2.cpp b/compiler/luci/service/src/Nodes/CircleSelectV2.cpp
new file mode 100644
index 000000000..07af40c40
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSelectV2.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSelectV2 *)
+{
+ return _graph->nodes()->create<luci::CircleSelectV2>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSelectV2.test.cpp b/compiler/luci/service/src/Nodes/CircleSelectV2.test.cpp
new file mode 100644
index 000000000..253dba555
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSelectV2.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_SelectV2)
+{
+ auto g = loco::make_graph();
+ auto node_sel = g->nodes()->create<luci::CircleSelectV2>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_sel, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_sel = dynamic_cast<luci::CircleSelectV2 *>(cloned);
+ ASSERT_NE(nullptr, cloned_sel);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleShape.cpp b/compiler/luci/service/src/Nodes/CircleShape.cpp
new file mode 100644
index 000000000..e5b5fa28f
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleShape.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleShape *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleShape>();
+ if (cloned != nullptr)
+ cloned->out_type(node->out_type());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleShape.test.cpp b/compiler/luci/service/src/Nodes/CircleShape.test.cpp
new file mode 100644
index 000000000..ec057bd05
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleShape.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Shape)
+{
+ auto g = loco::make_graph();
+ auto node_shape = g->nodes()->create<luci::CircleShape>();
+ node_shape->out_type(loco::DataType::S32);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_shape, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_shape = dynamic_cast<luci::CircleShape *>(cloned);
+ ASSERT_NE(nullptr, cloned_shape);
+ ASSERT_EQ(node_shape->out_type(), cloned_shape->out_type());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSin.cpp b/compiler/luci/service/src/Nodes/CircleSin.cpp
new file mode 100644
index 000000000..46a07d21d
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSin.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSin *)
+{
+ return _graph->nodes()->create<luci::CircleSin>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSin.test.cpp b/compiler/luci/service/src/Nodes/CircleSin.test.cpp
new file mode 100644
index 000000000..b072e7e2c
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSin.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Sin)
+{
+ auto g = loco::make_graph();
+ auto node_sin = g->nodes()->create<luci::CircleSin>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_sin, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_sin = dynamic_cast<luci::CircleSin *>(cloned);
+ ASSERT_NE(nullptr, cloned_sin);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSlice.cpp b/compiler/luci/service/src/Nodes/CircleSlice.cpp
new file mode 100644
index 000000000..6b2f4a591
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSlice.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSlice *)
+{
+ return _graph->nodes()->create<luci::CircleSlice>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSlice.test.cpp b/compiler/luci/service/src/Nodes/CircleSlice.test.cpp
new file mode 100644
index 000000000..48ec20304
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSlice.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Slice)
+{
+ auto g = loco::make_graph();
+ auto node_slice = g->nodes()->create<luci::CircleSlice>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_slice, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_slice = dynamic_cast<luci::CircleSlice *>(cloned);
+ ASSERT_NE(nullptr, cloned_slice);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSoftmax.cpp b/compiler/luci/service/src/Nodes/CircleSoftmax.cpp
new file mode 100644
index 000000000..359d1000c
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSoftmax.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSoftmax *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleSoftmax>();
+ if (cloned != nullptr)
+ cloned->beta(node->beta());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSoftmax.test.cpp b/compiler/luci/service/src/Nodes/CircleSoftmax.test.cpp
new file mode 100644
index 000000000..c80b44d69
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSoftmax.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Softmax)
+{
+ auto g = loco::make_graph();
+ auto node_sm = g->nodes()->create<luci::CircleSoftmax>();
+ node_sm->beta(2.3f);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_sm, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_sm = dynamic_cast<luci::CircleSoftmax *>(cloned);
+ ASSERT_NE(nullptr, cloned_sm);
+ ASSERT_EQ(node_sm->beta(), cloned_sm->beta());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSpaceToBatchND.cpp b/compiler/luci/service/src/Nodes/CircleSpaceToBatchND.cpp
new file mode 100644
index 000000000..feb4f3e37
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSpaceToBatchND.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSpaceToBatchND *)
+{
+ return _graph->nodes()->create<luci::CircleSpaceToBatchND>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSpaceToBatchND.test.cpp b/compiler/luci/service/src/Nodes/CircleSpaceToBatchND.test.cpp
new file mode 100644
index 000000000..eb743795d
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSpaceToBatchND.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_SpaceToBatchND)
+{
+ auto g = loco::make_graph();
+ auto node_s2bnd = g->nodes()->create<luci::CircleSpaceToBatchND>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_s2bnd, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_s2bnd = dynamic_cast<luci::CircleSpaceToBatchND *>(cloned);
+ ASSERT_NE(nullptr, cloned_s2bnd);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSpaceToDepth.cpp b/compiler/luci/service/src/Nodes/CircleSpaceToDepth.cpp
new file mode 100644
index 000000000..3a82f5c7a
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSpaceToDepth.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSpaceToDepth *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleSpaceToDepth>();
+ if (cloned != nullptr)
+ cloned->block_size(node->block_size());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSpaceToDepth.test.cpp b/compiler/luci/service/src/Nodes/CircleSpaceToDepth.test.cpp
new file mode 100644
index 000000000..fb544e6d7
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSpaceToDepth.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_SpaceToDepth)
+{
+ auto g = loco::make_graph();
+ auto node_s2d = g->nodes()->create<luci::CircleSpaceToDepth>();
+ node_s2d->block_size(32);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_s2d, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_s2d = dynamic_cast<luci::CircleSpaceToDepth *>(cloned);
+ ASSERT_NE(nullptr, cloned_s2d);
+ ASSERT_EQ(node_s2d->block_size(), cloned_s2d->block_size());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSparseToDense.cpp b/compiler/luci/service/src/Nodes/CircleSparseToDense.cpp
new file mode 100644
index 000000000..3dba1a542
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSparseToDense.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSparseToDense *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleSparseToDense>();
+ if (cloned != nullptr)
+ cloned->validate_indices(node->validate_indices());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSparseToDense.test.cpp b/compiler/luci/service/src/Nodes/CircleSparseToDense.test.cpp
new file mode 100644
index 000000000..177a469cd
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSparseToDense.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_SparseToDense)
+{
+ auto g = loco::make_graph();
+ auto node_s2d = g->nodes()->create<luci::CircleSparseToDense>();
+ node_s2d->validate_indices(true);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_s2d, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_s2d = dynamic_cast<luci::CircleSparseToDense *>(cloned);
+ ASSERT_NE(nullptr, cloned_s2d);
+ ASSERT_EQ(node_s2d->validate_indices(), cloned_s2d->validate_indices());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSplit.cpp b/compiler/luci/service/src/Nodes/CircleSplit.cpp
new file mode 100644
index 000000000..e68a24a1f
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSplit.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSplit *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleSplit>();
+ if (cloned != nullptr)
+ cloned->num_split(node->num_split());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSplit.test.cpp b/compiler/luci/service/src/Nodes/CircleSplit.test.cpp
new file mode 100644
index 000000000..9ee26b425
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSplit.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Split)
+{
+ auto g = loco::make_graph();
+ auto node_split = g->nodes()->create<luci::CircleSplit>();
+ node_split->num_split(5);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_split, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_split = dynamic_cast<luci::CircleSplit *>(cloned);
+ ASSERT_NE(nullptr, cloned_split);
+ ASSERT_EQ(node_split->num_split(), cloned_split->num_split());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSplitOut.cpp b/compiler/luci/service/src/Nodes/CircleSplitOut.cpp
new file mode 100644
index 000000000..024598892
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSplitOut.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSplitOut *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleSplitOut>();
+ if (cloned != nullptr)
+ cloned->index(node->index());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSplitOut.test.cpp b/compiler/luci/service/src/Nodes/CircleSplitOut.test.cpp
new file mode 100644
index 000000000..deec08804
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSplitOut.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_SplitOut)
+{
+ auto g = loco::make_graph();
+ auto node_sout = g->nodes()->create<luci::CircleSplitOut>();
+ node_sout->index(1);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_sout, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_sout = dynamic_cast<luci::CircleSplitOut *>(cloned);
+ ASSERT_NE(nullptr, cloned_sout);
+ ASSERT_EQ(node_sout->index(), cloned_sout->index());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSplitV.cpp b/compiler/luci/service/src/Nodes/CircleSplitV.cpp
new file mode 100644
index 000000000..de6c6cce6
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSplitV.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSplitV *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleSplitV>();
+ if (cloned != nullptr)
+ cloned->num_split(node->num_split());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSplitV.test.cpp b/compiler/luci/service/src/Nodes/CircleSplitV.test.cpp
new file mode 100644
index 000000000..d109a64aa
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSplitV.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_SplitV)
+{
+ auto g = loco::make_graph();
+ auto node_split = g->nodes()->create<luci::CircleSplitV>();
+ node_split->num_split(5);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_split, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_split = dynamic_cast<luci::CircleSplitV *>(cloned);
+ ASSERT_NE(nullptr, cloned_split);
+ ASSERT_EQ(node_split->num_split(), cloned_split->num_split());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSplitVOut.cpp b/compiler/luci/service/src/Nodes/CircleSplitVOut.cpp
new file mode 100644
index 000000000..f40eb0a47
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSplitVOut.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSplitVOut *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleSplitVOut>();
+ if (cloned != nullptr)
+ cloned->index(node->index());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSplitVOut.test.cpp b/compiler/luci/service/src/Nodes/CircleSplitVOut.test.cpp
new file mode 100644
index 000000000..ab5e9d6be
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSplitVOut.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_SplitVOut)
+{
+ auto g = loco::make_graph();
+ auto node_sout = g->nodes()->create<luci::CircleSplitVOut>();
+ node_sout->index(1);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_sout, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_sout = dynamic_cast<luci::CircleSplitVOut *>(cloned);
+ ASSERT_NE(nullptr, cloned_sout);
+ ASSERT_EQ(node_sout->index(), cloned_sout->index());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSqrt.cpp b/compiler/luci/service/src/Nodes/CircleSqrt.cpp
new file mode 100644
index 000000000..a3e63684b
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSqrt.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSqrt *)
+{
+ return _graph->nodes()->create<luci::CircleSqrt>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSqrt.test.cpp b/compiler/luci/service/src/Nodes/CircleSqrt.test.cpp
new file mode 100644
index 000000000..dbef839d6
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSqrt.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Sqrt)
+{
+ auto g = loco::make_graph();
+ auto node_sqrt = g->nodes()->create<luci::CircleSqrt>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_sqrt, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_sqrt = dynamic_cast<luci::CircleSqrt *>(cloned);
+ ASSERT_NE(nullptr, cloned_sqrt);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSquare.cpp b/compiler/luci/service/src/Nodes/CircleSquare.cpp
new file mode 100644
index 000000000..88bbed76c
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSquare.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSquare *)
+{
+ return _graph->nodes()->create<luci::CircleSquare>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSquare.test.cpp b/compiler/luci/service/src/Nodes/CircleSquare.test.cpp
new file mode 100644
index 000000000..67ac21210
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSquare.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Square)
+{
+ auto g = loco::make_graph();
+ auto node_squ = g->nodes()->create<luci::CircleSquare>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_squ, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_squ = dynamic_cast<luci::CircleSquare *>(cloned);
+ ASSERT_NE(nullptr, cloned_squ);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSquaredDifference.cpp b/compiler/luci/service/src/Nodes/CircleSquaredDifference.cpp
new file mode 100644
index 000000000..6becdf1c9
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSquaredDifference.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSquaredDifference *)
+{
+ return _graph->nodes()->create<luci::CircleSquaredDifference>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSquaredDifference.test.cpp b/compiler/luci/service/src/Nodes/CircleSquaredDifference.test.cpp
new file mode 100644
index 000000000..26099612b
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSquaredDifference.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_SquaredDifference)
+{
+ auto g = loco::make_graph();
+ auto node_sd = g->nodes()->create<luci::CircleSquaredDifference>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_sd, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_sd = dynamic_cast<luci::CircleSquaredDifference *>(cloned);
+ ASSERT_NE(nullptr, cloned_sd);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSqueeze.cpp b/compiler/luci/service/src/Nodes/CircleSqueeze.cpp
new file mode 100644
index 000000000..02ba5020c
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSqueeze.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSqueeze *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleSqueeze>();
+ if (cloned != nullptr)
+ cloned->squeeze_dims(node->squeeze_dims());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSqueeze.test.cpp b/compiler/luci/service/src/Nodes/CircleSqueeze.test.cpp
new file mode 100644
index 000000000..bc73eafa7
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSqueeze.test.cpp
@@ -0,0 +1,83 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Service/CircleShapeInference.h>
+
+#include <loco/IR/TensorShape.h>
+
+#include <gtest/gtest.h>
+
+TEST(ShapeRuleTest, squeeze_simple)
+{
+ luci::CircleInput input;
+ luci::CircleSqueeze squeeze;
+
+ input.shape({1, 4, 3, 1});
+ input.shape_status(luci::ShapeStatus::VALID);
+
+ squeeze.input(&input);
+ squeeze.squeeze_dims({0});
+
+ loco::TensorShape shape;
+ luci::sinf::Rule shape_inf_rule;
+
+ ASSERT_TRUE(shape_inf_rule.infer(&squeeze, shape));
+ ASSERT_EQ(3, shape.rank());
+ ASSERT_EQ(4, shape.dim(0).value());
+ ASSERT_EQ(3, shape.dim(1).value());
+ ASSERT_EQ(1, shape.dim(2).value());
+}
+
+TEST(ShapeRuleTest, squeeze_all)
+{
+ luci::CircleInput input;
+ luci::CircleSqueeze squeeze;
+
+ input.shape({1, 4, 3, 1});
+ input.shape_status(luci::ShapeStatus::VALID);
+
+ squeeze.input(&input);
+ squeeze.squeeze_dims({});
+
+ loco::TensorShape shape;
+ luci::sinf::Rule shape_inf_rule;
+
+ ASSERT_TRUE(shape_inf_rule.infer(&squeeze, shape));
+ ASSERT_EQ(2, shape.rank());
+ ASSERT_EQ(4, shape.dim(0).value());
+ ASSERT_EQ(3, shape.dim(1).value());
+}
+
+TEST(CloneNodeTest, clone_Squeeze)
+{
+ auto g = loco::make_graph();
+ auto node_squ = g->nodes()->create<luci::CircleSqueeze>();
+ node_squ->squeeze_dims({2, 3});
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_squ, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_squ = dynamic_cast<luci::CircleSqueeze *>(cloned);
+ ASSERT_NE(nullptr, cloned_squ);
+ ASSERT_EQ(node_squ->squeeze_dims().size(), cloned_squ->squeeze_dims().size());
+ for (size_t s = 0; s < node_squ->squeeze_dims().size(); ++s)
+ ASSERT_EQ(node_squ->squeeze_dims().at(s), cloned_squ->squeeze_dims().at(s));
+}
diff --git a/compiler/luci/service/src/Nodes/CircleStridedSlice.cpp b/compiler/luci/service/src/Nodes/CircleStridedSlice.cpp
new file mode 100644
index 000000000..c4d199316
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleStridedSlice.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleStridedSlice *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleStridedSlice>();
+ if (cloned != nullptr)
+ {
+ cloned->begin_mask(node->begin_mask());
+ cloned->end_mask(node->end_mask());
+ cloned->ellipsis_mask(node->ellipsis_mask());
+ cloned->new_axis_mask(node->new_axis_mask());
+ cloned->shrink_axis_mask(node->shrink_axis_mask());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleStridedSlice.test.cpp b/compiler/luci/service/src/Nodes/CircleStridedSlice.test.cpp
new file mode 100644
index 000000000..d633f3022
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleStridedSlice.test.cpp
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_StridedSlice)
+{
+ auto g = loco::make_graph();
+ auto node_ss = g->nodes()->create<luci::CircleStridedSlice>();
+ node_ss->begin_mask(1);
+ node_ss->end_mask(2);
+ node_ss->ellipsis_mask(3);
+ node_ss->new_axis_mask(4);
+ node_ss->shrink_axis_mask(5);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_ss, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_ss = dynamic_cast<luci::CircleStridedSlice *>(cloned);
+ ASSERT_NE(nullptr, cloned_ss);
+ ASSERT_EQ(node_ss->begin_mask(), cloned_ss->begin_mask());
+ ASSERT_EQ(node_ss->end_mask(), cloned_ss->end_mask());
+ ASSERT_EQ(node_ss->ellipsis_mask(), cloned_ss->ellipsis_mask());
+ ASSERT_EQ(node_ss->new_axis_mask(), cloned_ss->new_axis_mask());
+ ASSERT_EQ(node_ss->shrink_axis_mask(), cloned_ss->shrink_axis_mask());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSub.cpp b/compiler/luci/service/src/Nodes/CircleSub.cpp
new file mode 100644
index 000000000..fb4bab19a
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSub.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleSub *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleSub>();
+ if (cloned != nullptr)
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSub.test.cpp b/compiler/luci/service/src/Nodes/CircleSub.test.cpp
new file mode 100644
index 000000000..e6bd7b8ff
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSub.test.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Sub)
+{
+ auto g = loco::make_graph();
+ auto node_sub = g->nodes()->create<luci::CircleSub>();
+ node_sub->fusedActivationFunction(luci::FusedActFunc::RELU);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_sub, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_sub = dynamic_cast<luci::CircleSub *>(cloned);
+ ASSERT_NE(nullptr, cloned_sub);
+ ASSERT_EQ(node_sub->fusedActivationFunction(), cloned_sub->fusedActivationFunction());
+}
+
+TEST(CloneNodeTest, clone_Sub_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_sub = g->nodes()->create<luci::CircleSub>();
+ node_sub->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_sub, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleSum.cpp b/compiler/luci/service/src/Nodes/CircleSum.cpp
index 9ef90e8e0..29e6ee5f1 100644
--- a/compiler/luci/service/src/Nodes/CircleSum.cpp
+++ b/compiler/luci/service/src/Nodes/CircleSum.cpp
@@ -1,11 +1,11 @@
/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -14,15 +14,17 @@
* limitations under the License.
*/
-#include <luci/Service/CircleShapeSignatureInference.h>
+#include "CircleCloneNode.h"
namespace luci
{
-ShapeSignature ssinf::Algorithm::visit(const luci::CircleSum *node)
+luci::CircleNode *CloneNode::visit(const luci::CircleSum *node)
{
- return legalized_signature(
- reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+ auto *cloned = _graph->nodes()->create<luci::CircleSum>();
+ if (cloned != nullptr)
+ cloned->keep_dims(node->keep_dims());
+ return cloned;
}
} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSum.test.cpp b/compiler/luci/service/src/Nodes/CircleSum.test.cpp
new file mode 100644
index 000000000..aa1b0d128
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSum.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Sum)
+{
+ auto g = loco::make_graph();
+ auto node_sum = g->nodes()->create<luci::CircleSum>();
+ node_sum->keep_dims(true);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_sum, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_sum = dynamic_cast<luci::CircleSum *>(cloned);
+ ASSERT_NE(nullptr, cloned_sum);
+ ASSERT_EQ(node_sum->keep_dims(), cloned_sum->keep_dims());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleTanh.cpp b/compiler/luci/service/src/Nodes/CircleTanh.cpp
new file mode 100644
index 000000000..9cb35932f
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleTanh.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleTanh *)
+{
+ return _graph->nodes()->create<luci::CircleTanh>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleTanh.test.cpp b/compiler/luci/service/src/Nodes/CircleTanh.test.cpp
new file mode 100644
index 000000000..0215b42ca
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleTanh.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Tanh)
+{
+ auto g = loco::make_graph();
+ auto node_tanh = g->nodes()->create<luci::CircleTanh>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_tanh, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_tanh = dynamic_cast<luci::CircleTanh *>(cloned);
+ ASSERT_NE(nullptr, cloned_tanh);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleTile.cpp b/compiler/luci/service/src/Nodes/CircleTile.cpp
new file mode 100644
index 000000000..21c32e021
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleTile.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleTile *)
+{
+ return _graph->nodes()->create<luci::CircleTile>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleTile.test.cpp b/compiler/luci/service/src/Nodes/CircleTile.test.cpp
new file mode 100644
index 000000000..089c86ccb
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleTile.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Tile)
+{
+ auto g = loco::make_graph();
+ auto node_tile = g->nodes()->create<luci::CircleTile>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_tile, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_tile = dynamic_cast<luci::CircleTile *>(cloned);
+ ASSERT_NE(nullptr, cloned_tile);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleTopKV2.cpp b/compiler/luci/service/src/Nodes/CircleTopKV2.cpp
new file mode 100644
index 000000000..e940c03dd
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleTopKV2.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleTopKV2 *)
+{
+ return _graph->nodes()->create<luci::CircleTopKV2>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleTopKV2.test.cpp b/compiler/luci/service/src/Nodes/CircleTopKV2.test.cpp
new file mode 100644
index 000000000..7f68a408d
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleTopKV2.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_TopKV2)
+{
+ auto g = loco::make_graph();
+ auto node_top = g->nodes()->create<luci::CircleTopKV2>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_top, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_top = dynamic_cast<luci::CircleTopKV2 *>(cloned);
+ ASSERT_NE(nullptr, cloned_top);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleTopKV2Out.cpp b/compiler/luci/service/src/Nodes/CircleTopKV2Out.cpp
new file mode 100644
index 000000000..5c13f2be1
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleTopKV2Out.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleTopKV2Out *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleTopKV2Out>();
+ if (cloned != nullptr)
+ cloned->index(node->index());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleTopKV2Out.test.cpp b/compiler/luci/service/src/Nodes/CircleTopKV2Out.test.cpp
new file mode 100644
index 000000000..cfba61f10
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleTopKV2Out.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_TopKV2Out)
+{
+ auto g = loco::make_graph();
+ auto node_tout = g->nodes()->create<luci::CircleTopKV2Out>();
+ node_tout->index(1);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_tout, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_tout = dynamic_cast<luci::CircleTopKV2Out *>(cloned);
+ ASSERT_NE(nullptr, cloned_tout);
+ ASSERT_EQ(node_tout->index(), cloned_tout->index());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleTranspose.cpp b/compiler/luci/service/src/Nodes/CircleTranspose.cpp
new file mode 100644
index 000000000..81db55269
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleTranspose.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleTranspose *)
+{
+ return _graph->nodes()->create<luci::CircleTranspose>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleTranspose.test.cpp b/compiler/luci/service/src/Nodes/CircleTranspose.test.cpp
new file mode 100644
index 000000000..9447d1a5b
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleTranspose.test.cpp
@@ -0,0 +1,69 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Service/CircleShapeInference.h>
+
+#include <loco/IR/TensorShape.h>
+
+#include <gtest/gtest.h>
+
+TEST(ShapeRuleTest, transpose_simple)
+{
+ luci::CircleInput input;
+ luci::CircleConst perm;
+ luci::CircleTranspose transpose;
+
+ input.shape({3, 8, 1});
+ input.shape_status(luci::ShapeStatus::VALID);
+
+ perm.dtype(loco::DataType::S32);
+ perm.rank(1);
+ perm.dim(0).set(3);
+ perm.size<loco::DataType::S32>(3);
+ perm.at<loco::DataType::S32>(0) = 1;
+ perm.at<loco::DataType::S32>(1) = 2;
+ perm.at<loco::DataType::S32>(2) = 0;
+ perm.shape_status(luci::ShapeStatus::VALID);
+
+ transpose.a(&input);
+ transpose.perm(&perm);
+
+ loco::TensorShape shape;
+ luci::sinf::Rule shape_inf_rule;
+
+ ASSERT_TRUE(shape_inf_rule.infer(&transpose, shape));
+ ASSERT_EQ(3, shape.rank());
+ ASSERT_EQ(8, shape.dim(0).value());
+ ASSERT_EQ(1, shape.dim(1).value());
+ ASSERT_EQ(3, shape.dim(2).value());
+}
+
+TEST(CloneNodeTest, clone_Transpose)
+{
+ auto g = loco::make_graph();
+ auto node_tr = g->nodes()->create<luci::CircleTranspose>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_tr, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_tr = dynamic_cast<luci::CircleTranspose *>(cloned);
+ ASSERT_NE(nullptr, cloned_tr);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleTransposeConv.cpp b/compiler/luci/service/src/Nodes/CircleTransposeConv.cpp
new file mode 100644
index 000000000..1fe41bdb2
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleTransposeConv.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleTransposeConv *node)
+{
+ if (node->padding() == luci::Padding::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleTransposeConv>();
+ if (cloned != nullptr)
+ {
+ cloned->padding(node->padding());
+ cloned->stride()->h(node->stride()->h());
+ cloned->stride()->w(node->stride()->w());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp b/compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp
new file mode 100644
index 000000000..29a656c03
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleTransposeConv.test.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_TransposeConv)
+{
+ auto g = loco::make_graph();
+ auto node_trconv = g->nodes()->create<luci::CircleTransposeConv>();
+ node_trconv->padding(luci::Padding::SAME);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_trconv, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_trconv = dynamic_cast<luci::CircleTransposeConv *>(cloned);
+ ASSERT_NE(nullptr, cloned_trconv);
+ ASSERT_EQ(node_trconv->padding(), cloned_trconv->padding());
+}
+
+TEST(CloneNodeTest, clone_TransposeConv_padding_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_trconv = g->nodes()->create<luci::CircleTransposeConv>();
+ node_trconv->padding(luci::Padding::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_trconv, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp b/compiler/luci/service/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp
new file mode 100644
index 000000000..12205f3b0
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleUnidirectionalSequenceLSTM *node)
+{
+ if (node->fusedActivationFunction() == luci::FusedActFunc::UNDEFINED)
+ return nullptr;
+
+ auto *cloned = _graph->nodes()->create<luci::CircleUnidirectionalSequenceLSTM>();
+ if (cloned != nullptr)
+ {
+ cloned->fusedActivationFunction(node->fusedActivationFunction());
+ cloned->cell_clip(node->cell_clip());
+ cloned->proj_clip(node->proj_clip());
+ cloned->time_major(node->time_major());
+ cloned->asymmetric_quantize_inputs(node->asymmetric_quantize_inputs());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleUnidirectionalSequenceLSTM.test.cpp b/compiler/luci/service/src/Nodes/CircleUnidirectionalSequenceLSTM.test.cpp
new file mode 100644
index 000000000..c3816ab27
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleUnidirectionalSequenceLSTM.test.cpp
@@ -0,0 +1,54 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_UnidirectionalSequenceLSTM)
+{
+ auto g = loco::make_graph();
+ auto node_uslstm = g->nodes()->create<luci::CircleUnidirectionalSequenceLSTM>();
+ node_uslstm->fusedActivationFunction(luci::FusedActFunc::RELU);
+ node_uslstm->cell_clip(1.1f);
+ node_uslstm->proj_clip(2.2f);
+ node_uslstm->time_major(true);
+ node_uslstm->asymmetric_quantize_inputs(true);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_uslstm, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_uslstm = dynamic_cast<luci::CircleUnidirectionalSequenceLSTM *>(cloned);
+ ASSERT_NE(nullptr, cloned_uslstm);
+ ASSERT_EQ(node_uslstm->fusedActivationFunction(), cloned_uslstm->fusedActivationFunction());
+ ASSERT_EQ(node_uslstm->cell_clip(), cloned_uslstm->cell_clip());
+ ASSERT_EQ(node_uslstm->proj_clip(), cloned_uslstm->proj_clip());
+ ASSERT_EQ(node_uslstm->time_major(), cloned_uslstm->time_major());
+ ASSERT_EQ(node_uslstm->asymmetric_quantize_inputs(), cloned_uslstm->asymmetric_quantize_inputs());
+}
+
+TEST(CloneNodeTest, clone_UnidirectionalSequenceLSTM_NEG)
+{
+ auto g = loco::make_graph();
+ auto node_uslstm = g->nodes()->create<luci::CircleUnidirectionalSequenceLSTM>();
+ node_uslstm->fusedActivationFunction(luci::FusedActFunc::UNDEFINED);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_uslstm, gc.get());
+ ASSERT_EQ(nullptr, cloned);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleUnique.cpp b/compiler/luci/service/src/Nodes/CircleUnique.cpp
new file mode 100644
index 000000000..bde2ea0dc
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleUnique.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleUnique *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleUnique>();
+ if (cloned != nullptr)
+ cloned->idx_out_type(node->idx_out_type());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleUnique.test.cpp b/compiler/luci/service/src/Nodes/CircleUnique.test.cpp
new file mode 100644
index 000000000..a8ff9eade
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleUnique.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Unique)
+{
+ auto g = loco::make_graph();
+ auto node_uniq = g->nodes()->create<luci::CircleUnique>();
+ node_uniq->idx_out_type(loco::DataType::S32);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_uniq, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_uniq = dynamic_cast<luci::CircleUnique *>(cloned);
+ ASSERT_NE(nullptr, cloned_uniq);
+ ASSERT_EQ(node_uniq->idx_out_type(), cloned_uniq->idx_out_type());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleUniqueOut.cpp b/compiler/luci/service/src/Nodes/CircleUniqueOut.cpp
new file mode 100644
index 000000000..30093f9db
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleUniqueOut.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleUniqueOut *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleUniqueOut>();
+ if (cloned != nullptr)
+ cloned->index(node->index());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleUniqueOut.test.cpp b/compiler/luci/service/src/Nodes/CircleUniqueOut.test.cpp
new file mode 100644
index 000000000..780ad4b78
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleUniqueOut.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_UniqueOut)
+{
+ auto g = loco::make_graph();
+ auto node_uout = g->nodes()->create<luci::CircleUniqueOut>();
+ node_uout->index(1);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_uout, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_uout = dynamic_cast<luci::CircleUniqueOut *>(cloned);
+ ASSERT_NE(nullptr, cloned_uout);
+ ASSERT_EQ(node_uout->index(), cloned_uout->index());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleUnpack.cpp b/compiler/luci/service/src/Nodes/CircleUnpack.cpp
new file mode 100644
index 000000000..f9d61c426
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleUnpack.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleUnpack *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleUnpack>();
+ if (cloned != nullptr)
+ {
+ cloned->num(node->num());
+ cloned->axis(node->axis());
+ }
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleUnpack.test.cpp b/compiler/luci/service/src/Nodes/CircleUnpack.test.cpp
new file mode 100644
index 000000000..6559a9276
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleUnpack.test.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Unpack)
+{
+ auto g = loco::make_graph();
+ auto node_unp = g->nodes()->create<luci::CircleUnpack>();
+ node_unp->num(1);
+ node_unp->axis(2);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_unp, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_unp = dynamic_cast<luci::CircleUnpack *>(cloned);
+ ASSERT_NE(nullptr, cloned_unp);
+ ASSERT_EQ(node_unp->num(), cloned_unp->num());
+ ASSERT_EQ(node_unp->axis(), cloned_unp->axis());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleUnpackOut.cpp b/compiler/luci/service/src/Nodes/CircleUnpackOut.cpp
new file mode 100644
index 000000000..342d5daca
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleUnpackOut.cpp
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleUnpackOut *node)
+{
+ auto *cloned = _graph->nodes()->create<luci::CircleUnpackOut>();
+ if (cloned != nullptr)
+ cloned->index(node->index());
+ return cloned;
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleUnpackOut.test.cpp b/compiler/luci/service/src/Nodes/CircleUnpackOut.test.cpp
new file mode 100644
index 000000000..ec9bb974e
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleUnpackOut.test.cpp
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_UnpackOut)
+{
+ auto g = loco::make_graph();
+ auto node_uout = g->nodes()->create<luci::CircleUnpackOut>();
+ node_uout->index(1);
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_uout, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_uout = dynamic_cast<luci::CircleUnpackOut *>(cloned);
+ ASSERT_NE(nullptr, cloned_uout);
+ ASSERT_EQ(node_uout->index(), cloned_uout->index());
+}
diff --git a/compiler/luci/service/src/Nodes/CircleWhere.cpp b/compiler/luci/service/src/Nodes/CircleWhere.cpp
new file mode 100644
index 000000000..73f4b64ac
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleWhere.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleWhere *)
+{
+ return _graph->nodes()->create<luci::CircleWhere>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleWhere.test.cpp b/compiler/luci/service/src/Nodes/CircleWhere.test.cpp
new file mode 100644
index 000000000..352719d85
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleWhere.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_Where)
+{
+ auto g = loco::make_graph();
+ auto node_wh = g->nodes()->create<luci::CircleWhere>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_wh, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_wh = dynamic_cast<luci::CircleWhere *>(cloned);
+ ASSERT_NE(nullptr, cloned_wh);
+}
diff --git a/compiler/luci/service/src/Nodes/CircleZerosLike.cpp b/compiler/luci/service/src/Nodes/CircleZerosLike.cpp
new file mode 100644
index 000000000..2ee455857
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleZerosLike.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CircleCloneNode.h"
+
+namespace luci
+{
+
+luci::CircleNode *CloneNode::visit(const luci::CircleZerosLike *)
+{
+ return _graph->nodes()->create<luci::CircleZerosLike>();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleZerosLike.test.cpp b/compiler/luci/service/src/Nodes/CircleZerosLike.test.cpp
new file mode 100644
index 000000000..6e0a4b3be
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleZerosLike.test.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleNodeClone.h"
+
+#include <gtest/gtest.h>
+
+TEST(CloneNodeTest, clone_ZerosLike)
+{
+ auto g = loco::make_graph();
+ auto node_zl = g->nodes()->create<luci::CircleZerosLike>();
+
+ auto gc = loco::make_graph();
+ auto cloned = luci::clone_node(node_zl, gc.get());
+ ASSERT_NE(nullptr, cloned);
+ ASSERT_EQ(gc.get(), cloned->graph());
+
+ auto cloned_zl = dynamic_cast<luci::CircleZerosLike *>(cloned);
+ ASSERT_NE(nullptr, cloned_zl);
+}
diff --git a/compiler/luci/service/src/ShapeDescription.cpp b/compiler/luci/service/src/ShapeDescription.cpp
index 01a638f8f..adfb7e342 100644
--- a/compiler/luci/service/src/ShapeDescription.cpp
+++ b/compiler/luci/service/src/ShapeDescription.cpp
@@ -31,7 +31,7 @@ ShapeDescription to_shape_description(const luci::CircleNode *circle_node)
res._dims.resize(circle_node->rank());
for (uint32_t i = 0; i < circle_node->rank(); ++i)
- res._dims.at(i) = circle_node->dim(i).value();
+ res._dims.at(i) = circle_node->dim(i).known() ? circle_node->dim(i).value() : -1;
return res;
}
@@ -53,95 +53,12 @@ ShapeDescription to_shape_description(const loco::TensorShape &shape)
return res;
}
-ShapeDescription to_shape_description(const loco::FeatureShape &shape)
-{
- ShapeDescription res;
-
- res._rank_known = true;
-
- // T/F Lite encodes a feature map as a NHWC tensor
- res._dims.resize(4);
- res._dims.at(0) = shape.count().value();
- res._dims.at(1) = shape.height().value();
- res._dims.at(2) = shape.width().value();
- res._dims.at(3) = shape.depth().value();
-
- return res;
-}
-
-ShapeDescription to_shape_description(const loco::FilterShape &shape)
-{
- ShapeDescription res;
-
- res._rank_known = true;
-
- // T/F Lite encodes a convolution filter as a NHWC tensor
- res._dims.resize(4);
- res._dims.at(0) = shape.count().value();
- res._dims.at(1) = shape.height().value();
- res._dims.at(2) = shape.width().value();
- res._dims.at(3) = shape.depth().value();
-
- return res;
-}
-
-ShapeDescription to_shape_description(const loco::DepthwiseFilterShape &shape)
-{
- ShapeDescription res;
-
- res._rank_known = true;
-
- // T/F Lite encodes a depthwise convolution filter as a [1, H, W, C*M] tensor
- res._dims.resize(4);
- res._dims.at(0) = 1;
- res._dims.at(1) = shape.height().value();
- res._dims.at(2) = shape.width().value();
- res._dims.at(3) = shape.depth().value() * shape.multiplier().value();
-
- return res;
-}
-
-ShapeDescription to_shape_description(const loco::BiasShape &shape)
-{
- ShapeDescription res;
-
- res._rank_known = true;
-
- res._dims.resize(1);
- res._dims.at(0) = shape.length().value();
-
- return res;
-}
-
-ShapeDescription to_shape_description(const loco::MatrixShape &shape)
-{
- ShapeDescription res;
-
- res._rank_known = true;
-
- res._dims.resize(2);
- res._dims.at(0) = shape.height().value();
- res._dims.at(1) = shape.width().value();
-
- return res;
-}
-
ShapeDescription to_shape_description(const loco::NodeShape &shape)
{
switch (shape.domain())
{
case loco::Domain::Tensor:
return to_shape_description(shape.as<loco::TensorShape>());
- case loco::Domain::Feature:
- return to_shape_description(shape.as<loco::FeatureShape>());
- case loco::Domain::Filter:
- return to_shape_description(shape.as<loco::FilterShape>());
- case loco::Domain::DepthwiseFilter:
- return to_shape_description(shape.as<loco::DepthwiseFilterShape>());
- case loco::Domain::Bias:
- return to_shape_description(shape.as<loco::BiasShape>());
- case loco::Domain::Matrix:
- return to_shape_description(shape.as<loco::MatrixShape>());
default:
break;
}
diff --git a/compiler/luci/service/src/ShapeDescription.test.cpp b/compiler/luci/service/src/ShapeDescription.test.cpp
new file mode 100644
index 000000000..6e53aac75
--- /dev/null
+++ b/compiler/luci/service/src/ShapeDescription.test.cpp
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/ShapeDescription.h"
+
+#include <luci/IR/CircleNode.h>
+#include <luci/IR/Nodes/CircleConst.h>
+
+#include <gtest/gtest.h>
+
+TEST(ShapeDescriptionTest, CircleNode)
+{
+ // Use CircleConst as CircleNode
+ luci::CircleConst circle_const;
+ circle_const.shape({1, 2, 3, 4});
+
+ auto sd = luci::to_shape_description(&circle_const);
+
+ ASSERT_EQ(4, sd._dims.size());
+ ASSERT_EQ(1, sd._dims.at(0));
+ ASSERT_TRUE(sd._rank_known);
+}
+
+TEST(ShapeDescriptionTest, TensorShape)
+{
+ loco::TensorShape tensor_shape{1, 2, 3, 4};
+ loco::NodeShape node_shape(tensor_shape);
+
+ auto sd = luci::to_shape_description(node_shape);
+
+ ASSERT_EQ(4, sd._dims.size());
+ ASSERT_EQ(1, sd._dims.at(0));
+ ASSERT_TRUE(sd._rank_known);
+}
+
+TEST(ShapeDescriptionTest, BiasShape_NEG)
+{
+ loco::BiasShape bias_shape;
+ bias_shape.length() = 1;
+ loco::NodeShape node_shape(bias_shape);
+
+ EXPECT_THROW(luci::to_shape_description(node_shape), std::exception);
+}
diff --git a/compiler/luci/service/src/ShapeInfer_StridedSlice.cpp b/compiler/luci/service/src/ShapeInfer_StridedSlice.cpp
index 341201148..c5864f938 100644
--- a/compiler/luci/service/src/ShapeInfer_StridedSlice.cpp
+++ b/compiler/luci/service/src/ShapeInfer_StridedSlice.cpp
@@ -17,12 +17,12 @@
#include "ShapeInfer_StridedSlice.h"
#include "Check.h"
+#include "CircleShapeInferenceHelper.h"
#include <luci/IR/CircleNode.h>
#include <loco/IR/DataType.h>
#include <loco/IR/NodeShape.h>
#include <oops/InternalExn.h>
-#include <loco/Service/ShapeInference.h>
#include <cmath>
#include <cstdint>
@@ -245,7 +245,7 @@ loco::TensorShape infer_output_shape(const CircleStridedSlice *node)
assert(node->new_axis_mask() == 0);
auto op_params = BuildStridedSliceParams(node);
- loco::TensorShape input_shape = loco::shape_get(input_node).as<loco::TensorShape>();
+ loco::TensorShape input_shape = luci::shape_get(input_node).as<loco::TensorShape>();
uint32_t num_input_axes = input_shape.rank();
assert(begin_node->size<S32>() <= num_input_axes);
diff --git a/compiler/luci/service/src/Validate.cpp b/compiler/luci/service/src/Validate.cpp
index 3f732b6fe..7ed14c356 100644
--- a/compiler/luci/service/src/Validate.cpp
+++ b/compiler/luci/service/src/Validate.cpp
@@ -20,10 +20,9 @@
#include <luci/Log.h>
#include <loco/IR/NodeShape.h>
-#include <loco/Service/ShapeInference.h>
-#include <loco/Service/TypeInference.h>
#include <cassert>
+#include <unordered_map>
#include <vector>
namespace
@@ -36,7 +35,11 @@ std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape
{
if (r)
os << ",";
- os << tensor_shape.dim(r).value();
+
+ if (tensor_shape.dim(r).known())
+ os << tensor_shape.dim(r).value();
+ else
+ os << "?";
}
os << "]";
return os;
@@ -49,7 +52,11 @@ std::ostream &operator<<(std::ostream &os, const luci::CircleNode *circle_node)
{
if (r)
os << ",";
- os << circle_node->dim(r).value();
+
+ if (circle_node->dim(r).known())
+ os << circle_node->dim(r).value();
+ else
+ os << "?";
}
os << "]";
return os;
@@ -99,10 +106,24 @@ bool validate_shape_dtype(loco::Graph *g)
auto go_tensor_shape = graph_out->shape();
assert(go_tensor_shape);
+ // NOTE Even if shape of graph output is [] (which means "shape inference was impossible")
+ // but shape of CircleNode is not, it can be valid case because shape inference
+ // algorithm of CircleNode may be upgraded than before. The opposite is possible either.
+ // If such cases are appeared, following validation code should be fixed.
bool is_shape_valid = (circle_node->rank() == go_tensor_shape->rank());
for (uint32_t i = 0; is_shape_valid && i < circle_node->rank(); ++i)
- if (circle_node->dim(i).value() != go_tensor_shape->dim(i).value())
+ {
+ if (!circle_node->dim(i).known() || !go_tensor_shape->dim(i).known())
+ {
+ // If at least one of two dimensions is unknown,
+ // the unknown dimension can accept any value.
+ INFO(l) << "Unknown dimension is matched with known dimension" << std::endl;
+ }
+ else if (circle_node->dim(i).value() != go_tensor_shape->dim(i).value())
+ {
is_shape_valid = false;
+ }
+ }
if (is_shape_valid == false)
{
@@ -124,72 +145,62 @@ bool validate_shape_dtype(loco::Graph *g)
return true;
}
-bool validate_shape_signature(loco::Graph *g)
-{
- LOGGER(l);
-
- for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
- {
- auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- const auto shape_signature = circle_node->shape_signature();
+} // namespace
- if (shape_signature.rank() == 0)
- continue;
+namespace luci
+{
- // Rank of shape and shape signature should be same
- if (circle_node->rank() != shape_signature.rank())
- {
- INFO(l) << "[luci] Rank of shape signature for " << circle_node->name() << " do not match"
- << std::endl;
- return false;
- }
+bool validate(loco::Graph *g)
+{
+ if (!loco::valid(g))
+ return false;
- bool has_unknown = false;
+ if (!validate_shape_dtype(g))
+ return false;
- // If shape siganture is not -1, dimension value should be same
- for (uint32_t d = 0; d < shape_signature.rank(); ++d)
- {
- if (shape_signature.dim(d) != -1 &&
- shape_signature.dim(d) != (int32_t)(circle_node->dim(d).value()))
- {
- INFO(l) << "[luci] Dimension " << d << "of shape signature for " << circle_node->name()
- << " do not match" << std::endl;
- return false;
- }
+ // TODO add more validation
- if (shape_signature.dim(d) == -1)
- has_unknown = true;
- }
+ return true;
+}
- // Shape signature should have at least one -1 value.
- if (!has_unknown)
- {
- INFO(l) << "[luci] Shape signature in " << circle_node->name()
- << " do not have unknown dimension" << std::endl;
+bool validate_name(loco::Graph *g)
+{
+ auto nodes = g->nodes();
+ for (uint32_t n = 0; n < nodes->size(); ++n)
+ {
+ auto node = loco::must_cast<luci::CircleNode *>(nodes->at(n));
+ auto name = node->name();
+ if (name.empty())
return false;
- }
}
return true;
}
-} // namespace
-
-namespace luci
+bool validate_unique_name(luci::Module *m)
{
+ std::unordered_map<std::string, bool> names_col;
-bool validate(loco::Graph *g)
-{
- if (!loco::valid(g))
- return false;
-
- if (!validate_shape_dtype(g))
- return false;
-
- if (!validate_shape_signature(g))
- return false;
+ for (size_t g = 0; g < m->size(); ++g)
+ {
+ auto graph = m->graph(g);
+ auto nodes = graph->nodes();
+ for (uint32_t n = 0; n < nodes->size(); ++n)
+ {
+ auto node = loco::must_cast<luci::CircleNode *>(nodes->at(n));
+ // skip CircleOutput as it may have same name with from() node
+ auto output = dynamic_cast<luci::CircleOutput *>(node);
+ if (output != nullptr)
+ continue;
+
+ auto name = node->name();
+ auto it = names_col.find(name);
+ if (it != names_col.end())
+ return false;
- // TODO add more validation
+ names_col[name] = true;
+ }
+ }
return true;
}
diff --git a/compiler/luci/service/src/Validate.test.cpp b/compiler/luci/service/src/Validate.test.cpp
new file mode 100644
index 000000000..8ce6d895b
--- /dev/null
+++ b/compiler/luci/service/src/Validate.test.cpp
@@ -0,0 +1,139 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/Validate.h"
+
+#include <luci/test/TestIOGraph.h>
+
+#include <luci/IR/Nodes/CircleAdd.h>
+#include <luci/IR/Nodes/CircleSqrt.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class SqrtGraphlet
+{
+public:
+ SqrtGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 input_shape)
+ {
+ _sqrt = g->nodes()->create<luci::CircleSqrt>();
+ _sqrt->dtype(loco::DataType::S32);
+ _sqrt->name("sqrt");
+ }
+
+protected:
+ luci::CircleSqrt *_sqrt = nullptr;
+};
+
+class SqrtGraph : public TestIOGraph, public SqrtGraphlet
+{
+public:
+ SqrtGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ SqrtGraphlet::init(g(), shape);
+
+ _sqrt->x(input());
+
+ output()->from(_sqrt);
+
+ // set output name to _sqrt: CircleOutput may have duplicate name
+ output()->name(_sqrt->name());
+ }
+};
+
+class Sqrt2xGraphlet
+{
+public:
+ Sqrt2xGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 input_shape)
+ {
+ _sqrt1 = g->nodes()->create<luci::CircleSqrt>();
+ _sqrt1->dtype(loco::DataType::S32);
+ _sqrt1->name("sqrt");
+
+ _sqrt2 = g->nodes()->create<luci::CircleSqrt>();
+ _sqrt2->dtype(loco::DataType::S32);
+ _sqrt2->name("sqrt");
+ }
+
+protected:
+ luci::CircleSqrt *_sqrt1 = nullptr;
+ luci::CircleSqrt *_sqrt2 = nullptr;
+};
+
+class Sqrt2xGraph : public TestIOGraph, public Sqrt2xGraphlet
+{
+public:
+ Sqrt2xGraph() = default;
+
+public:
+ void init(const ShapeU32 shape)
+ {
+ TestIOGraph::init(shape, shape);
+ Sqrt2xGraphlet::init(g(), shape);
+
+ _sqrt1->x(input());
+
+ _sqrt2->x(_sqrt1);
+
+ output()->from(_sqrt2);
+ }
+};
+
+} // namespace
+
+TEST(ValidateTest, non_empty_name)
+{
+ SqrtGraph g;
+ g.init({3, 3});
+
+ ASSERT_TRUE(luci::validate_name(g.g()));
+}
+
+TEST(ValidateTest, unique_name)
+{
+ luci::Module module;
+
+ SqrtGraph g;
+ g.init({3, 3});
+ g.transfer_to(&module);
+
+ ASSERT_TRUE(luci::validate_unique_name(&module));
+}
+
+TEST(ValidateTest, unique_name_NEG)
+{
+ luci::Module module;
+
+ Sqrt2xGraph g;
+ g.init({3, 3});
+ g.transfer_to(&module);
+
+ ASSERT_FALSE(luci::validate_unique_name(&module));
+}
diff --git a/compiler/luci/tester/CMakeLists.txt b/compiler/luci/tester/CMakeLists.txt
index 3ac06ef3a..13aab11e7 100644
--- a/compiler/luci/tester/CMakeLists.txt
+++ b/compiler/luci/tester/CMakeLists.txt
@@ -6,6 +6,7 @@ TargetRequire_Return(${REQUIRED_TARGETS})
set(SRCS_READ_TESTER
src/ReadTester.cpp
+ src/ReadModule.cpp
)
add_executable(luci_readtester "${SRCS_READ_TESTER}")
@@ -18,6 +19,7 @@ target_link_libraries(luci_readtester PRIVATE safemain)
set(SRCS_WRITE_TESTER
src/WriteTester.cpp
+ src/ReadModule.cpp
)
add_executable(luci_writetester "${SRCS_WRITE_TESTER}")
@@ -28,3 +30,22 @@ target_link_libraries(luci_writetester PRIVATE luci_export)
target_link_libraries(luci_writetester PRIVATE foder)
target_link_libraries(luci_writetester PRIVATE oops)
target_link_libraries(luci_writetester PRIVATE safemain)
+
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest REQUIRED)
+
+GTest_AddTest(luci_readtester_test src/ReadTester.test.cpp ${SRCS_READ_TESTER})
+target_link_libraries(luci_readtester_test luci_import)
+target_link_libraries(luci_readtester_test luci_service)
+target_link_libraries(luci_readtester_test luci_pass)
+target_link_libraries(luci_readtester_test foder)
+
+GTest_AddTest(luci_writetester_test src/WriteTester.test.cpp ${SRCS_WRITE_TESTER})
+target_link_libraries(luci_writetester_test luci_import)
+target_link_libraries(luci_writetester_test luci_service)
+target_link_libraries(luci_writetester_test luci_pass)
+target_link_libraries(luci_writetester_test luci_export)
+target_link_libraries(luci_writetester_test foder)
diff --git a/compiler/luci/tester/src/ReadModule.cpp b/compiler/luci/tester/src/ReadModule.cpp
new file mode 100644
index 000000000..87c1233f0
--- /dev/null
+++ b/compiler/luci/tester/src/ReadModule.cpp
@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ReadModule.h"
+
+#include <luci/Pass/CircleShapeInferencePass.h>
+#include <luci/Pass/CircleTypeInferencePass.h>
+#include <luci/Service/Validate.h>
+
+#include <logo/Phase.h>
+
+#include <iostream>
+#include <string>
+#include <vector>
+
+std::unique_ptr<luci::Module> ReadModule(std::string &input_path)
+{
+ // Load model from the file
+ foder::FileLoader file_loader{input_path};
+ std::vector<char> model_data = file_loader.load();
+ const circle::Model *circle_model = circle::GetModel(model_data.data());
+ if (circle_model == nullptr)
+ {
+ std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
+ return nullptr;
+ }
+
+ luci::Importer importer;
+ auto module = importer.importModule(circle_model);
+ assert(module->size() > 0);
+
+ for (size_t g = 0; g < module->size(); ++g)
+ {
+ auto graph = module->graph(g);
+ if (graph == nullptr)
+ return nullptr;
+
+ {
+ logo::Phase phase;
+
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
+
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{graph};
+ phase_runner.run(phase);
+ }
+
+ if (!luci::validate(graph))
+ return nullptr;
+ }
+ return module;
+}
diff --git a/compiler/luci/tester/src/ReadModule.h b/compiler/luci/tester/src/ReadModule.h
new file mode 100644
index 000000000..dfa9bad6b
--- /dev/null
+++ b/compiler/luci/tester/src/ReadModule.h
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_TESTER_READ_MODULE_H__
+#define __LUCI_TESTER_READ_MODULE_H__
+
+#include <luci/Importer.h>
+#include <foder/FileLoader.h>
+
+#include <memory>
+#include <string>
+
+std::unique_ptr<luci::Module> ReadModule(std::string &input_path);
+
+#endif // __LUCI_TESTER_READ_MODULE_H__
diff --git a/compiler/luci/tester/src/ReadTester.cpp b/compiler/luci/tester/src/ReadTester.cpp
index f270a232c..864343e43 100644
--- a/compiler/luci/tester/src/ReadTester.cpp
+++ b/compiler/luci/tester/src/ReadTester.cpp
@@ -14,18 +14,9 @@
* limitations under the License.
*/
-#include <foder/FileLoader.h>
-
-#include <luci/Importer.h>
-#include <luci/Service/Validate.h>
-#include <luci/Pass/ShapeInferencePass.h>
-#include <luci/Pass/TypeInferencePass.h>
-
-// Following passes will be removed after refactoring is finished
-#include <luci/Pass/MigrateLegacyShapeDtypePass.h>
+#include "ReadModule.h"
#include <iostream>
-#include <map>
#include <string>
namespace
@@ -68,45 +59,9 @@ int entry(int argc, char **argv)
std::cout << "[INFO] Circle is '" << input_path << "'" << std::endl;
- // Load model from the file
- foder::FileLoader file_loader{input_path};
- std::vector<char> model_data = file_loader.load();
- const circle::Model *circle_model = circle::GetModel(model_data.data());
- if (circle_model == nullptr)
- {
- std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
+ auto module = ReadModule(input_path);
+ if (module == nullptr)
return EXIT_FAILURE;
- }
-
- luci::Importer importer;
- auto module = importer.importModule(circle_model);
- assert(module->size() > 0);
- for (size_t g = 0; g < module->size(); ++g)
- {
- auto graph = module->graph(g);
- if (graph == nullptr)
- return 255;
-
- {
- luci::ShapeInferencePass pass;
- while (pass.run(graph) == true)
- ;
- }
- {
- luci::TypeInferencePass pass;
- while (pass.run(graph) == true)
- ;
- }
- {
- // This pass will be removed after refactoring is finished
- luci::MigrateLegacyShapeDtypePass pass;
- while (pass.run(graph) == true)
- ;
- }
-
- if (!luci::validate(graph))
- return 255;
- }
return 0;
}
diff --git a/compiler/luci/tester/src/ReadTester.test.cpp b/compiler/luci/tester/src/ReadTester.test.cpp
new file mode 100644
index 000000000..f3850d517
--- /dev/null
+++ b/compiler/luci/tester/src/ReadTester.test.cpp
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+// From ReadTester.cpp
+int entry(int argc, char **argv);
+
+TEST(ReadTesterTest, invalid_argc_NEG)
+{
+ char argv_1[20];
+ strcpy(argv_1, "ReadTesterTest");
+
+ int argc = 1;
+ char *argv[] = {argv_1};
+
+ ASSERT_NE(0, entry(argc, argv));
+}
+
+TEST(ReadTesterTest, invalid_file_NEG)
+{
+ char argv_1[20], argv_2[20];
+ strcpy(argv_1, "ReadTesterTest");
+ strcpy(argv_2, "not_a_file");
+
+ int argc = 2;
+ char *argv[] = {argv_1, argv_2};
+
+ EXPECT_THROW(entry(argc, argv), std::runtime_error);
+}
diff --git a/compiler/luci/tester/src/WriteTester.cpp b/compiler/luci/tester/src/WriteTester.cpp
index 9a6e8de05..0d3a1efa2 100644
--- a/compiler/luci/tester/src/WriteTester.cpp
+++ b/compiler/luci/tester/src/WriteTester.cpp
@@ -14,21 +14,13 @@
* limitations under the License.
*/
-#include <foder/FileLoader.h>
+#include "ReadModule.h"
-#include <luci/Importer.h>
-#include <luci/Pass/ShapeInferencePass.h>
-#include <luci/Pass/TypeInferencePass.h>
-#include <luci/Service/Validate.h>
#include <luci/CircleExporter.h>
#include <oops/InternalExn.h>
-// Following passes will be removed after refactoring is finished
-#include <luci/Pass/MigrateLegacyShapeDtypePass.h>
-
#include <fstream>
#include <iostream>
-#include <map>
#include <string>
namespace
@@ -51,12 +43,12 @@ struct CircleExpContract : public luci::CircleExporter::Contract
{
public:
CircleExpContract(loco::Graph *graph, const std::string &filename)
- : _graph(graph), _filepath(filename)
+ : _graph(graph), _filepath(filename)
{
// NOTHING TO DO
}
CircleExpContract(luci::Module *module, const std::string &filename)
- : _module(module), _filepath(filename)
+ : _module(module), _filepath(filename)
{
// NOTHING TO DO
}
@@ -111,47 +103,9 @@ int entry(int argc, char **argv)
std::cout << "[INFO] Circle from '" << input_path << "' to '" << output_path << "'" << std::endl;
- // Load model from the file
- foder::FileLoader file_loader{input_path};
- std::vector<char> model_data = file_loader.load();
- const circle::Model *circle_model = circle::GetModel(model_data.data());
- if (circle_model == nullptr)
- {
- std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
+ auto module = ReadModule(input_path);
+ if (module == nullptr)
return EXIT_FAILURE;
- }
-
- // Import from input Circle file
- luci::Importer importer;
- auto module = importer.importModule(circle_model);
- assert(module->size() > 0);
-
- for (size_t g = 0; g < module->size(); ++g)
- {
- auto graph = module->graph(g);
- if (graph == nullptr)
- return 255;
-
- {
- luci::ShapeInferencePass pass;
- while (pass.run(graph) == true)
- ;
- }
- {
- luci::TypeInferencePass pass;
- while (pass.run(graph) == true)
- ;
- }
- {
- // This pass will be removed after refactoring is finished
- luci::MigrateLegacyShapeDtypePass pass;
- while (pass.run(graph) == true)
- ;
- }
-
- if (!luci::validate(graph))
- return 255;
- }
// Export to output Circle file
luci::CircleExporter exporter;
diff --git a/compiler/luci/tester/src/WriteTester.test.cpp b/compiler/luci/tester/src/WriteTester.test.cpp
new file mode 100644
index 000000000..9d34c5f98
--- /dev/null
+++ b/compiler/luci/tester/src/WriteTester.test.cpp
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+// From WriteTester.cpp
+int entry(int argc, char **argv);
+
+TEST(WriteTesterTest, invalid_argc_NEG)
+{
+ char argv_1[20];
+ strcpy(argv_1, "WriteTesterTest");
+
+ int argc = 1;
+ char *argv[] = {argv_1};
+
+ ASSERT_NE(0, entry(argc, argv));
+}
+
+TEST(WriteTesterTest, invalid_file_NEG)
+{
+ char argv_1[20], argv_2[20], argv_3[20];
+ strcpy(argv_1, "WriteTesterTest");
+ strcpy(argv_2, "not_a_file");
+ strcpy(argv_3, "not_a_file");
+
+ int argc = 3;
+ char *argv[] = {argv_1, argv_2, argv_3};
+
+ EXPECT_THROW(entry(argc, argv), std::runtime_error);
+}
diff --git a/compiler/luci/testhelper/CMakeLists.txt b/compiler/luci/testhelper/CMakeLists.txt
new file mode 100644
index 000000000..86aa66225
--- /dev/null
+++ b/compiler/luci/testhelper/CMakeLists.txt
@@ -0,0 +1,25 @@
+if(NOT ENABLE_TEST)
+ return()
+endif(NOT ENABLE_TEST)
+
+nnas_find_package(GTest REQUIRED)
+
+# NOTE we are using "*.test.cpp" NOT to be included in static analyzer tools
+
+# testhelper library itself
+set(HELPER_SOURCE
+ src/TestShape.test.cpp
+ )
+
+add_library(luci_testhelper STATIC ${HELPER_SOURCE})
+target_include_directories(luci_testhelper PRIVATE src)
+target_include_directories(luci_testhelper PUBLIC include)
+target_link_libraries(luci_testhelper luci_lang)
+
+# test for testhelper library
+set(TESTER_SOURCE
+ src/TestIOGraph.test.cpp
+ )
+
+GTest_AddTest(luci_testhelper_test ${TESTER_SOURCE})
+target_link_libraries(luci_testhelper_test luci_testhelper)
diff --git a/compiler/luci/testhelper/README.md b/compiler/luci/testhelper/README.md
new file mode 100644
index 000000000..6bdb92aa4
--- /dev/null
+++ b/compiler/luci/testhelper/README.md
@@ -0,0 +1,3 @@
+# luci-testhelper
+
+_luci-testhelper_ provides Helper classes for unit testing
diff --git a/compiler/luci/testhelper/include/luci/test/TestIOGraph.h b/compiler/luci/testhelper/include/luci/test/TestIOGraph.h
new file mode 100644
index 000000000..ae04f4dbc
--- /dev/null
+++ b/compiler/luci/testhelper/include/luci/test/TestIOGraph.h
@@ -0,0 +1,198 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_TESTHELPER_TEST_IO_GRAPH_H__
+#define __LUCI_TESTHELPER_TEST_IO_GRAPH_H__
+
+#include "TestShape.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/Module.h>
+
+#include <memory>
+#include <stdexcept>
+
+namespace luci
+{
+namespace test
+{
+
+/**
+ * @brief Graphlet with Inputs and loco::Graph for multiple inputs
+ * @note Every Graph will have Input(s) and Output(s)
+ * We put loco::Graph only in IsGraphlet not to declare separate
+ * class for loco::Graph
+ */
+template <unsigned N> class TestIsGraphlet
+{
+public:
+ TestIsGraphlet()
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _graph_inputs[n] = nullptr;
+ _inputs[n] = nullptr;
+ }
+ _g = loco::make_graph();
+ }
+
+public:
+ virtual void init(loco::Graph *g, const std::initializer_list<ShapeU32> shape_in)
+ {
+ if (shape_in.size() != N)
+ throw std::runtime_error("Failed to init TestIsGraphlet");
+
+ auto shpin = shape_in.begin();
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _graph_inputs[n] = g->inputs()->create();
+
+ _inputs[n] = g->nodes()->create<luci::CircleInput>();
+ _inputs[n]->shape(*shpin);
+ _inputs[n]->shape_status(luci::ShapeStatus::VALID);
+ _inputs[n]->dtype(loco::DataType::FLOAT32);
+ _inputs[n]->name("input_" + std::to_string(n));
+
+ _inputs[n]->index(_graph_inputs[n]->index());
+
+ auto input_shape = std::make_unique<loco::TensorShape>();
+ set_shape_vector(input_shape.get(), *shpin);
+ _graph_inputs[n]->shape(std::move(input_shape));
+ _graph_inputs[n]->dtype(loco::DataType::FLOAT32);
+
+ shpin++;
+ }
+ }
+
+public:
+ loco::Graph *g(void) { return _g.get(); }
+ luci::CircleInput *input(int idx) { return _inputs[idx]; }
+ uint32_t num_inputs(void) { return N; }
+
+public:
+ void transfer_to(luci::Module *module)
+ {
+ // WARNING: after g is transfered, _graph_inputs, _inputs
+ // and _graph_outputs, _outputs in TestOsGraphlet will be invalid.
+ // arrays are not cleared as this is just helpers to unit tests
+ module->add(std::move(_g));
+ }
+
+protected:
+ std::unique_ptr<loco::Graph> _g;
+ std::array<loco::GraphInput *, N> _graph_inputs;
+ std::array<luci::CircleInput *, N> _inputs;
+};
+
+/**
+ * @brief Graphlet with one Input
+ */
+class TestIGraphlet : public TestIsGraphlet<1>
+{
+public:
+ virtual void init(loco::Graph *g, const ShapeU32 shape_in)
+ {
+ TestIsGraphlet<1>::init(g, {shape_in});
+ }
+
+ luci::CircleInput *input() { return _inputs[0]; }
+};
+
+/**
+ * @brief Graphlet with Outputs for multiple outputs
+ */
+template <unsigned N> class TestOsGraphlet
+{
+public:
+ TestOsGraphlet()
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _graph_outputs[n] = nullptr;
+ _outputs[n] = nullptr;
+ }
+ }
+
+public:
+ virtual void init(loco::Graph *g, const std::initializer_list<ShapeU32> shape_out)
+ {
+ if (shape_out.size() != N)
+ throw std::runtime_error("Failed to init TestOsGraphlet");
+
+ auto shpout = shape_out.begin();
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _graph_outputs[n] = g->outputs()->create();
+
+ _outputs[n] = g->nodes()->create<luci::CircleOutput>();
+ _outputs[n]->shape(*shpout);
+ _outputs[n]->shape_status(luci::ShapeStatus::VALID);
+ _outputs[n]->dtype(loco::DataType::FLOAT32);
+ _outputs[n]->name("output_" + std::to_string(n));
+
+ _outputs[n]->index(_graph_outputs[n]->index());
+
+ auto output_shape = std::make_unique<loco::TensorShape>();
+ set_shape_vector(output_shape.get(), *shpout);
+ _graph_outputs[n]->shape(std::move(output_shape));
+ _graph_outputs[n]->dtype(loco::DataType::FLOAT32);
+
+ shpout++;
+ }
+ }
+
+public:
+ luci::CircleOutput *output(int idx) { return _outputs[idx]; }
+
+protected:
+ std::array<loco::GraphOutput *, N> _graph_outputs;
+ std::array<luci::CircleOutput *, N> _outputs;
+};
+
+/**
+ * @brief Graphlet with one Output
+ */
+class TestOGraphlet : public TestOsGraphlet<1>
+{
+public:
+ virtual void init(loco::Graph *g, const ShapeU32 shape_out)
+ {
+ TestOsGraphlet<1>::init(g, {shape_out});
+ }
+
+ luci::CircleOutput *output() { return _outputs[0]; }
+};
+
+/**
+ * @brief Graph with Input and Output
+ */
+class TestIOGraph : public TestIGraphlet, public TestOGraphlet
+{
+public:
+ TestIOGraph() = default;
+
+public:
+ virtual void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIGraphlet::init(g(), shape_in);
+ TestOGraphlet::init(g(), shape_out);
+ }
+};
+
+} // namespace test
+} // namespace luci
+
+#endif // __LUCI_TESTHELPER_TEST_IO_GRAPH_H__
diff --git a/compiler/luci/testhelper/include/luci/test/TestShape.h b/compiler/luci/testhelper/include/luci/test/TestShape.h
new file mode 100644
index 000000000..1a5adf7d6
--- /dev/null
+++ b/compiler/luci/testhelper/include/luci/test/TestShape.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_TESTHELPER_TEST_SHAPE_H__
+#define __LUCI_TESTHELPER_TEST_SHAPE_H__
+
+#include <luci/IR/CircleNode.h>
+
+#include <initializer_list>
+
+namespace luci
+{
+namespace test
+{
+
+using ShapeU32 = std::initializer_list<uint32_t>;
+using ShapeI32 = std::initializer_list<int32_t>;
+
+void set_shape_vector(loco::TensorShape *shape, const ShapeU32 &values);
+void set_shape_vector(luci::CircleConst *const_node, const ShapeI32 &values);
+
+uint32_t num_elements(const ShapeU32 shape);
+
+} // namespace test
+} // namespace luci
+
+#endif // __LUCI_TESTHELPER_TEST_SHAPE_H__
diff --git a/compiler/luci/testhelper/src/TestIOGraph.test.cpp b/compiler/luci/testhelper/src/TestIOGraph.test.cpp
new file mode 100644
index 000000000..8a7d1e060
--- /dev/null
+++ b/compiler/luci/testhelper/src/TestIOGraph.test.cpp
@@ -0,0 +1,182 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/test/TestIOGraph.h"
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class SqrtGraphlet
+{
+public:
+ SqrtGraphlet() = default;
+
+ void init(loco::Graph *g)
+ {
+ _sqrt = g->nodes()->create<luci::CircleSqrt>();
+ _sqrt->name("sqrt");
+ }
+
+protected:
+ luci::CircleSqrt *_sqrt = nullptr;
+};
+
+class AddGraphlet
+{
+public:
+ AddGraphlet() = default;
+
+ void init(loco::Graph *g)
+ {
+ _add = g->nodes()->create<luci::CircleAdd>();
+ _add->name("add");
+ }
+
+protected:
+ luci::CircleAdd *_add = nullptr;
+};
+
+class ConvGraphlet
+{
+public:
+ ConvGraphlet() = default;
+
+ void init(loco::Graph *g)
+ {
+ _conv = g->nodes()->create<luci::CircleConv2D>();
+ _conv->name("conv");
+ }
+
+protected:
+ luci::CircleConv2D *_conv = nullptr;
+};
+
+} // namespace
+
+namespace
+{
+
+class TestOfTestIOGraph : public TestIOGraph, public SqrtGraphlet
+{
+public:
+ TestOfTestIOGraph() = default;
+
+public:
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ SqrtGraphlet::init(g());
+
+ _sqrt->x(input());
+
+ output()->from(_sqrt);
+ }
+};
+
+class TestOfTestI2OGraph : public TestIsGraphlet<2>, public TestOGraphlet, public AddGraphlet
+{
+public:
+ TestOfTestI2OGraph() = default;
+
+public:
+ void init(void)
+ {
+ TestIsGraphlet<2>::init(g(), {{2, 3}, {2, 3}});
+ TestOsGraphlet<1>::init(g(), {{2, 3}});
+ AddGraphlet::init(g());
+
+ _add->x(input(0));
+ _add->y(input(1));
+
+ output()->from(_add);
+ }
+};
+
+class TestOfTestI3OGraph : public TestIsGraphlet<3>, public TestOGraphlet, public ConvGraphlet
+{
+public:
+ TestOfTestI3OGraph() = default;
+
+public:
+ void init(void)
+ {
+ TestIsGraphlet<3>::init(g(), {{2, 3, 3, 4}, {1, 1}, {4}});
+ TestOsGraphlet<1>::init(g(), {{2, 3, 3, 4}});
+ ConvGraphlet::init(g());
+
+ _conv->input(input(0));
+ _conv->filter(input(1));
+ _conv->bias(input(2));
+
+ output()->from(_conv);
+ }
+};
+
+class FailOfTestI3OGraph : public TestIsGraphlet<3>, public TestOGraphlet, public ConvGraphlet
+{
+public:
+ FailOfTestI3OGraph() = default;
+
+public:
+ void init(void)
+ {
+ TestIsGraphlet<3>::init(g(), {{2, 3, 3, 4}, {1, 1}});
+ TestOsGraphlet<1>::init(g(), {{2, 3, 3, 4}});
+ ConvGraphlet::init(g());
+
+ _conv->input(input(0));
+ _conv->filter(input(1));
+ _conv->bias(input(2));
+
+ output()->from(_conv);
+ }
+};
+
+} // namespace
+
+TEST(TestIOGraphTest, IOGraph_init)
+{
+ TestOfTestIOGraph tg;
+ tg.init();
+
+ SUCCEED();
+}
+
+TEST(TestIOGraphTest, I2OGraph_init)
+{
+ TestOfTestI2OGraph tg;
+ tg.init();
+
+ SUCCEED();
+}
+
+TEST(TestIOGraphTest, I3OGraph_init)
+{
+ TestOfTestI3OGraph tg;
+ tg.init();
+
+ SUCCEED();
+}
+
+TEST(TestIOGraphTest, I3OGraph_input_number_mismatch_NEG)
+{
+ FailOfTestI3OGraph fg;
+ EXPECT_THROW(fg.init(), std::runtime_error);
+}
diff --git a/compiler/luci/testhelper/src/TestShape.test.cpp b/compiler/luci/testhelper/src/TestShape.test.cpp
new file mode 100644
index 000000000..9838c6182
--- /dev/null
+++ b/compiler/luci/testhelper/src/TestShape.test.cpp
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/test/TestShape.h"
+
+/**
+ * @note This file does not hold any test cases but provides methods for tests
+ */
+
+namespace luci
+{
+namespace test
+{
+
+void set_shape_vector(loco::TensorShape *shape, const ShapeU32 &values)
+{
+ uint32_t r = 0;
+ shape->rank(values.size());
+ for (auto v : values)
+ shape->dim(r++).set(v);
+}
+
+void set_shape_vector(luci::CircleConst *const_node, const ShapeI32 &values)
+{
+ const_node->rank(1);
+ const_node->dim(0).set(values.size());
+ const_node->shape_status(luci::ShapeStatus::VALID);
+ const_node->dtype(loco::DataType::S32);
+ const_node->size<loco::DataType::S32>(values.size());
+ uint32_t idx = 0;
+ for (auto val : values)
+ const_node->at<loco::DataType::S32>(idx++) = val;
+}
+
+uint32_t num_elements(const ShapeU32 shape)
+{
+ uint32_t result = 1;
+ for (auto val : shape)
+ result = result * val;
+ return result;
+}
+
+} // namespace test
+} // namespace luci
diff --git a/compiler/luci/tests/test.lst b/compiler/luci/tests/test.lst
index 897d41983..a278fa256 100644
--- a/compiler/luci/tests/test.lst
+++ b/compiler/luci/tests/test.lst
@@ -51,6 +51,8 @@ addread(ExpandDims_000)
addread(ExpandDims_001)
addread(ExpandDims_002)
addread(ExpandDims_003)
+addread(ExpandDims_004)
+addread(FakeQuant_000)
addread(Fill_000)
addread(Fill_001)
addread(Floor_000)
@@ -151,6 +153,7 @@ addread(SelectV2_002)
addread(Shape_000)
addread(Sin_000)
addread(Slice_000)
+addread(Slice_001)
addread(Softmax_000)
addread(Softmax_U8_000)
addread(SpaceToBatchND_000)
@@ -166,6 +169,7 @@ addread(Sqrt_000)
addread(Square_000)
addread(SquaredDifference_000)
addread(Squeeze_000)
+addread(Squeeze_001)
addread(StridedSlice_000)
addread(StridedSlice_001)
addread(StridedSlice_002)
@@ -268,6 +272,8 @@ addwrite(ExpandDims_000)
addwrite(ExpandDims_001)
addwrite(ExpandDims_002)
addwrite(ExpandDims_003)
+addwrite(ExpandDims_004)
+addwrite(FakeQuant_000)
addwrite(Fill_000)
addwrite(Fill_001)
addwrite(Floor_000)
@@ -367,6 +373,7 @@ addwrite(SelectV2_002)
addwrite(Shape_000)
addwrite(Sin_000)
addwrite(Slice_000)
+addwrite(Slice_001)
addwrite(Softmax_000)
addwrite(Softmax_U8_000)
addwrite(SpaceToBatchND_000)
@@ -382,6 +389,7 @@ addwrite(Sqrt_000)
addwrite(Square_000)
addwrite(SquaredDifference_000)
addwrite(Squeeze_000)
+addwrite(Squeeze_001)
addwrite(StridedSlice_000)
addwrite(StridedSlice_001)
addwrite(StridedSlice_002)