summaryrefslogtreecommitdiff
path: root/runtime/onert/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/onert/core/src')
-rw-r--r--runtime/onert/core/src/backend/BackendContext.cc30
-rw-r--r--runtime/onert/core/src/backend/IConstantInitializer.cc112
-rw-r--r--runtime/onert/core/src/backend/IPortableTensor.cc (renamed from runtime/onert/core/src/backend/cpu_common/Tensor.cc)21
-rw-r--r--runtime/onert/core/src/backend/ITensor.cc11
-rw-r--r--runtime/onert/core/src/backend/basic/Allocator.cc (renamed from runtime/onert/core/src/backend/cpu_common/Allocator.cc)6
-rw-r--r--runtime/onert/core/src/backend/basic/BackendContextHelpers.cc17
-rw-r--r--runtime/onert/core/src/backend/basic/DynamicTensorManager.cc53
-rw-r--r--runtime/onert/core/src/backend/basic/MemoryManager.cc (renamed from runtime/onert/core/src/backend/cpu_common/MemoryManager.cc)36
-rw-r--r--runtime/onert/core/src/backend/basic/MemoryPlanner.cc (renamed from runtime/onert/core/src/backend/cpu_common/MemoryPlanner.cc)35
-rw-r--r--runtime/onert/core/src/backend/basic/MemoryPlanner.h (renamed from runtime/onert/core/src/backend/cpu_common/MemoryPlanner.h)20
-rw-r--r--runtime/onert/core/src/backend/basic/MemoryPlanner.test.cc (renamed from runtime/onert/core/src/backend/cpu_common/MemoryPlanner.test.cc)8
-rw-r--r--runtime/onert/core/src/backend/basic/MemoryPlannerFactory.cc (renamed from runtime/onert/core/src/backend/cpu_common/MemoryPlannerFactory.cc)6
-rw-r--r--runtime/onert/core/src/backend/basic/MemoryPlannerFactory.h (renamed from runtime/onert/core/src/backend/cpu_common/MemoryPlannerFactory.h)15
-rw-r--r--runtime/onert/core/src/backend/basic/StaticTensorManager.cc (renamed from runtime/onert/core/src/backend/cpu_common/StaticTensorManager.cc)59
-rw-r--r--runtime/onert/core/src/backend/basic/Tensor.cc104
-rw-r--r--runtime/onert/core/src/backend/basic/TensorBuilder.cc91
-rw-r--r--runtime/onert/core/src/backend/basic/train/TrainableTensor.cc49
-rw-r--r--runtime/onert/core/src/backend/builtin/Backend.h (renamed from runtime/onert/core/src/backend/controlflow/Backend.h)45
-rw-r--r--runtime/onert/core/src/backend/builtin/BackendContext.cc58
-rw-r--r--runtime/onert/core/src/backend/builtin/BackendContext.h71
-rw-r--r--runtime/onert/core/src/backend/builtin/Config.cc (renamed from runtime/onert/core/src/backend/controlflow/Config.cc)8
-rw-r--r--runtime/onert/core/src/backend/builtin/Config.h (renamed from runtime/onert/core/src/backend/controlflow/Config.h)12
-rw-r--r--runtime/onert/core/src/backend/builtin/ConstantInitializer.h (renamed from runtime/onert/core/src/backend/controlflow/UserTensor.cc)23
-rw-r--r--runtime/onert/core/src/backend/builtin/DynamicTensorManager.h (renamed from runtime/onert/core/src/backend/controlflow/UserTensorRegistry.h)18
-rw-r--r--runtime/onert/core/src/backend/builtin/ExternalContext.h79
-rw-r--r--runtime/onert/core/src/backend/builtin/IOTensor.cc60
-rw-r--r--runtime/onert/core/src/backend/builtin/IOTensor.h114
-rw-r--r--runtime/onert/core/src/backend/builtin/KernelGenerator.cc159
-rw-r--r--runtime/onert/core/src/backend/builtin/KernelGenerator.h (renamed from runtime/onert/core/src/backend/controlflow/KernelGenerator.h)50
-rw-r--r--runtime/onert/core/src/backend/builtin/Tensor.h (renamed from runtime/onert/core/src/backend/controlflow/Tensor.h)15
-rw-r--r--runtime/onert/core/src/backend/builtin/TensorBuilder.cc (renamed from runtime/onert/core/src/backend/controlflow/TensorBuilder.cc)48
-rw-r--r--runtime/onert/core/src/backend/builtin/TensorBuilder.h (renamed from runtime/onert/core/src/backend/controlflow/TensorBuilder.h)46
-rw-r--r--runtime/onert/core/src/backend/builtin/TensorRegistry.h134
-rw-r--r--runtime/onert/core/src/backend/builtin/UserTensor.cc41
-rw-r--r--runtime/onert/core/src/backend/builtin/UserTensor.h63
-rw-r--r--runtime/onert/core/src/backend/builtin/kernel/IfLayer.cc81
-rw-r--r--runtime/onert/core/src/backend/builtin/kernel/IfLayer.h (renamed from runtime/onert/core/src/backend/controlflow/kernel/IfLayer.h)40
-rw-r--r--runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.cc316
-rw-r--r--runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.h150
-rw-r--r--runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc151
-rw-r--r--runtime/onert/core/src/backend/builtin/kernel/WhileLayer.h (renamed from runtime/onert/core/src/backend/controlflow/kernel/WhileLayer.h)39
-rw-r--r--runtime/onert/core/src/backend/builtin/train/BackendContext.cc78
-rw-r--r--runtime/onert/core/src/backend/builtin/train/BackendContext.h76
-rw-r--r--runtime/onert/core/src/backend/builtin/train/KernelGenerator.cc104
-rw-r--r--runtime/onert/core/src/backend/builtin/train/KernelGenerator.h75
-rw-r--r--runtime/onert/core/src/backend/builtin/train/Tensor.h40
-rw-r--r--runtime/onert/core/src/backend/builtin/train/TensorRegistry.h140
-rw-r--r--runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.cc87
-rw-r--r--runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.h61
-rw-r--r--runtime/onert/core/src/backend/controlflow/ConstantInitializer.h52
-rw-r--r--runtime/onert/core/src/backend/controlflow/DynamicTensorManager.cc144
-rw-r--r--runtime/onert/core/src/backend/controlflow/DynamicTensorManager.h72
-rw-r--r--runtime/onert/core/src/backend/controlflow/KernelGenerator.cc171
-rw-r--r--runtime/onert/core/src/backend/controlflow/TensorRegistry.h134
-rw-r--r--runtime/onert/core/src/backend/controlflow/UserTensor.h91
-rw-r--r--runtime/onert/core/src/backend/controlflow/kernel/IfLayer.cc128
-rw-r--r--runtime/onert/core/src/backend/controlflow/kernel/PermuteLayer.cc82
-rw-r--r--runtime/onert/core/src/backend/controlflow/kernel/PermuteLayer.h77
-rw-r--r--runtime/onert/core/src/backend/controlflow/kernel/WhileLayer.cc216
-rw-r--r--runtime/onert/core/src/backend/cpu_common/DynamicTensorManager.cc137
-rw-r--r--runtime/onert/core/src/compiler/BackendManager.cc130
-rw-r--r--runtime/onert/core/src/compiler/Compiler.cc333
-rw-r--r--runtime/onert/core/src/compiler/CompilerFactory.cc50
-rw-r--r--runtime/onert/core/src/compiler/CompilerHelpers.h52
-rw-r--r--runtime/onert/core/src/compiler/CompilerOptions.cc147
-rw-r--r--runtime/onert/core/src/compiler/ExecutorFactory.cc1035
-rw-r--r--runtime/onert/core/src/compiler/ExecutorFactory.h73
-rw-r--r--runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc152
-rw-r--r--runtime/onert/core/src/compiler/Fp32ToFp16Converter.h8
-rw-r--r--runtime/onert/core/src/compiler/HEScheduler.cc119
-rw-r--r--runtime/onert/core/src/compiler/HEScheduler.h50
-rw-r--r--runtime/onert/core/src/compiler/HEScheduler.test.cc572
-rw-r--r--runtime/onert/core/src/compiler/Linear.cc201
-rw-r--r--runtime/onert/core/src/compiler/Linear.h20
-rw-r--r--runtime/onert/core/src/compiler/LoweredGraph.cc578
-rw-r--r--runtime/onert/core/src/compiler/ManualScheduler.cc33
-rw-r--r--runtime/onert/core/src/compiler/ManualScheduler.h4
-rw-r--r--runtime/onert/core/src/compiler/MultiModelCompiler.cc230
-rw-r--r--runtime/onert/core/src/compiler/MultiModelCompiler.h68
-rw-r--r--runtime/onert/core/src/compiler/OperationLowerInfo.cc (renamed from runtime/onert/core/src/ir/operation/LowerInfo.cc)13
-rw-r--r--runtime/onert/core/src/compiler/OperationValidator.cc1053
-rw-r--r--runtime/onert/core/src/compiler/ParamChecker.h73
-rw-r--r--runtime/onert/core/src/compiler/PermuteFactor.cc28
-rw-r--r--runtime/onert/core/src/compiler/ShapeValidator.cc1132
-rw-r--r--runtime/onert/core/src/compiler/ShapeValidator.h (renamed from runtime/onert/core/src/compiler/OperationValidator.h)25
-rw-r--r--runtime/onert/core/src/compiler/StaticShapeInference.cc1096
-rw-r--r--runtime/onert/core/src/compiler/StaticShapeInferer.cc1487
-rw-r--r--runtime/onert/core/src/compiler/TensorBuilders.h78
-rw-r--r--runtime/onert/core/src/compiler/TensorRegistries.h34
-rw-r--r--runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc44
-rw-r--r--runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h9
-rw-r--r--runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc23
-rw-r--r--runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h2
-rw-r--r--runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc68
-rw-r--r--runtime/onert/core/src/compiler/pass/ConstantOutputPass.h63
-rw-r--r--runtime/onert/core/src/compiler/pass/IPass.h41
-rw-r--r--runtime/onert/core/src/compiler/pass/LoweredOperandPass.h8
-rw-r--r--runtime/onert/core/src/compiler/pass/LoweredOperationPass.h10
-rw-r--r--runtime/onert/core/src/compiler/pass/OddOutputPass.cc90
-rw-r--r--runtime/onert/core/src/compiler/pass/OddOutputPass.h89
-rw-r--r--runtime/onert/core/src/compiler/pass/OperandPass.cc2
-rw-r--r--runtime/onert/core/src/compiler/pass/OperationPass.cc4
-rw-r--r--runtime/onert/core/src/compiler/pass/OperationPass.h4
-rw-r--r--runtime/onert/core/src/compiler/pass/Pass.h6
-rw-r--r--runtime/onert/core/src/compiler/pass/PassRunner.cc45
-rw-r--r--runtime/onert/core/src/compiler/pass/PassRunner.h53
-rw-r--r--runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc102
-rw-r--r--runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h4
-rw-r--r--runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc102
-rw-r--r--runtime/onert/core/src/compiler/pass/PermutationInsertionPass.h4
-rw-r--r--runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc137
-rw-r--r--runtime/onert/core/src/compiler/pass/PermutationOperationPass.h3
-rw-r--r--runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc64
-rw-r--r--runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.h54
-rw-r--r--runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.test.cc47
-rw-r--r--runtime/onert/core/src/compiler/train/LoweredTrainableGraph.cc286
-rw-r--r--runtime/onert/core/src/compiler/train/StaticBackwardShapeInferer.cc151
-rw-r--r--runtime/onert/core/src/compiler/train/StaticBackwardShapeInferer.h80
-rw-r--r--runtime/onert/core/src/compiler/train/TensorRegistries.h114
-rw-r--r--runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc106
-rw-r--r--runtime/onert/core/src/compiler/train/TrainableOperationConverter.h61
-rw-r--r--runtime/onert/core/src/compiler/train/TrainingCompiler.cc310
-rw-r--r--runtime/onert/core/src/compiler/train/TrainingCompiler.h81
-rw-r--r--runtime/onert/core/src/compiler/train/UntrainableOperationConverter.cc53
-rw-r--r--runtime/onert/core/src/compiler/train/UntrainableOperationConverter.h52
-rw-r--r--runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc82
-rw-r--r--runtime/onert/core/src/compiler/train/pass/LossInsertionPass.h56
-rw-r--r--runtime/onert/core/src/compiler/train/pass/Pass.h62
-rw-r--r--runtime/onert/core/src/dumper/dot/DotBuilder.cc23
-rw-r--r--runtime/onert/core/src/dumper/dot/DotBuilder.h2
-rw-r--r--runtime/onert/core/src/dumper/dot/DotDumper.cc244
-rw-r--r--runtime/onert/core/src/dumper/dot/DotDumper.h37
-rw-r--r--runtime/onert/core/src/dumper/dot/DotSubgraphInfo.cc58
-rw-r--r--runtime/onert/core/src/dumper/dot/DotSubgraphInfo.h61
-rw-r--r--runtime/onert/core/src/dumper/dot/OperandNode.cc5
-rw-r--r--runtime/onert/core/src/dumper/dot/OperandNode.h1
-rw-r--r--runtime/onert/core/src/dumper/dot/OperationNode.cc5
-rw-r--r--runtime/onert/core/src/dumper/dot/OperationNode.h4
-rw-r--r--runtime/onert/core/src/dumper/h5/Dumper.cc (renamed from runtime/onert/core/src/compiler/ParamChecker.cc)21
-rw-r--r--runtime/onert/core/src/dumper/h5/Dumper.h51
-rw-r--r--runtime/onert/core/src/dumper/h5/MinMaxDumper.cc87
-rw-r--r--runtime/onert/core/src/dumper/h5/MinMaxDumper.h78
-rw-r--r--runtime/onert/core/src/dumper/text/GraphDumper.cc108
-rw-r--r--runtime/onert/core/src/dumper/text/GraphDumper.h62
-rw-r--r--runtime/onert/core/src/exec/DataflowExecutor.cc89
-rw-r--r--runtime/onert/core/src/exec/DataflowExecutor.h29
-rw-r--r--runtime/onert/core/src/exec/DynamicShapeInferer.cc (renamed from runtime/onert/core/src/exec/DynamicShapeInference.cc)471
-rw-r--r--runtime/onert/core/src/exec/EdgeTensor.cc55
-rw-r--r--runtime/onert/core/src/exec/EdgeTensor.h72
-rw-r--r--runtime/onert/core/src/exec/ExecTime.cc6
-rw-r--r--runtime/onert/core/src/exec/ExecTime.h4
-rw-r--r--runtime/onert/core/src/exec/ExecTime.test.cc106
-rw-r--r--runtime/onert/core/src/exec/Execution.cc210
-rw-r--r--runtime/onert/core/src/exec/Execution.test.cc783
-rw-r--r--runtime/onert/core/src/exec/ExecutionContext.cc34
-rw-r--r--runtime/onert/core/src/exec/ExecutionObservee.cc62
-rw-r--r--runtime/onert/core/src/exec/ExecutionObservee.h27
-rw-r--r--runtime/onert/core/src/exec/ExecutionObservers.cc138
-rw-r--r--runtime/onert/core/src/exec/ExecutionObservers.h91
-rw-r--r--runtime/onert/core/src/exec/ExecutorBase.cc228
-rw-r--r--runtime/onert/core/src/exec/ExecutorBase.h107
-rw-r--r--runtime/onert/core/src/exec/FunctionSequence.cc28
-rw-r--r--runtime/onert/core/src/exec/IPermuteFunction.cc320
-rw-r--r--runtime/onert/core/src/exec/IPermuteFunction.h393
-rw-r--r--runtime/onert/core/src/exec/IPermuteFunction.test.cc920
-rw-r--r--runtime/onert/core/src/exec/JSONExecTime.cc6
-rw-r--r--runtime/onert/core/src/exec/JSONExecTime.h18
-rw-r--r--runtime/onert/core/src/exec/LinearExecutor.cc61
-rw-r--r--runtime/onert/core/src/exec/LinearExecutor.h23
-rw-r--r--runtime/onert/core/src/exec/MinMaxData.cc135
-rw-r--r--runtime/onert/core/src/exec/MinMaxData.h75
-rw-r--r--runtime/onert/core/src/exec/MinMaxRecorder.cc161
-rw-r--r--runtime/onert/core/src/exec/MinMaxRecorder.h58
-rw-r--r--runtime/onert/core/src/exec/MultiModelExecutors.cc589
-rw-r--r--runtime/onert/core/src/exec/MultiModelExecutors.h152
-rw-r--r--runtime/onert/core/src/exec/ParallelExecutor.cc56
-rw-r--r--runtime/onert/core/src/exec/ParallelExecutor.h24
-rw-r--r--runtime/onert/core/src/exec/ParallelScheduler.cc4
-rw-r--r--runtime/onert/core/src/exec/SingleModelExecutors.cc170
-rw-r--r--runtime/onert/core/src/exec/SingleModelExecutors.h70
-rw-r--r--runtime/onert/core/src/exec/Sink.h199
-rw-r--r--runtime/onert/core/src/exec/Source.h208
-rw-r--r--runtime/onert/core/src/exec/ThreadPool.cc2
-rw-r--r--runtime/onert/core/src/exec/feature/MockTensor.test.h66
-rw-r--r--runtime/onert/core/src/exec/feature/nchw/Reader.h41
-rw-r--r--runtime/onert/core/src/exec/feature/nchw/Reader.test.cc85
-rw-r--r--runtime/onert/core/src/exec/feature/nchw/View.h4
-rw-r--r--runtime/onert/core/src/exec/feature/nchw/View.test.cc85
-rw-r--r--runtime/onert/core/src/exec/feature/nhwc/Reader.h40
-rw-r--r--runtime/onert/core/src/exec/feature/nhwc/Reader.test.cc86
-rw-r--r--runtime/onert/core/src/exec/feature/nhwc/View.h8
-rw-r--r--runtime/onert/core/src/exec/feature/nhwc/View.test.cc86
-rw-r--r--runtime/onert/core/src/exec/train/TrainableExecutor.cc225
-rw-r--r--runtime/onert/core/src/exec/train/TrainableExecutor.h143
-rw-r--r--runtime/onert/core/src/exec/train/TrainableExecutors.cc142
-rw-r--r--runtime/onert/core/src/exec/train/TrainableExecutors.h104
-rw-r--r--runtime/onert/core/src/exec/train/TrainableFnSequence.cc69
-rw-r--r--runtime/onert/core/src/exporter/CircleExporter.cc153
-rw-r--r--runtime/onert/core/src/exporter/TrainInfoBuilder.h116
-rw-r--r--runtime/onert/core/src/interp/Buffer.h91
-rw-r--r--runtime/onert/core/src/interp/ExecEnv.h212
-rw-r--r--runtime/onert/core/src/interp/InterpExecutor.cc126
-rw-r--r--runtime/onert/core/src/interp/InterpExecutor.h70
-rw-r--r--runtime/onert/core/src/interp/InterpOps.lst73
-rw-r--r--runtime/onert/core/src/interp/Interpreter.cc184
-rw-r--r--runtime/onert/core/src/interp/Interpreter.h64
-rw-r--r--runtime/onert/core/src/interp/Registration.h43
-rw-r--r--runtime/onert/core/src/interp/Tensor.cc53
-rw-r--r--runtime/onert/core/src/interp/Tensor.h184
-rw-r--r--runtime/onert/core/src/interp/operations/BinaryArithmeticOps.cc205
-rw-r--r--runtime/onert/core/src/interp/operations/Concat.cc147
-rw-r--r--runtime/onert/core/src/interp/operations/Conv2D.cc151
-rw-r--r--runtime/onert/core/src/interp/operations/DepthwiseConv2D.cc156
-rw-r--r--runtime/onert/core/src/interp/operations/ElementwiseActivations.cc161
-rw-r--r--runtime/onert/core/src/interp/operations/FullyConnected.cc136
-rw-r--r--runtime/onert/core/src/interp/operations/Gather.cc138
-rw-r--r--runtime/onert/core/src/interp/operations/InstanceNorm.cc121
-rw-r--r--runtime/onert/core/src/interp/operations/OperationUtil.h203
-rw-r--r--runtime/onert/core/src/interp/operations/Pad.cc106
-rw-r--r--runtime/onert/core/src/interp/operations/Pool2D.cc140
-rw-r--r--runtime/onert/core/src/interp/operations/Reshape.cc63
-rw-r--r--runtime/onert/core/src/interp/operations/Softmax.cc123
-rw-r--r--runtime/onert/core/src/interp/operations/TransposeConv.cc141
-rw-r--r--runtime/onert/core/src/ir/DataType.cc6
-rw-r--r--runtime/onert/core/src/ir/Graph.cc174
-rw-r--r--runtime/onert/core/src/ir/Graph.test.cc147
-rw-r--r--runtime/onert/core/src/ir/GraphIterator.cc121
-rw-r--r--runtime/onert/core/src/ir/GraphIterator.h90
-rw-r--r--runtime/onert/core/src/ir/LayoutSet.cc8
-rw-r--r--runtime/onert/core/src/ir/LayoutSet.h1
-rw-r--r--runtime/onert/core/src/ir/LayoutSet.test.cc67
-rw-r--r--runtime/onert/core/src/ir/MockNode.h47
-rw-r--r--runtime/onert/core/src/ir/OpSequence.cc95
-rw-r--r--runtime/onert/core/src/ir/OpSequences.cc124
-rw-r--r--runtime/onert/core/src/ir/Operand.cc6
-rw-r--r--runtime/onert/core/src/ir/Operand.test.cc86
-rw-r--r--runtime/onert/core/src/ir/OperandIndexSequence.cc13
-rw-r--r--runtime/onert/core/src/ir/OperandIndexSequence.test.cc52
-rw-r--r--runtime/onert/core/src/ir/Operands.cc2
-rw-r--r--runtime/onert/core/src/ir/Operands.test.cc45
-rw-r--r--runtime/onert/core/src/ir/Operation.cc21
-rw-r--r--runtime/onert/core/src/ir/Operation.test.cc98
-rw-r--r--runtime/onert/core/src/ir/OperationCloner.cc26
-rw-r--r--runtime/onert/core/src/ir/OperationCloner.h14
-rw-r--r--runtime/onert/core/src/ir/OperationDumper.cc285
-rw-r--r--runtime/onert/core/src/ir/OperationDumper.h6
-rw-r--r--runtime/onert/core/src/ir/OperationValidator.cc546
-rw-r--r--runtime/onert/core/src/ir/OperationValidator.h101
-rw-r--r--runtime/onert/core/src/ir/Operations.cc9
-rw-r--r--runtime/onert/core/src/ir/Operations.test.cc42
-rw-r--r--runtime/onert/core/src/ir/Padding.cc10
-rw-r--r--runtime/onert/core/src/ir/Shape.cc41
-rw-r--r--runtime/onert/core/src/ir/Shape.test.cc58
-rw-r--r--runtime/onert/core/src/ir/TypeInfo.cc2
-rw-r--r--runtime/onert/core/src/ir/operation/AddN.cc36
-rw-r--r--runtime/onert/core/src/ir/operation/ArgMinMax.cc (renamed from runtime/onert/core/src/ir/operation/ArgMax.cc)13
-rw-r--r--runtime/onert/core/src/ir/operation/BCQFullyConnected.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/BCQGather.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/BatchMatMul.cc2
-rw-r--r--runtime/onert/core/src/ir/operation/BatchToSpaceND.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/BinaryArithmetic.cc14
-rw-r--r--runtime/onert/core/src/ir/operation/BroadcastTo.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Bulk.cc36
-rw-r--r--runtime/onert/core/src/ir/operation/Comparison.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Concat.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Conv2D.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/ConvertFp16ToFp32.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/ConvertFp32ToFp16.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Custom.cc2
-rw-r--r--runtime/onert/core/src/ir/operation/DepthToSpace.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/DepthwiseConv2D.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/DetectionPostProcess.cc37
-rw-r--r--runtime/onert/core/src/ir/operation/Einsum.cc2
-rw-r--r--runtime/onert/core/src/ir/operation/ElementwiseActivation.cc30
-rw-r--r--runtime/onert/core/src/ir/operation/ElementwiseBinary.cc16
-rw-r--r--runtime/onert/core/src/ir/operation/ElementwiseUnary.cc42
-rw-r--r--runtime/onert/core/src/ir/operation/EmbeddingLookup.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/ExpandDims.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Fill.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/FullyConnected.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/FusedBatchNorm.cc2
-rw-r--r--runtime/onert/core/src/ir/operation/Gather.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/HashtableLookup.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/If.cc2
-rw-r--r--runtime/onert/core/src/ir/operation/InstanceNorm.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/L2Normalization.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/LSTM.cc13
-rw-r--r--runtime/onert/core/src/ir/operation/LocalResponseNormalization.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/LogSoftmax.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Loss.cc39
-rw-r--r--runtime/onert/core/src/ir/operation/MatrixBandPart.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/OneHot.cc2
-rw-r--r--runtime/onert/core/src/ir/operation/PReLU.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Pack.cc2
-rw-r--r--runtime/onert/core/src/ir/operation/Pad.cc2
-rw-r--r--runtime/onert/core/src/ir/operation/Permute.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Pool2D.cc12
-rw-r--r--runtime/onert/core/src/ir/operation/Pow.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/RNN.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Range.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Rank.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Reduce.cc20
-rw-r--r--runtime/onert/core/src/ir/operation/Reshape.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/ResizeBilinear.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/ResizeNearestNeighbor.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Reverse.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Select.cc2
-rw-r--r--runtime/onert/core/src/ir/operation/Shape.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Slice.cc2
-rw-r--r--runtime/onert/core/src/ir/operation/Softmax.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/SpaceToBatchND.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/SpaceToDepth.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Split.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/SplitV.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/SquaredDifference.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Squeeze.cc2
-rw-r--r--runtime/onert/core/src/ir/operation/StatelessRandomUniform.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/StridedSlice.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Tile.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/TopKV2.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Transpose.cc8
-rw-r--r--runtime/onert/core/src/ir/operation/TransposeConv.cc5
-rw-r--r--runtime/onert/core/src/ir/operation/Unpack.cc3
-rw-r--r--runtime/onert/core/src/ir/operation/While.cc3
-rw-r--r--runtime/onert/core/src/ir/train/LossCode.cc39
-rw-r--r--runtime/onert/core/src/ir/train/OptimizerCode.cc39
-rw-r--r--runtime/onert/core/src/ir/train/TrainableGraph.cc337
-rw-r--r--runtime/onert/core/src/ir/train/TrainableGraph.test.cc378
-rw-r--r--runtime/onert/core/src/ir/train/TrainingInfo.cc49
-rw-r--r--runtime/onert/core/src/ir/train/UseDefChain.cc52
-rw-r--r--runtime/onert/core/src/ir/train/UseDefGenerator.cc187
-rw-r--r--runtime/onert/core/src/ir/train/UseDefGenerator.h87
-rw-r--r--runtime/onert/core/src/ir/train/operation/BinaryArithmetic.cc49
-rw-r--r--runtime/onert/core/src/ir/train/operation/Conv2D.cc49
-rw-r--r--runtime/onert/core/src/ir/train/operation/DepthwiseConv2D.cc49
-rw-r--r--runtime/onert/core/src/ir/train/operation/ElementwiseActivation.cc49
-rw-r--r--runtime/onert/core/src/ir/train/operation/FullyConnected.cc49
-rw-r--r--runtime/onert/core/src/ir/train/operation/Loss.cc48
-rw-r--r--runtime/onert/core/src/ir/train/operation/Pad.cc46
-rw-r--r--runtime/onert/core/src/ir/train/operation/Permute.cc50
-rw-r--r--runtime/onert/core/src/ir/train/operation/Pool2D.cc49
-rw-r--r--runtime/onert/core/src/ir/train/operation/Reduce.cc49
-rw-r--r--runtime/onert/core/src/ir/train/operation/Reshape.cc49
-rw-r--r--runtime/onert/core/src/ir/train/operation/Softmax.cc49
-rw-r--r--runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc1239
-rw-r--r--runtime/onert/core/src/ir/verifier/Verifier.cc218
-rw-r--r--runtime/onert/core/src/ir/verifier/Verifier.h15
-rw-r--r--runtime/onert/core/src/ir/verifier/Verifier.test.cc93
-rw-r--r--runtime/onert/core/src/loader/BaseLoader.h1794
-rw-r--r--runtime/onert/core/src/loader/CircleLoader.cc239
-rw-r--r--runtime/onert/core/src/loader/ModelLoader.cc85
-rw-r--r--runtime/onert/core/src/loader/TFLiteLoader.cc167
-rw-r--r--runtime/onert/core/src/loader/TrainInfoLoader.cc139
-rw-r--r--runtime/onert/core/src/loader/tflite_schema.fbs1308
-rw-r--r--runtime/onert/core/src/loader/tflite_schema_generated.h11989
-rw-r--r--runtime/onert/core/src/odc/CodegenLoader.cc91
-rw-r--r--runtime/onert/core/src/odc/CodegenLoader.h96
-rw-r--r--runtime/onert/core/src/odc/CodegenManager.cc56
-rw-r--r--runtime/onert/core/src/odc/QuantizeManager.cc50
-rw-r--r--runtime/onert/core/src/odc/QuantizeManager.test.cc38
-rw-r--r--runtime/onert/core/src/odc/QuantizerLoader.cc104
-rw-r--r--runtime/onert/core/src/odc/QuantizerLoader.h89
-rw-r--r--runtime/onert/core/src/odc/QuantizerLoader.test.cc63
-rw-r--r--runtime/onert/core/src/util/ChromeTracingEventWriter.cc195
-rw-r--r--runtime/onert/core/src/util/ConfigSource.cc36
-rw-r--r--runtime/onert/core/src/util/EventCollector.cc77
-rw-r--r--runtime/onert/core/src/util/EventCollector.h70
-rw-r--r--runtime/onert/core/src/util/EventCollectorGlobal.cc93
-rw-r--r--runtime/onert/core/src/util/EventCollectorGlobal.h155
-rw-r--r--runtime/onert/core/src/util/EventRecorder.cc532
-rw-r--r--runtime/onert/core/src/util/EventRecorder.h63
-rw-r--r--runtime/onert/core/src/util/EventWriter.cc49
-rw-r--r--runtime/onert/core/src/util/EventWriter.h144
-rw-r--r--runtime/onert/core/src/util/GeneralConfigSource.cc45
-rw-r--r--runtime/onert/core/src/util/Index.test.cc (renamed from runtime/onert/core/src/util/EnvConfigSource.cc)34
-rw-r--r--runtime/onert/core/src/util/MDTableEventWriter.cc365
-rw-r--r--runtime/onert/core/src/util/ObjectManager.test.cc211
-rw-r--r--runtime/onert/core/src/util/SNPEEventWriter.cc186
-rw-r--r--runtime/onert/core/src/util/ShapeInference.cc262
-rw-r--r--runtime/onert/core/src/util/ShapeInference.test.cc544
-rw-r--r--runtime/onert/core/src/util/TracingCtx.cc30
381 files changed, 40641 insertions, 12463 deletions
diff --git a/runtime/onert/core/src/backend/BackendContext.cc b/runtime/onert/core/src/backend/BackendContext.cc
index bafa36d28..7b36f106d 100644
--- a/runtime/onert/core/src/backend/BackendContext.cc
+++ b/runtime/onert/core/src/backend/BackendContext.cc
@@ -16,40 +16,10 @@
#include "backend/BackendContext.h"
-#include "ir/Operation.h"
-#include "backend/IConstantInitializer.h"
-
namespace onert
{
namespace backend
{
-void BackendContext::initialize(const std::vector<OperationInfo> &operation_list,
- const std::vector<ir::OperandIndex> &operand_list)
-{
- _operation_list = operation_list;
- _operand_list = operand_list;
-}
-
-void BackendContext::initConsts()
-{
- for (auto &op : _operation_list)
- {
- constant_initializer->setLayout(op.layout);
- _graph->operations().at(op.index).accept(*constant_initializer);
- }
-
- for (auto ind : _operand_list)
- {
- const auto &obj = _graph->operands().at(ind);
- if (obj.isConstant() && !constant_initializer->exist(ind))
- {
- constant_initializer->registerDefaultInitializer(ind, obj);
- }
- }
-
- constant_initializer->run();
-}
-
} // namespace backend
} // namespace onert
diff --git a/runtime/onert/core/src/backend/IConstantInitializer.cc b/runtime/onert/core/src/backend/IConstantInitializer.cc
deleted file mode 100644
index 934a42753..000000000
--- a/runtime/onert/core/src/backend/IConstantInitializer.cc
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "backend/IConstantInitializer.h"
-
-#include <Half.h>
-
-using float16 = Half;
-
-namespace onert
-{
-namespace backend
-{
-
-void IConstantInitializer::registerCopyInitializer(const ir::OperandIndex &index,
- const ir::Operand &obj)
-{
- // For only CONSTANTS
- // TODO Add to check if tensor has been allocated
- if (!obj.isConstant())
- return;
-
- const auto type = obj.typeInfo().type();
- using ir::DataType;
-
- switch (type)
- {
- case DataType::FLOAT32:
- _init_map[index] = copyInit<float>;
- break;
- case DataType::INT32:
- _init_map[index] = copyInit<int32_t>;
- break;
- case DataType::UINT32:
- _init_map[index] = copyInit<uint32_t>;
- break;
- case DataType::BOOL8:
- case DataType::QUANT_UINT8_ASYMM:
- _init_map[index] = copyInit<uint8_t>;
- break;
- case DataType::QUANT_INT8_SYMM:
- _init_map[index] = copyInit<int8_t>;
- break;
- case DataType::FLOAT16:
- _init_map[index] = copyInit<float16>;
- break;
- case DataType::INT64:
- _init_map[index] = copyInit<int64_t>;
- break;
- default:
- throw std::runtime_error("Not supported, yet");
- break;
- }
-}
-
-void IConstantInitializer::registerPermuteInitializer(const ir::OperandIndex &index,
- const ir::Operand &obj)
-{
- // For only CONSTANTS
- // TODO Add to check if tensor has been allocated
- if (!obj.isConstant())
- return;
-
- const auto type = obj.typeInfo().type();
- using ir::DataType;
- using namespace std::placeholders;
-
- switch (type)
- {
- case DataType::FLOAT32:
- _init_map[index] = std::bind(permuteInit<float>, _1, _2, _current_op_seq_layout);
- break;
- case DataType::INT32:
- _init_map[index] = std::bind(permuteInit<int32_t>, _1, _2, _current_op_seq_layout);
- break;
- case DataType::UINT32:
- _init_map[index] = std::bind(permuteInit<uint32_t>, _1, _2, _current_op_seq_layout);
- break;
- case DataType::BOOL8:
- case DataType::QUANT_UINT8_ASYMM:
- _init_map[index] = std::bind(permuteInit<uint8_t>, _1, _2, _current_op_seq_layout);
- break;
- case DataType::QUANT_INT8_SYMM:
- _init_map[index] = std::bind(permuteInit<int8_t>, _1, _2, _current_op_seq_layout);
- break;
- case DataType::FLOAT16:
- _init_map[index] = std::bind(permuteInit<float16>, _1, _2, _current_op_seq_layout);
- break;
- case DataType::INT64:
- _init_map[index] = std::bind(permuteInit<int64_t>, _1, _2, _current_op_seq_layout);
- break;
- default:
- throw std::runtime_error("Not supported, yet");
- break;
- }
-}
-
-} // namespace backend
-} // namespace onert
diff --git a/runtime/onert/core/src/backend/cpu_common/Tensor.cc b/runtime/onert/core/src/backend/IPortableTensor.cc
index f34564dd9..066ba0004 100644
--- a/runtime/onert/core/src/backend/cpu_common/Tensor.cc
+++ b/runtime/onert/core/src/backend/IPortableTensor.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,30 +14,31 @@
* limitations under the License.
*/
-#include "backend/cpu_common/Tensor.h"
+#include "backend/IPortableTensor.h"
namespace onert
{
namespace backend
{
-namespace cpu_common
-{
-size_t Tensor::calcOffset(const ir::Coordinates &coords) const
+// `dynamic_cast` not working across library boundaries on NDK
+// With this as a key function, `dynamic_cast` works across dl
+IPortableTensor::~IPortableTensor() {}
+
+size_t IPortableTensor::calcOffset(const ir::Coordinates &coords) const
{
- size_t rank = num_dimensions();
+ auto shape = _info.shape();
+ size_t rank = shape.rank();
rank = rank == 0 ? 1 : rank;
size_t offset = 0;
for (size_t i = 0; i < rank; ++i)
{
- offset = offset * dimension(i) + coords[i];
+ auto dim = shape.rank() == 0 ? 1 : shape.dim(i);
+ offset = offset * dim + coords[i];
}
offset *= sizeOfDataType(data_type());
return offset;
}
-void Tensor::setShape(const ir::Shape &new_shape) { _info.shape(new_shape); }
-
-} // namespace cpu_common
} // namespace backend
} // namespace onert
diff --git a/runtime/onert/core/src/backend/ITensor.cc b/runtime/onert/core/src/backend/ITensor.cc
index 7127ed93d..1339cb409 100644
--- a/runtime/onert/core/src/backend/ITensor.cc
+++ b/runtime/onert/core/src/backend/ITensor.cc
@@ -21,14 +21,9 @@ namespace onert
namespace backend
{
-ir::Shape ITensor::getShape() const
-{
- onert::ir::Shape shape(num_dimensions());
- for (uint32_t d = 0; d < num_dimensions(); d++)
- shape.dim(d) = dimension(d);
-
- return shape;
-}
+// `dynamic_cast` not working across library boundaries on NDK
+// With this as a key function, `dynamic_cast` works across dl
+ITensor::~ITensor() {}
} // namespace backend
} // namespace onert
diff --git a/runtime/onert/core/src/backend/cpu_common/Allocator.cc b/runtime/onert/core/src/backend/basic/Allocator.cc
index 0ba444ee6..61214dfad 100644
--- a/runtime/onert/core/src/backend/cpu_common/Allocator.cc
+++ b/runtime/onert/core/src/backend/basic/Allocator.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "backend/cpu_common/Allocator.h"
+#include "backend/basic/Allocator.h"
#include "util/logging.h"
@@ -22,7 +22,7 @@ namespace onert
{
namespace backend
{
-namespace cpu_common
+namespace basic
{
Allocator::Allocator(uint32_t capacity)
@@ -33,6 +33,6 @@ Allocator::Allocator(uint32_t capacity)
VERBOSE(ALLOC) << "base pointer: " << static_cast<void *>(_base.get()) << std::endl;
}
-} // namespace cpu_common
+} // namespace basic
} // namespace backend
} // namespace onert
diff --git a/runtime/onert/core/src/backend/basic/BackendContextHelpers.cc b/runtime/onert/core/src/backend/basic/BackendContextHelpers.cc
new file mode 100644
index 000000000..c02cc0cf2
--- /dev/null
+++ b/runtime/onert/core/src/backend/basic/BackendContextHelpers.cc
@@ -0,0 +1,17 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "backend/basic/BackendContextHelpers.h"
diff --git a/runtime/onert/core/src/backend/basic/DynamicTensorManager.cc b/runtime/onert/core/src/backend/basic/DynamicTensorManager.cc
new file mode 100644
index 000000000..07bcb09ee
--- /dev/null
+++ b/runtime/onert/core/src/backend/basic/DynamicTensorManager.cc
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "backend/basic/DynamicTensorManager.h"
+
+#include "util/logging.h"
+#include "misc/polymorphic_downcast.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace basic
+{
+
+DynamicTensorManager::DynamicTensorManager(const std::shared_ptr<TensorRegistry> &reg)
+ : _dynamic_mem_mgr{new DynamicMemoryManager()}, _tensors{reg}
+{
+ // DO NOTHING
+}
+
+void DynamicTensorManager::buildTensor(const ir::OperandIndex &ind,
+ const ir::OperandInfo &tensor_info,
+ ir::Layout backend_layout)
+{
+ assert(_tensors->getNativeTensor(ind) == nullptr);
+ auto tensor = std::make_unique<Tensor>(tensor_info, backend_layout, _dynamic_mem_mgr.get());
+ _tensors->setNativeTensor(ind, std::move(tensor));
+}
+
+const ITensor *DynamicTensorManager::getRawITensor(ir::OperandIndex ind)
+{
+ auto ptr = _tensors->getITensor(ind);
+ assert(ptr);
+ return ptr;
+}
+
+} // namespace basic
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/cpu_common/MemoryManager.cc b/runtime/onert/core/src/backend/basic/MemoryManager.cc
index 8cb9c22ca..48144561b 100644
--- a/runtime/onert/core/src/backend/cpu_common/MemoryManager.cc
+++ b/runtime/onert/core/src/backend/basic/MemoryManager.cc
@@ -14,18 +14,19 @@
* limitations under the License.
*/
-#include <backend/cpu_common/MemoryManager.h>
+#include <backend/basic/MemoryManager.h>
#include <cassert>
#include "MemoryPlannerFactory.h"
#include "util/ConfigSource.h"
+#include "util/logging.h"
namespace onert
{
namespace backend
{
-namespace cpu_common
+namespace basic
{
MemoryManager::MemoryManager() : _mem_planner{createMemoryPlanner()}
@@ -34,20 +35,21 @@ MemoryManager::MemoryManager() : _mem_planner{createMemoryPlanner()}
}
MemoryManager::MemoryManager(const std::string planner_id)
- : _mem_planner{createMemoryPlanner(planner_id)}
+ : _mem_planner{createMemoryPlanner(planner_id)}
{
// DO NOTHING
}
-cpu_common::IMemoryPlanner *MemoryManager::createMemoryPlanner()
+basic::IMemoryPlanner<ir::OperandIndex> *MemoryManager::createMemoryPlanner()
{
auto planner_id = util::getConfigString(util::config::CPU_MEMORY_PLANNER);
- return cpu_common::MemoryPlannerFactory::get().create(planner_id);
+ return basic::MemoryPlannerFactory::get().create(planner_id);
}
-cpu_common::IMemoryPlanner *MemoryManager::createMemoryPlanner(const std::string planner_id)
+basic::IMemoryPlanner<ir::OperandIndex> *
+MemoryManager::createMemoryPlanner(const std::string planner_id)
{
- return cpu_common::MemoryPlannerFactory::get().create(planner_id);
+ return basic::MemoryPlannerFactory::get().create(planner_id);
}
void MemoryManager::claimPlan(const ir::OperandIndex &ind, uint32_t size)
@@ -59,7 +61,7 @@ void MemoryManager::releasePlan(const ir::OperandIndex &ind) { _mem_planner->rel
void MemoryManager::allocate(void)
{
- _mem_alloc = std::make_shared<cpu_common::Allocator>(_mem_planner->capacity());
+ _mem_alloc = std::make_shared<basic::Allocator>(_mem_planner->capacity());
assert(_mem_alloc->base());
}
@@ -70,20 +72,20 @@ uint8_t *MemoryManager::getBuffer(const ir::OperandIndex &ind) const
return _mem_alloc->base() + mem_blk.offset;
}
-std::shared_ptr<cpu_common::Allocator> DynamicMemoryManager::allocate(const ir::OperandIndex &ind,
- uint32_t capacity)
+std::shared_ptr<basic::Allocator> DynamicMemoryManager::allocate(const ITensor *tensor,
+ uint32_t capacity)
{
- auto find = _mem_alloc_map.find(ind);
+ auto find = _mem_alloc_map.find(tensor);
if (find != _mem_alloc_map.end())
throw std::runtime_error("Cannot allocate memory for a tensor. It was already allocated.");
- _mem_alloc_map[ind] = std::make_shared<cpu_common::Allocator>(capacity);
- return _mem_alloc_map[ind];
+ _mem_alloc_map[tensor] = std::make_shared<basic::Allocator>(capacity);
+ return _mem_alloc_map[tensor];
}
-void DynamicMemoryManager::deallocate(const ir::OperandIndex &ind)
+void DynamicMemoryManager::deallocate(const ITensor *tensor)
{
- auto find = _mem_alloc_map.find(ind);
+ auto find = _mem_alloc_map.find(tensor);
if (find == _mem_alloc_map.end())
throw std::runtime_error("Cannot find Allocator for the requested index");
@@ -93,7 +95,7 @@ void DynamicMemoryManager::deallocate(const ir::OperandIndex &ind)
void DynamicMemoryManager::deallocate(void)
{
- for (auto &mem_alloc : _mem_alloc_map)
+ for (auto &&mem_alloc : _mem_alloc_map)
{
// Release memory buffer of mem_alloc
mem_alloc.second->release();
@@ -102,6 +104,6 @@ void DynamicMemoryManager::deallocate(void)
_mem_alloc_map.clear();
}
-} // namespace cpu_common
+} // namespace basic
} // namespace backend
} // namespace onert
diff --git a/runtime/onert/core/src/backend/cpu_common/MemoryPlanner.cc b/runtime/onert/core/src/backend/basic/MemoryPlanner.cc
index 75c2da7d2..1c048043c 100644
--- a/runtime/onert/core/src/backend/cpu_common/MemoryPlanner.cc
+++ b/runtime/onert/core/src/backend/basic/MemoryPlanner.cc
@@ -22,24 +22,21 @@ namespace onert
{
namespace backend
{
-namespace cpu_common
+namespace basic
{
void BumpPlanner::claim(const ir::OperandIndex &ind, size_t size)
{
- assert(size != 0);
-
Block blk{_capacity, size};
_mem_plans[ind] = blk;
_capacity += size;
- VERBOSE(BP_PLANNER) << "CLAIM(#" << ind.value() << "): " << blk.offset << ", " << blk.size
- << std::endl;
+ VERBOSE(BP_PLANNER) << "CLAIM(" << ind << "): " << blk.offset << ", " << blk.size << std::endl;
}
void BumpPlanner::release(const ir::OperandIndex &ind)
{
- VERBOSE(BP_PLANNER) << "RELEASE(#" << ind.value() << "): "
+ VERBOSE(BP_PLANNER) << "RELEASE(" << ind << "): "
<< "NOTHING does" << std::endl;
}
@@ -59,11 +56,9 @@ void BumpPlanner::release(const ir::OperandIndex &ind)
// the previous claim_base_offset.
void FirstFitPlanner::claim(const ir::OperandIndex &ind, size_t size)
{
- assert(size != 0);
-
// Find the right position for claiming
uint32_t next_offset = 0;
- for (auto &mem_claim : _claim_table)
+ for (const auto &mem_claim : _claim_table)
{
auto claimed_base_offset = mem_claim.first;
auto claimed_size = _mem_plans[mem_claim.second].size;
@@ -81,7 +76,7 @@ void FirstFitPlanner::claim(const ir::OperandIndex &ind, size_t size)
_claim_table[next_offset] = ind;
_mem_plans[ind] = {next_offset, size};
- VERBOSE(FF_PLANNER) << "claim(#" << ind.value() << "): [+" << next_offset << ", " << size << "sz]"
+ VERBOSE(FF_PLANNER) << "claim(" << ind << "): [+" << next_offset << ", " << size << "sz]"
<< std::endl;
if (_capacity < next_offset + size)
@@ -102,7 +97,7 @@ void FirstFitPlanner::release(const ir::OperandIndex &ind)
_claim_table.erase(it);
- VERBOSE(FF_PLANNER) << "release(#" << index << "): [+" << offset << ", " << size << "sz]"
+ VERBOSE(FF_PLANNER) << "release(" << index << "): [+" << offset << ", " << size << "sz]"
<< std::endl;
return;
}
@@ -111,16 +106,14 @@ void FirstFitPlanner::release(const ir::OperandIndex &ind)
}
WICPlanner::WICPlanner()
- : _initialized(false), _capacity(0), _mem_plans(), _live_operands(), _interference_graph(),
- _operands()
+ : _initialized(false), _capacity(0), _mem_plans(), _live_operands(), _interference_graph(),
+ _operands()
{
// DO NOTHING
}
void WICPlanner::claim(const ir::OperandIndex &ind, size_t size)
{
- assert(size != 0);
-
_operands.emplace(size, ind);
_interference_graph[ind].insert(_interference_graph[ind].end(), _live_operands.cbegin(),
_live_operands.cend());
@@ -130,13 +123,13 @@ void WICPlanner::claim(const ir::OperandIndex &ind, size_t size)
}
_live_operands.emplace(ind);
- VERBOSE(WIC_PLANNER) << "claim(#" << ind.value() << "): [" << size << "sz]" << std::endl;
+ VERBOSE(WIC_PLANNER) << "claim(" << ind << "): [" << size << "sz]" << std::endl;
}
void WICPlanner::release(const ir::OperandIndex &ind)
{
_live_operands.erase(ind);
- VERBOSE(WIC_PLANNER) << "release(#" << ind.value() << ")" << std::endl;
+ VERBOSE(WIC_PLANNER) << "release(" << ind << ")" << std::endl;
}
/*
@@ -154,7 +147,7 @@ void WICPlanner::buildMemoryPlans()
{
uint32_t size = operand.first;
const ir::OperandIndex &ind = operand.second;
- VERBOSE(WIC_PLANNER) << "build_plan(#" << ind.value() << "): [" << size << "sz]" << std::endl;
+ VERBOSE(WIC_PLANNER) << "build_plan(" << ind << "): [" << size << "sz]" << std::endl;
uint32_t next_offset = 0;
if (_interference_graph.count(ind))
@@ -190,8 +183,8 @@ void WICPlanner::buildMemoryPlans()
}
_mem_plans[ind] = {next_offset, size};
- VERBOSE(WIC_PLANNER) << "alloc(#" << ind.value() << "): [+" << next_offset << ", " << size
- << "sz]" << std::endl;
+ VERBOSE(WIC_PLANNER) << "alloc(" << ind << "): [+" << next_offset << ", " << size << "sz]"
+ << std::endl;
if (_capacity < next_offset + size)
{
@@ -210,6 +203,6 @@ WICPlanner::MemoryPlans &WICPlanner::memory_plans()
return _mem_plans;
}
-} // namespace cpu_common
+} // namespace basic
} // namespace backend
} // namespace onert
diff --git a/runtime/onert/core/src/backend/cpu_common/MemoryPlanner.h b/runtime/onert/core/src/backend/basic/MemoryPlanner.h
index 7c387e542..03e977500 100644
--- a/runtime/onert/core/src/backend/cpu_common/MemoryPlanner.h
+++ b/runtime/onert/core/src/backend/basic/MemoryPlanner.h
@@ -19,29 +19,29 @@
* @brief       This file contains Memory Planning related classes
*/
-#ifndef __ONERT_BACKEND_CPU_COMMON_MEMORY_PLANNER_H__
-#define __ONERT_BACKEND_CPU_COMMON_MEMORY_PLANNER_H__
+#ifndef __ONERT_BACKEND_BASIC_MEMORY_PLANNER_H__
+#define __ONERT_BACKEND_BASIC_MEMORY_PLANNER_H__
#include <map>
#include <vector>
#include <unordered_set>
#include <memory>
-#include "backend/cpu_common/Allocator.h"
-#include "backend/cpu_common/IMemoryPlanner.h"
+#include "backend/basic/Allocator.h"
+#include "backend/basic/IMemoryPlanner.h"
#include "ir/OperandIndexMap.h"
namespace onert
{
namespace backend
{
-namespace cpu_common
+namespace basic
{
/**
* @brief Class to plan memory by bump way
*/
-class BumpPlanner : public IMemoryPlanner
+class BumpPlanner : public IMemoryPlanner<ir::OperandIndex>
{
public:
/**
@@ -74,7 +74,7 @@ private:
/**
* @brief Class to plan memory by firstfit way
*/
-class FirstFitPlanner : public IMemoryPlanner
+class FirstFitPlanner : public IMemoryPlanner<ir::OperandIndex>
{
public:
/**
@@ -109,7 +109,7 @@ private:
/**
* @brief Class to plan memory by Weighted Interval Color algorithm
*/
-class WICPlanner : public IMemoryPlanner
+class WICPlanner : public IMemoryPlanner<ir::OperandIndex>
{
public:
WICPlanner();
@@ -153,8 +153,8 @@ private:
std::multimap<uint32_t, ir::OperandIndex, std::greater<uint32_t>> _operands;
};
-} // namespace cpu_common
+} // namespace basic
} // namespace backend
} // namespace onert
-#endif // __ONERT_BACKEND_CPU_COMMON_MEMORY_PLANNER_H__
+#endif // __ONERT_BACKEND_BASIC_MEMORY_PLANNER_H__
diff --git a/runtime/onert/core/src/backend/cpu_common/MemoryPlanner.test.cc b/runtime/onert/core/src/backend/basic/MemoryPlanner.test.cc
index 5208a94d4..a32228cbe 100644
--- a/runtime/onert/core/src/backend/cpu_common/MemoryPlanner.test.cc
+++ b/runtime/onert/core/src/backend/basic/MemoryPlanner.test.cc
@@ -21,13 +21,13 @@
TEST(Allocator, allocate_test)
{
- ::onert::backend::cpu_common::Allocator allocator(1024);
+ ::onert::backend::basic::Allocator allocator(1024);
ASSERT_NE(allocator.base(), nullptr);
}
TEST(BumpPlanner, claim_test)
{
- ::onert::backend::cpu_common::BumpPlanner planner;
+ ::onert::backend::basic::BumpPlanner planner;
auto claim = [&planner](uint32_t index, size_t size, uint32_t expected_offset) {
onert::ir::OperandIndex mem_idx(index);
@@ -44,7 +44,7 @@ TEST(BumpPlanner, claim_test)
TEST(FirstFitPlanner, claim_release_test)
{
- ::onert::backend::cpu_common::FirstFitPlanner planner;
+ ::onert::backend::basic::FirstFitPlanner planner;
auto claim = [&planner](uint32_t index, size_t size, uint32_t expected_offset) {
onert::ir::OperandIndex mem_idx(index);
@@ -128,7 +128,7 @@ TEST(FirstFitPlanner, claim_release_test)
TEST(WICPlanner, claim_release_test)
{
- ::onert::backend::cpu_common::WICPlanner planner;
+ ::onert::backend::basic::WICPlanner planner;
auto claim = [&planner](uint32_t index, size_t size) {
onert::ir::OperandIndex mem_idx(index);
diff --git a/runtime/onert/core/src/backend/cpu_common/MemoryPlannerFactory.cc b/runtime/onert/core/src/backend/basic/MemoryPlannerFactory.cc
index ead4f3294..7338f87b6 100644
--- a/runtime/onert/core/src/backend/cpu_common/MemoryPlannerFactory.cc
+++ b/runtime/onert/core/src/backend/basic/MemoryPlannerFactory.cc
@@ -22,7 +22,7 @@ namespace onert
{
namespace backend
{
-namespace cpu_common
+namespace basic
{
MemoryPlannerFactory &MemoryPlannerFactory::get()
@@ -31,7 +31,7 @@ MemoryPlannerFactory &MemoryPlannerFactory::get()
return instance;
}
-IMemoryPlanner *MemoryPlannerFactory::create(const std::string &key)
+IMemoryPlanner<ir::OperandIndex> *MemoryPlannerFactory::create(const std::string &key)
{
if (key == "FirstFit")
{
@@ -48,6 +48,6 @@ IMemoryPlanner *MemoryPlannerFactory::create(const std::string &key)
return new FirstFitPlanner; // Default Planner
}
-} // namespace cpu_common
+} // namespace basic
} // namespace backend
} // namespace onert
diff --git a/runtime/onert/core/src/backend/cpu_common/MemoryPlannerFactory.h b/runtime/onert/core/src/backend/basic/MemoryPlannerFactory.h
index d14ec13ca..b4173f749 100644
--- a/runtime/onert/core/src/backend/cpu_common/MemoryPlannerFactory.h
+++ b/runtime/onert/core/src/backend/basic/MemoryPlannerFactory.h
@@ -14,10 +14,11 @@
* limitations under the License.
*/
-#ifndef __ONERT_BACKEND_CPU_COMMON_MEMORY_PLANNER_FACTORY_H__
-#define __ONERT_BACKEND_CPU_COMMON_MEMORY_PLANNER_FACTORY_H__
+#ifndef __ONERT_BACKEND_BASIC_MEMORY_PLANNER_FACTORY_H__
+#define __ONERT_BACKEND_BASIC_MEMORY_PLANNER_FACTORY_H__
-#include "backend/cpu_common/IMemoryPlanner.h"
+#include "backend/basic/IMemoryPlanner.h"
+#include "MemoryPlanner.h"
#include <string>
@@ -25,7 +26,7 @@ namespace onert
{
namespace backend
{
-namespace cpu_common
+namespace basic
{
class MemoryPlannerFactory
@@ -37,11 +38,11 @@ private:
MemoryPlannerFactory() = default;
public:
- IMemoryPlanner *create(const std::string &key);
+ IMemoryPlanner<ir::OperandIndex> *create(const std::string &key);
};
-} // namespace cpu_common
+} // namespace basic
} // namespace backend
} // namespace onert
-#endif // __ONERT_BACKEND_CPU_COMMON_MEMORY_PLANNER_FACTORY_H__
+#endif // __ONERT_BACKEND_BASIC_MEMORY_PLANNER_FACTORY_H__
diff --git a/runtime/onert/core/src/backend/cpu_common/StaticTensorManager.cc b/runtime/onert/core/src/backend/basic/StaticTensorManager.cc
index 440f70c93..04dbc4a6b 100644
--- a/runtime/onert/core/src/backend/cpu_common/StaticTensorManager.cc
+++ b/runtime/onert/core/src/backend/basic/StaticTensorManager.cc
@@ -14,65 +14,55 @@
* limitations under the License.
*/
-#include "backend/cpu_common/StaticTensorManager.h"
+#include "backend/basic/StaticTensorManager.h"
-#include "backend/cpu_common/DynamicTensorManager.h"
+#include "backend/basic/DynamicTensorManager.h"
+#include "backend/basic/Tensor.h"
#include <util/logging.h>
namespace onert
{
namespace backend
{
-namespace cpu_common
+namespace basic
{
StaticTensorManager::StaticTensorManager(const std::shared_ptr<TensorRegistry> &reg,
- IDynamicTensorManager *dynamic_tensor_manager)
- : _const_mgr{new DynamicMemoryManager()}, _nonconst_mgr{new MemoryManager()}, _tensors{reg},
- _dynamic_tensor_manager{dynamic_tensor_manager}
+ DynamicTensorManager *dynamic_tensor_manager)
+ : _nonconst_mgr{new MemoryManager()}, _tensors{reg},
+ _dynamic_tensor_manager{dynamic_tensor_manager}
{
// DO NOTHING
}
-void StaticTensorManager::allocateConsts(void)
+StaticTensorManager::StaticTensorManager(const std::shared_ptr<TensorRegistry> &reg,
+ const std::string planner_id,
+ DynamicTensorManager *dynamic_tensor_manager)
+ : _nonconst_mgr{new MemoryManager(planner_id)}, _tensors{reg},
+ _dynamic_tensor_manager{dynamic_tensor_manager}
{
- for (auto &pair : _tensors->native_tensors())
- {
- const auto &ind = pair.first;
- auto tensor = pair.second;
- if (_as_constants[ind])
- {
- auto mem_alloc = _const_mgr->allocate(ind, tensor->total_size());
- tensor->setBuffer(mem_alloc);
- auto buffer = mem_alloc->base();
- VERBOSE(CPU_COMMON_StaticTensorManager) << "CONSTANT TENSOR(#" << ind.value()
- << "): " << static_cast<void *>(buffer)
- << "size : " << tensor->total_size() << std::endl;
- }
- }
+ // DO NOTHING
}
void StaticTensorManager::allocateNonconsts(void)
{
_nonconst_mgr->allocate();
- for (auto &pair : _tensors->native_tensors())
+ for (auto &&pair : _tensors->native_tensors())
{
const auto &ind = pair.first;
- auto tensor = pair.second;
+ auto tensor = pair.second.get();
if (!_as_constants[ind] && !tensor->is_dynamic())
{
auto *buffer = _nonconst_mgr->getBuffer(ind);
tensor->setBuffer(buffer);
- VERBOSE(CPU_COMMON_StaticTensorManager) << "TENSOR(#" << ind.value()
- << "): " << static_cast<void *>(buffer) << std::endl;
+ VERBOSE(CPU_StaticTensorManager)
+ << "TENSOR " << ind << " : " << static_cast<void *>(buffer) << std::endl;
}
}
}
-void StaticTensorManager::deallocateConsts(void) { _const_mgr->deallocate(); }
-
void StaticTensorManager::deallocateNonconsts(void) { _nonconst_mgr->deallocate(); }
void StaticTensorManager::buildTensor(const ir::OperandIndex &ind,
@@ -80,8 +70,17 @@ void StaticTensorManager::buildTensor(const ir::OperandIndex &ind,
bool as_const)
{
assert(!_tensors->getNativeTensor(ind));
- auto tensor = std::make_shared<Tensor>(tensor_info, backend_layout, _dynamic_tensor_manager);
- _tensors->setNativeTensor(ind, tensor);
+ if (as_const)
+ {
+ auto tensor = std::make_unique<ExternalTensor>(tensor_info, backend_layout);
+ _tensors->setNativeTensor(ind, std::move(tensor));
+ }
+ else
+ {
+ auto tensor = std::make_unique<Tensor>(tensor_info, backend_layout,
+ _dynamic_tensor_manager->dynamic_mem_mgr().get());
+ _tensors->setNativeTensor(ind, std::move(tensor));
+ }
_as_constants[ind] = as_const;
}
@@ -113,6 +112,6 @@ void StaticTensorManager::iterate(const std::function<void(const ir::OperandInde
fn(it.first);
}
-} // namespace cpu_common
+} // namespace basic
} // namespace backend
} // namespace onert
diff --git a/runtime/onert/core/src/backend/basic/Tensor.cc b/runtime/onert/core/src/backend/basic/Tensor.cc
new file mode 100644
index 000000000..7f33d4d74
--- /dev/null
+++ b/runtime/onert/core/src/backend/basic/Tensor.cc
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "backend/basic/Tensor.h"
+
+#include "ir/DataType.h"
+#include "backend/basic/MemoryManager.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace basic
+{
+
+Tensor::~Tensor() {}
+
+void Tensor::setShape(const ir::Shape &new_shape) { _info.shape(new_shape); }
+
+bool Tensor::applyShape(const ir::Shape &new_shape)
+{
+ bool previously_dynamic = is_dynamic();
+
+ auto allocTensorMem = [&]() {
+ auto capacity = total_size();
+ assert(_dynamic_mem_mgr);
+ auto alloc = _dynamic_mem_mgr->allocate(this, capacity);
+ setBuffer(alloc);
+ };
+
+ if (!previously_dynamic || buffer() == nullptr)
+ {
+ // Always set shape - when buffer with same size was already allocated, shape could differ
+ setShape(new_shape);
+ set_dynamic();
+ allocTensorMem();
+ }
+ else
+ {
+ auto previous_size = total_size();
+ auto new_size = new_shape.num_elements() * ir::sizeOfDataType(data_type());
+ if (previous_size != new_size)
+ {
+ assert(_dynamic_mem_mgr);
+ _dynamic_mem_mgr->deallocate(this);
+
+ setShape(new_shape);
+ set_dynamic();
+ allocTensorMem();
+ }
+ else
+ { // when buffer with same size was already allocated, shape could differ
+ setShape(new_shape);
+ }
+ }
+ return true;
+}
+
+void Tensor::deallocBuffer()
+{
+ if (_allocator)
+ {
+ _buffer = nullptr;
+ _allocator.reset();
+ if (_dynamic_mem_mgr)
+ {
+ _dynamic_mem_mgr->deallocate(this);
+ }
+ }
+}
+
+} // namespace basic
+} // namespace backend
+} // namespace onert
+
+// ExternalTensor
+
+namespace onert
+{
+namespace backend
+{
+namespace basic
+{
+
+// `dynamic_cast` not working across library boundaries on NDK
+// With this as a key function, `dynamic_cast` works across dl
+ExternalTensor::~ExternalTensor() {}
+
+} // namespace basic
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/basic/TensorBuilder.cc b/runtime/onert/core/src/backend/basic/TensorBuilder.cc
new file mode 100644
index 000000000..4912af1f5
--- /dev/null
+++ b/runtime/onert/core/src/backend/basic/TensorBuilder.cc
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <backend/basic/TensorBuilder.h>
+
+#include <util/logging.h>
+
+#include <cassert>
+
+namespace onert
+{
+namespace backend
+{
+namespace basic
+{
+
+TensorBuilder::TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg)
+ : _tensor_reg{tensor_reg}, _dynamic_tensor_mgr{new DynamicTensorManager(_tensor_reg)},
+ _static_tensor_mgr{new StaticTensorManager(_tensor_reg, _dynamic_tensor_mgr.get())}
+{
+ /* empty */
+}
+
+TensorBuilder::TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg,
+ const std::string planner_id)
+ : _tensor_reg{tensor_reg}, _dynamic_tensor_mgr{new DynamicTensorManager(_tensor_reg)},
+ _static_tensor_mgr{new StaticTensorManager(_tensor_reg, planner_id, _dynamic_tensor_mgr.get())}
+{
+ /* empty */
+}
+
+void TensorBuilder::registerTensorInfo(const ir::OperandIndex &ind, const ir::OperandInfo &info,
+ ir::Layout layout)
+{
+ _tensor_info_map.emplace(ind, info);
+
+ // CPU backend supports only one layout as NHWC
+ assert(layout == ir::Layout::NHWC);
+ if (info.isDynamic())
+ {
+ _dynamic_tensor_mgr->buildTensor(ind, info, layout);
+ }
+ else
+ {
+ _static_tensor_mgr->buildTensor(ind, info, layout, info.isConstant());
+ }
+}
+
+void TensorBuilder::notifyFirstUse(const ir::OperandIndex &ind)
+{
+ assert(_tensor_info_map.find(ind) != _tensor_info_map.end());
+ const auto &tensor_info = _tensor_info_map.at(ind);
+
+ if (!_tensor_reg->getNativeTensor(ind)->is_dynamic())
+ {
+ const auto size = tensor_info.total_size();
+ _static_tensor_mgr->claimPlan(ind, size);
+ }
+}
+
+void TensorBuilder::notifyLastUse(const ir::OperandIndex &ind)
+{
+ if (!_tensor_reg->getNativeTensor(ind)->is_dynamic())
+ {
+ _static_tensor_mgr->releasePlan(ind);
+ }
+}
+
+bool TensorBuilder::isRegistered(const ir::OperandIndex &ind) const
+{
+ return _tensor_info_map.find(ind) != _tensor_info_map.end();
+}
+
+void TensorBuilder::allocate(void) { _static_tensor_mgr->allocateNonconsts(); }
+
+} // namespace basic
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/basic/train/TrainableTensor.cc b/runtime/onert/core/src/backend/basic/train/TrainableTensor.cc
new file mode 100644
index 000000000..d09604224
--- /dev/null
+++ b/runtime/onert/core/src/backend/basic/train/TrainableTensor.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <backend/basic/train/TrainableTensor.h>
+
+namespace onert
+{
+namespace backend
+{
+namespace basic
+{
+namespace train
+{
+
+std::vector<ITensor *> TrainableTensor::optVars()
+{
+ std::vector<ITensor *> ret;
+ for (auto &&e : _opt_vars)
+ {
+ ret.emplace_back(e.get());
+ }
+ return ret;
+}
+
+void TrainableTensor::fillBuffer(const std::shared_ptr<ir::Data> &data)
+{
+ auto *buffer = _tensor.buffer();
+ assert(buffer);
+ assert(total_size() == data->size());
+ std::memcpy(buffer, data->base(), data->size());
+}
+
+} // namespace train
+} // namespace basic
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/controlflow/Backend.h b/runtime/onert/core/src/backend/builtin/Backend.h
index 670f7750f..85d389505 100644
--- a/runtime/onert/core/src/backend/controlflow/Backend.h
+++ b/runtime/onert/core/src/backend/builtin/Backend.h
@@ -14,16 +14,20 @@
* limitations under the License.
*/
-#ifndef __ONERT_BACKEND_CONTROLFLOW_BACKEND_H__
-#define __ONERT_BACKEND_CONTROLFLOW_BACKEND_H__
+#ifndef __ONERT_BACKEND_BUILTIN_BACKEND_H__
+#define __ONERT_BACKEND_BUILTIN_BACKEND_H__
+#include "BackendContext.h"
#include "Config.h"
-#include "ConstantInitializer.h"
#include "KernelGenerator.h"
#include "TensorBuilder.h"
#include "Tensor.h"
+#include "train/BackendContext.h"
+#include "train/KernelGenerator.h"
+#include "train/TensorRegistry.h"
#include <backend/Backend.h>
+#include <backend/train/ITrainableBackend.h>
#include <memory>
@@ -31,22 +35,19 @@ namespace onert
{
namespace backend
{
-namespace controlflow
+namespace builtin
{
-class Backend : public ::onert::backend::Backend
+class Backend : public ::onert::backend::Backend, public backend::train::ITrainableBackend
{
public:
Backend() : _config{std::make_shared<Config>()} {}
std::shared_ptr<IConfig> config() const override { return _config; }
- std::unique_ptr<BackendContext> newContext(const ir::Graph &graph,
- const std::shared_ptr<custom::IKernelBuilder> &,
- bool) const override
+ std::unique_ptr<onert::backend::BackendContext> newContext(ContextData &&data) const override
{
- const auto &operands = graph.operands();
- auto context = std::make_unique<BackendContext>(this, &graph);
+ auto context = std::make_unique<BackendContext>(this, std::move(data));
// ControlFlow backend may not build tensors for itself because the backend's operation uses
// tensors of other baceknd instead
// But the backend builds tensors in case of that the controlflow operation may have constant
@@ -68,10 +69,22 @@ public:
auto tb = std::make_shared<TensorBuilder>(tr);
context->tensor_registry = tr;
context->tensor_builder = tb;
- context->constant_initializer = std::make_shared<ConstantInitializer>(operands, tr);
- context->kernel_gen = std::make_shared<KernelGenerator>(graph, tb->dynamicTensorManager(), tr);
- context->tensor_register = nullptr;
- context->optimizer = nullptr;
+ context->kernel_gen = std::make_shared<KernelGenerator>(
+ *context->graph(), tb->dynamicTensorManager(), tr, context->external_context());
+ return context;
+ }
+
+ std::unique_ptr<backend::train::TrainableBackendContext>
+ newContext(backend::train::TrainableContextData &&tdata) const override
+ {
+ const auto &tgraph = *tdata.tgraph;
+ auto tr = std::make_shared<train::TensorRegistry>();
+ // TODO Create TensorBuilder if necessary
+ auto tdata_ptr = std::make_unique<backend::train::TrainableContextData>(std::move(tdata));
+ auto context = std::make_unique<train::BackendContext>(this, std::move(tdata_ptr), tr);
+
+ context->kernel_gen =
+ std::make_shared<train::KernelGenerator>(tgraph, tr, context->external_context());
return context;
}
@@ -79,8 +92,8 @@ private:
std::shared_ptr<IConfig> _config;
};
-} // namespace controlflow
+} // namespace builtin
} // namespace backend
} // namespace onert
-#endif // __ONERT_BACKEND_CONTROLFLOW_BACKEND_H__
+#endif // __ONERT_BACKEND_BUILTIN_BACKEND_H__
diff --git a/runtime/onert/core/src/backend/builtin/BackendContext.cc b/runtime/onert/core/src/backend/builtin/BackendContext.cc
new file mode 100644
index 000000000..a66e97b6e
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/BackendContext.cc
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "BackendContext.h"
+
+#include "KernelGenerator.h"
+#include "backend/basic/BackendContextHelpers.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+
+ITensorRegistry *BackendContext::genTensors() { return basic::genTensors(*this); }
+
+FunctionMap BackendContext::genKernels()
+{
+ FunctionMap ret;
+
+ for (auto &&op_ind : _data.op_order)
+ {
+ auto fn_seq = kernel_gen->generate(op_ind);
+ ret.emplace(op_ind, std::move(fn_seq));
+ }
+
+ basic::initConsts(*this);
+
+ // NOTE For memory optimization, we want to free some operand data
+ const_cast<ir::Graph *>(graph())->operands().iterate(
+ [&](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); });
+
+ for (auto &&it : ret)
+ {
+ auto &fn_seq = it.second;
+ fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
+ }
+
+ return ret;
+}
+
+} // namespace builtin
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/builtin/BackendContext.h b/runtime/onert/core/src/backend/builtin/BackendContext.h
new file mode 100644
index 000000000..93e825239
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/BackendContext.h
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUILTIN_BACKEND_CONTEXT_H__
+#define __ONERT_BACKEND_BUILTIN_BACKEND_CONTEXT_H__
+
+#include <backend/BackendContext.h>
+#include "TensorBuilder.h"
+#include "KernelGenerator.h"
+#include "ExternalContext.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+
+class BackendContext : public onert::backend::BackendContext
+{
+public:
+ BackendContext(const Backend *backend, ContextData &&data,
+ std::shared_ptr<ITensorRegistry> tensor_registry = nullptr,
+ std::shared_ptr<TensorBuilder> tensor_builder = nullptr,
+ std::shared_ptr<KernelGenerator> kernel_gen = nullptr)
+ : onert::backend::BackendContext(backend, std::move(data), tensor_registry),
+ tensor_builder{tensor_builder}, kernel_gen{kernel_gen},
+ _external_context(std::make_shared<ExternalContext>())
+ {
+ }
+
+ ITensorRegistry *genTensors() override;
+
+ FunctionMap genKernels() override;
+
+ std::shared_ptr<ExternalContext> external_context() { return _external_context; }
+
+private:
+ void planTensors(const std::vector<onert::ir::OperationIndex> &order,
+ const compiler::GraphLowerInfo &lower_info);
+
+public:
+ // TODO Make it private
+ std::shared_ptr<TensorBuilder> tensor_builder;
+ std::shared_ptr<KernelGenerator> kernel_gen;
+
+private:
+ // NOTE ruy context has a thread pool, and when multiple ruy contexts are created,
+ // the thread pool is also created in duplicate
+ // TODO Create one ruy context for session
+ std::shared_ptr<ExternalContext> _external_context;
+};
+
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUILTIN_BACKEND_CONTEXT_H__
diff --git a/runtime/onert/core/src/backend/controlflow/Config.cc b/runtime/onert/core/src/backend/builtin/Config.cc
index 5ec01fe11..e5f6d4c21 100644
--- a/runtime/onert/core/src/backend/controlflow/Config.cc
+++ b/runtime/onert/core/src/backend/builtin/Config.cc
@@ -20,18 +20,18 @@ namespace onert
{
namespace backend
{
-namespace controlflow
+namespace builtin
{
-std::string Config::ID = "controlflow";
+std::string Config::ID = "builtin";
bool Config::initialize() { return true; }
-ir::Layout Config::supportLayout(const ir::Operation &, ir::Layout frontend_layout)
+ir::Layout Config::supportLayout(const ir::IOperation &, ir::Layout frontend_layout)
{
return frontend_layout;
}
-} // namespace controlflow
+} // namespace builtin
} // namespace backend
} // namespace onert
diff --git a/runtime/onert/core/src/backend/controlflow/Config.h b/runtime/onert/core/src/backend/builtin/Config.h
index 6645ed59d..196b299d3 100644
--- a/runtime/onert/core/src/backend/controlflow/Config.h
+++ b/runtime/onert/core/src/backend/builtin/Config.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef __ONERT_BACKEND_CONTROLFLOW_CONFIG_H__
-#define __ONERT_BACKEND_CONTROLFLOW_CONFIG_H__
+#ifndef __ONERT_BACKEND_BUILTIN_CONFIG_H__
+#define __ONERT_BACKEND_BUILTIN_CONFIG_H__
#include <backend/IConfig.h>
#include <memory>
@@ -25,7 +25,7 @@ namespace onert
{
namespace backend
{
-namespace controlflow
+namespace builtin
{
class Config : public IConfig
@@ -34,7 +34,7 @@ public:
static std::string ID;
std::string id() override { return ID; }
bool initialize() override;
- ir::Layout supportLayout(const ir::Operation &node, ir::Layout frontend_layout) override;
+ ir::Layout supportLayout(const ir::IOperation &node, ir::Layout frontend_layout) override;
bool supportPermutation() override { return false; }
bool supportDynamicTensor() override
{
@@ -46,8 +46,8 @@ public:
std::unique_ptr<util::ITimer> timer() override { return std::make_unique<util::CPUTimer>(); }
};
-} // namespace controlflow
+} // namespace builtin
} // namespace backend
} // namespace onert
-#endif // __ONERT_BACKEND_CONTROLFLOW_CONFIG_H__
+#endif // __ONERT_BACKEND_BUILTIN_CONFIG_H__
diff --git a/runtime/onert/core/src/backend/controlflow/UserTensor.cc b/runtime/onert/core/src/backend/builtin/ConstantInitializer.h
index c8e2ebade..6b8eb3e9d 100644
--- a/runtime/onert/core/src/backend/controlflow/UserTensor.cc
+++ b/runtime/onert/core/src/backend/builtin/ConstantInitializer.h
@@ -14,27 +14,22 @@
* limitations under the License.
*/
-#include "UserTensor.h"
+#ifndef __ONERT_COMPILER_BUILTIN_CONSTANT_INITIALIZER_H__
+#define __ONERT_COMPILER_BUILTIN_CONSTANT_INITIALIZER_H__
+
+#include <backend/basic/ConstantInitializer.h>
namespace onert
{
namespace backend
{
-namespace controlflow
+namespace builtin
{
-size_t UserTensor::calcOffset(const ir::Coordinates &coords) const
-{
- size_t rank = num_dimensions();
- size_t offset = 0;
- for (size_t i = 0; i < rank; ++i)
- {
- offset = offset * dimension(i) + coords[i];
- }
- offset *= sizeOfDataType(data_type());
- return offset;
-}
+using ConstantInitializer = basic::ConstantInitializer;
-} // namespace controlflow
+} // namespace builtin
} // namespace backend
} // namespace onert
+
+#endif // __ONERT_COMPILER_BUILTIN_CONSTANT_INITIALIZER_H__
diff --git a/runtime/onert/core/src/backend/controlflow/UserTensorRegistry.h b/runtime/onert/core/src/backend/builtin/DynamicTensorManager.h
index fa2a2d54c..148948a9c 100644
--- a/runtime/onert/core/src/backend/controlflow/UserTensorRegistry.h
+++ b/runtime/onert/core/src/backend/builtin/DynamicTensorManager.h
@@ -14,23 +14,25 @@
* limitations under the License.
*/
-#ifndef __ONERT_BACKEND_CONTROLFLOW_USER_TENSOR_REGISTRY__
-#define __ONERT_BACKEND_CONTROLFLOW_USER_TENSOR_REGISTRY__
+#ifndef __ONERT_BACKEND_BUILTIN_DYNAMICTENSOR_MANAGER_H__
+#define __ONERT_BACKEND_BUILTIN_DYNAMICTENSOR_MANAGER_H__
-#include "backend/ITensorRegistry.h"
-#include "UserTensor.h"
+#include "TensorRegistry.h"
+#include "Tensor.h"
+
+#include <backend/basic/DynamicTensorManager.h>
namespace onert
{
namespace backend
{
-namespace controlflow
+namespace builtin
{
-using UserTensorRegistry = PortableTensorRegistryTemplate<UserTensor>;
+using DynamicTensorManager = basic::DynamicTensorManager;
-} // namespace controlflow
+} // namespace builtin
} // namespace backend
} // namespace onert
-#endif // __ONERT_BACKEND_CONTROLFLOW_USER_TENSOR_REGISTRY__
+#endif // __ONERT_BACKEND_BUILTIN_DYNAMICTENSOR_MANAGER_H__
diff --git a/runtime/onert/core/src/backend/builtin/ExternalContext.h b/runtime/onert/core/src/backend/builtin/ExternalContext.h
new file mode 100644
index 000000000..390dbb579
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/ExternalContext.h
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUILTIN_EXTERNAL_CONTEXT_H__
+#define __ONERT_BACKEND_BUILTIN_EXTERNAL_CONTEXT_H__
+
+#include <util/ConfigSource.h>
+
+#include <ruy/context.h>
+#include <ruy/context_get_ctx.h>
+#include <ruy/ctx.h>
+#include <ruy/tune.h>
+
+#include <memory>
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+
+// TODO Unify this with cpu::ExternalContext
+class ExternalContext
+{
+private:
+ static const int kDefaultNumThreadpoolThreads = 1;
+
+public:
+ ExternalContext() : _ruy_context(std::make_unique<ruy::Context>())
+ {
+ setMaxNumThreads(onert::util::getConfigInt(onert::util::config::RUY_THREADS));
+ initPerThreadState();
+ }
+
+ void setMaxNumThreads(int max_num_threads)
+ {
+ const int target_num_threads =
+ max_num_threads > -1 ? max_num_threads : kDefaultNumThreadpoolThreads;
+ _ruy_context->set_max_num_threads(target_num_threads);
+ }
+
+ ruy::Context *ruy_context() const { return _ruy_context.get(); }
+
+private:
+ void initPerThreadState()
+ {
+ // Initialize per-thread state.
+ const int thread_count = _ruy_context->max_num_threads();
+ auto ctx = ruy::get_ctx(_ruy_context.get());
+ ctx->EnsureThreadSpecificResources(thread_count);
+ for (int i = 0; i < thread_count; i++)
+ {
+ ctx->GetThreadSpecificTuningResolver(i)->SetTuning(ctx->explicit_tuning());
+ }
+ }
+
+private:
+ const std::unique_ptr<ruy::Context> _ruy_context;
+};
+
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUILTIN_EXTERNAL_CONTEXT_H__
diff --git a/runtime/onert/core/src/backend/builtin/IOTensor.cc b/runtime/onert/core/src/backend/builtin/IOTensor.cc
new file mode 100644
index 000000000..e157a12e9
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/IOTensor.cc
@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IOTensor.h"
+
+#include <assert.h>
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+
+// `dynamic_cast` not working across library boundaries on NDK
+// With this as a key function, `dynamic_cast` works across dl
+IOTensor::~IOTensor() {}
+
+IOTensor::IOTensor(const ir::OperandInfo &info, ir::Layout layout)
+ : IPortableTensor{info}, _tensor{nullptr},
+ _orig{std::make_unique<UserTensor>(info, layout, (uint8_t *)nullptr, 0)}
+{
+ _tensor = _orig.get();
+}
+
+void IOTensor::setTensor(IPortableTensor *tensor)
+{
+ assert(tensor);
+ assert(tensor != this);
+ assert(tensor->layout() == _orig->layout()); // Changing layout is not considered yet
+ _tensor = tensor;
+ if (_info.shape() != tensor->getShape())
+ {
+ _info.shape(tensor->getShape());
+
+ // If input tensor shape is updated, other effective buffers use dynamic memory manager.
+ // Dynamic memory manager deallocate allcoated memory after each execution.
+ // So we should remain input tensor as dynamic if we mark it dynamic at least once.
+ // If dynamic memory manager maintains allocated memory after execution is finished,
+ // we may need to reset it as static for each setTensor call.
+ _info.setDynamic();
+ }
+}
+
+} // namespace builtin
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/builtin/IOTensor.h b/runtime/onert/core/src/backend/builtin/IOTensor.h
new file mode 100644
index 000000000..3d684e07d
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/IOTensor.h
@@ -0,0 +1,114 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUILTIN_IO_TENSOR_H__
+#define __ONERT_BACKEND_BUILTIN_IO_TENSOR_H__
+
+#include "backend/IPortableTensor.h"
+#include "UserTensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+
+/**
+ * @brief Tensor object that indirects to the tensor it is pointing to.
+ *
+ * A executor's I/O tensor could be two types.
+ *
+ * 1. @c UserTensor, if it is the primary graph (package's input/output)
+ * 2. Any other derivative of @c IPortableTensor from another executor, otherwise
+ *
+ * To support these, this object indirects everything to the actual tensor pointer.
+ *
+ * IOTensor is derived from IPortableTensor, and it also have "_info" field.
+ * "_info" field is accessed by IPortableTensor's getter method.
+ *
+ * It assumes that IOTensor's info is always same with actual tensor's info except shape.
+ * setTensor() updates IOTensor's info's shape to actual tensor shape.
+ * Actual tensor's info should not be updated directly after setTensor() call until
+ * executor's execution is finished, instead it is allowed to update actual tensor's info
+ * indirectly by IOTensor's setter methods.
+ */
+class IOTensor : public IPortableTensor
+{
+public:
+ IOTensor(const ir::OperandInfo &info, ir::Layout layout);
+ ~IOTensor();
+
+public:
+ void setTensor(IPortableTensor *tensor);
+
+public:
+ uint8_t *buffer() const override { return _tensor->buffer(); }
+ ir::Layout layout() const override { return _orig->layout(); }
+ void set_dynamic() override
+ {
+ _info.setDynamic();
+ _tensor->set_dynamic();
+ }
+ void setShape(const ir::Shape &shape) override
+ {
+ _info.shape(shape);
+ _tensor->setShape(shape);
+ }
+
+ /*
+ * Changes tensor shape and allocate memory since its shape was changed
+ * perhaps by nnfw_set_input_tensorinfo()
+ *
+ * Cases are:
+ * 1) static operand -> nnfw_set_input_tensorinfo() -> execute() -> execute()
+ * (a) (b)
+ *
+ * at (a), operand is static, tensor is static - memory dealloc is not needed
+ * (DynamicTensorManager cannot dealloc memory allocated by StaticTensorManager)
+ * at (b), operand is static, tensor is dynamic - memory dealloc is needed
+ *
+ * 2) dynamic operand -> nnfw_set_input_tensorinfo() -> execute() -> execute()
+ * (a) (b)
+ *
+ * at (a), operand is dynamic, tensor is dynamic - memory dealloc is not needed
+ * since it has not been allocated yet
+ * at (b), operand is dynamic, tensor is dynamic - memory dealloc is needed
+ */
+ bool applyShape(const ir::Shape &shape) override
+ {
+ auto return_val = _tensor->applyShape(shape);
+ if (return_val)
+ {
+ _info.shape(shape);
+ _info.setDynamic();
+ }
+ return return_val;
+ }
+
+private:
+ IPortableTensor *_tensor{nullptr}; //< The actual tensor that is indirected
+ // "_orig" has UserTensor type original tensor's info with nullptr buffer and layout,
+ // and "_tensor" points to "_user_tensor".
+ // After 1st setTensor(tensor) call, "_tensor" is updated to actual tensor
+ std::unique_ptr<UserTensor> _orig; //< If it is a user tensor, it is managed by this object
+};
+
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUILTIN_IO_TENSOR_H__
diff --git a/runtime/onert/core/src/backend/builtin/KernelGenerator.cc b/runtime/onert/core/src/backend/builtin/KernelGenerator.cc
new file mode 100644
index 000000000..00c200a92
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/KernelGenerator.cc
@@ -0,0 +1,159 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "KernelGenerator.h"
+
+#include "kernel/IfLayer.h"
+#include "kernel/PermuteLayer.h"
+#include "kernel/WhileLayer.h"
+
+#include "exec/FunctionSequence.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+
+KernelGenerator::KernelGenerator(const ir::Graph &graph, DynamicTensorManager *dyn_tensor_manager,
+ const std::shared_ptr<TensorRegistry> &tensor_reg,
+ const std::shared_ptr<ExternalContext> &external_context)
+ : basic::KernelGeneratorBase{graph}, _dyn_tensor_manager{dyn_tensor_manager},
+ _tensor_reg{tensor_reg}, _tensor_registries{}, _executors{nullptr}, _model_index{},
+ _external_context{external_context}
+{
+ UNUSED_RELEASE(_graph);
+ UNUSED_RELEASE(_tensor_registries);
+ UNUSED_RELEASE(_executors);
+}
+
+std::unique_ptr<exec::FunctionSequence> KernelGenerator::generate(ir::OperationIndex ind)
+{
+ assert(_dyn_tensor_manager);
+ assert(_tensor_reg);
+
+ auto ret = std::make_unique<exec::FunctionSequence>();
+
+ // Prepare to handle dynamic tensors later
+ auto dyn_ctx = std::make_shared<exec::FunctionSequence::DynamicTensorCtx>();
+ {
+ dyn_ctx->op = &_graph.operations().at(ind);
+ dyn_ctx->dynamic_shape_inferer =
+ std::make_unique<exec::DynamicShapeInferer>(_graph.operands(), _tensor_reg);
+ }
+ ret->dynamic_tensor_ctx(dyn_ctx);
+
+ auto &op = _graph.operations().at(ind);
+ op.accept(*this);
+ assert(_return_fn); // _return_fn must have been generated
+ ret->append(std::move(_return_fn));
+
+ return ret;
+}
+
+void KernelGenerator::visit(const ir::operation::If &node)
+{
+ const auto then_subg_index = node.param().then_subg_index;
+ const auto else_subg_index = node.param().else_subg_index;
+
+ std::vector<backend::IPortableTensor *> input_tensors;
+ for (const auto &input_index : node.getInputs())
+ {
+ auto input_tensor = getPortableTensor(input_index);
+ input_tensors.emplace_back(input_tensor);
+ }
+
+ std::vector<backend::IPortableTensor *> output_tensors;
+ for (const auto &output_index : node.getOutputs())
+ {
+ auto output_tensor = getPortableTensor(output_index);
+ output_tensors.emplace_back(output_tensor);
+ }
+
+ // IfLayer just set Executors instead of then and else executor to avoid complexity of
+ // creating executor recusively
+ const auto cond_tensor = input_tensors.front();
+ input_tensors.erase(input_tensors.begin());
+ auto fn = std::make_unique<::onert::backend::builtin::kernel::IfLayer>(
+ cond_tensor, input_tensors, output_tensors, then_subg_index, else_subg_index, _executors,
+ _model_index, _external_context);
+
+ _return_fn = std::move(fn);
+}
+
+void KernelGenerator::visit(const ir::operation::Permute &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(0)};
+
+ // Add PermuteLayer
+ std::vector<ITensor *> output_tensors{getTensor(output_index)};
+ std::vector<ITensor *> input_tensors{getTensor(input_index)};
+
+ auto fn =
+ std::make_unique<kernel::PermuteLayer>(input_tensors, output_tensors, _external_context);
+ _return_fn = std::move(fn);
+}
+
+void KernelGenerator::visit(const ir::operation::While &node)
+{
+ const auto cond_subg_index = node.param().cond_subg_index;
+ const auto body_subg_index = node.param().body_subg_index;
+
+ // This op does not support input as a constant, because builtin backend does not have
+ // TensorBuilder
+ std::vector<backend::IPortableTensor *> input_tensors;
+ for (const auto &input_index : node.getInputs())
+ {
+ auto input_tensor = getPortableTensor(input_index);
+ input_tensors.emplace_back(input_tensor);
+ }
+
+ std::vector<backend::IPortableTensor *> output_tensors;
+ for (const auto &output_index : node.getOutputs())
+ {
+ auto output_tensor = getPortableTensor(output_index);
+ output_tensors.emplace_back(output_tensor);
+ }
+
+ // WhileLayer just set Executors instead of cond and body executor to avoid complexity of
+ // creating executor recusively
+ auto fn = std::make_unique<::onert::backend::builtin::kernel::WhileLayer>(
+ input_tensors, output_tensors, cond_subg_index, body_subg_index, _executors, _model_index,
+ _dyn_tensor_manager->dynamic_mem_mgr().get(), _external_context);
+
+ _return_fn = std::move(fn);
+}
+
+backend::ITensor *KernelGenerator::getTensor(const ir::OperandIndex &index)
+{
+ // get Tensor from all tensor registries (for Permute op)
+ auto ret = _tensor_registries.getITensor(index);
+ assert(ret != nullptr);
+ return ret;
+}
+
+backend::IPortableTensor *KernelGenerator::getPortableTensor(const ir::OperandIndex &index)
+{
+ auto ret = _tensor_reg->getPortableTensor(index);
+ assert(ret != nullptr);
+ return ret;
+}
+
+} // namespace builtin
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/controlflow/KernelGenerator.h b/runtime/onert/core/src/backend/builtin/KernelGenerator.h
index b84a810e4..3c86fe306 100644
--- a/runtime/onert/core/src/backend/controlflow/KernelGenerator.h
+++ b/runtime/onert/core/src/backend/builtin/KernelGenerator.h
@@ -14,60 +14,66 @@
* limitations under the License.
*/
-#ifndef __ONERT_BACKEND_CONTROLFLOW_KERNEL_GENERATOR_H__
-#define __ONERT_BACKEND_CONTROLFLOW_KERNEL_GENERATOR_H__
+#ifndef __ONERT_BACKEND_BUILTIN_KERNEL_GENERATOR_H__
+#define __ONERT_BACKEND_BUILTIN_KERNEL_GENERATOR_H__
-#include <backend/IKernelGenerator.h>
-#include <backend/ITensorBuilder.h>
-#include <exec/IExecutor.h>
-#include <ir/Graph.h>
-#include "TensorBuilder.h"
-#include "compiler/TensorRegistries.h"
+#include "DynamicTensorManager.h"
+#include "ExternalContext.h"
#include "TensorRegistry.h"
+#include "../../compiler/TensorRegistries.h"
+
+#include "backend/basic/KernelGeneratorBase.h"
+#include "exec/IExecutors.h"
+#include "ir/Graph.h"
namespace onert
{
namespace backend
{
-namespace controlflow
+namespace builtin
{
-class KernelGenerator : public IKernelGenerator
+class KernelGenerator : public basic::KernelGeneratorBase
{
public:
- KernelGenerator(const ir::Graph &graph, IDynamicTensorManager *dyn_tensor_manager,
- const std::shared_ptr<TensorRegistry> &tensor_reg);
+ KernelGenerator(const ir::Graph &graph, DynamicTensorManager *dyn_tensor_manager,
+ const std::shared_ptr<TensorRegistry> &tensor_reg,
+ const std::shared_ptr<ExternalContext> &external_context);
void setTensorRegistries(const compiler::TensorRegistries &tensor_registries)
{
_tensor_registries = tensor_registries;
}
- void setExecutorMap(const std::shared_ptr<exec::ExecutorMap> &executor_map)
+ void setExecutors(const std::shared_ptr<exec::IExecutors> &executors)
{
// FIXME Using shared_ptr's raw pointer!
- _executor_map = executor_map.get();
+ _executors = executors.get();
}
- using IKernelGenerator::visit;
+ void setModelIndex(const ir::ModelIndex &index) { _model_index = index; }
+
+ std::unique_ptr<exec::FunctionSequence> generate(ir::OperationIndex ind) override;
- void visit(const ir::OpSequence &) override;
+private:
void visit(const ir::operation::If &) override;
void visit(const ir::operation::Permute &) override;
void visit(const ir::operation::While &) override;
private:
- std::shared_ptr<backend::ITensor> getTensor(const ir::OperandIndex &index);
+ backend::ITensor *getTensor(const ir::OperandIndex &index);
+ backend::IPortableTensor *getPortableTensor(const ir::OperandIndex &index);
private:
- const ir::Graph &_graph;
- IDynamicTensorManager *_dyn_tensor_manager;
+ DynamicTensorManager *_dyn_tensor_manager;
std::shared_ptr<TensorRegistry> _tensor_reg;
compiler::TensorRegistries _tensor_registries;
- exec::ExecutorMap *_executor_map;
+ exec::IExecutors *_executors;
+ ir::ModelIndex _model_index;
+ const std::shared_ptr<ExternalContext> _external_context;
};
-} // namespace controlflow
+} // namespace builtin
} // namespace backend
} // namespace onert
-#endif // __ONERT_BACKEND_CONTROLFLOW_KERNEL_GENERATOR_H__
+#endif // __ONERT_BACKEND_BUILTIN_KERNEL_GENERATOR_H__
diff --git a/runtime/onert/core/src/backend/controlflow/Tensor.h b/runtime/onert/core/src/backend/builtin/Tensor.h
index ba5bafd75..d55e64161 100644
--- a/runtime/onert/core/src/backend/controlflow/Tensor.h
+++ b/runtime/onert/core/src/backend/builtin/Tensor.h
@@ -14,22 +14,23 @@
* limitations under the License.
*/
-#ifndef __ONERT_BACKEND_CONTROLFLOW_TENSOR_H__
-#define __ONERT_BACKEND_CONTROLFLOW_TENSOR_H__
+#ifndef __ONERT_BACKEND_BUILTIN_TENSOR_H__
+#define __ONERT_BACKEND_BUILTIN_TENSOR_H__
-#include <backend/cpu_common/Tensor.h>
+#include <backend/basic/Tensor.h>
namespace onert
{
namespace backend
{
-namespace controlflow
+namespace builtin
{
-using Tensor = cpu_common::Tensor;
+using Tensor = basic::Tensor;
+using ExternalTensor = basic::ExternalTensor;
-} // namespace controlflow
+} // namespace builtin
} // namespace backend
} // namespace onert
-#endif // __ONERT_BACKEND_CONTROLFLOW_TENSOR_H__
+#endif // __ONERT_BACKEND_BUILTIN_TENSOR_H__
diff --git a/runtime/onert/core/src/backend/controlflow/TensorBuilder.cc b/runtime/onert/core/src/backend/builtin/TensorBuilder.cc
index e5c3f5fd5..a2f7af3ea 100644
--- a/runtime/onert/core/src/backend/controlflow/TensorBuilder.cc
+++ b/runtime/onert/core/src/backend/builtin/TensorBuilder.cc
@@ -24,13 +24,13 @@ namespace onert
{
namespace backend
{
-namespace controlflow
+namespace builtin
{
TensorBuilder::TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg)
- : _tensor_reg{tensor_reg}, _dynamic_tensor_mgr{new DynamicTensorManager(_tensor_reg)},
- _static_tensor_mgr{
- new cpu_common::StaticTensorManager(_tensor_reg->base_reg(), _dynamic_tensor_mgr.get())}
+ : _tensor_reg{tensor_reg}, _dynamic_tensor_mgr{new DynamicTensorManager(_tensor_reg->base_reg())},
+ _static_tensor_mgr{
+ new basic::StaticTensorManager(_tensor_reg->base_reg(), _dynamic_tensor_mgr.get())}
{
/* empty */
}
@@ -40,15 +40,14 @@ void TensorBuilder::registerTensorInfo(const ir::OperandIndex &ind, const ir::Op
{
_tensor_info_map.emplace(ind, info);
- _tensor_layout_map.insert({ind, backend_layout});
-
+ VERBOSE_F() << "cpucommon REGISTER!! " << ind << std::endl;
if (info.isDynamic())
{
- _dynamic_tensor_mgr->buildTensor(ind, info, _tensor_layout_map[ind]);
+ _dynamic_tensor_mgr->buildTensor(ind, info, backend_layout);
}
else
{
- _static_tensor_mgr->buildTensor(ind, info, _tensor_layout_map[ind], info.isConstant());
+ _static_tensor_mgr->buildTensor(ind, info, backend_layout, info.isConstant());
}
}
@@ -58,7 +57,7 @@ void TensorBuilder::notifyFirstUse(const ir::OperandIndex &ind)
if (_tensor_info_map.find(ind) == _tensor_info_map.end()) // Do not proceed for user tensors
return;
- const auto tensor_info = _tensor_info_map.at(ind);
+ const auto &tensor_info = _tensor_info_map.at(ind);
if (!nativeOwnTensorAt(ind)->is_dynamic())
{
@@ -89,39 +88,18 @@ bool TensorBuilder::isRegistered(const ir::OperandIndex &ind) const
return _tensor_info_map.find(ind) != _tensor_info_map.end();
}
-void TensorBuilder::prepare(void)
-{
- _static_tensor_mgr->allocateConsts();
- _static_tensor_mgr->allocateNonconsts();
-}
+void TensorBuilder::allocate(void) { _static_tensor_mgr->allocateNonconsts(); }
-void TensorBuilder::allocate()
+DynamicTensorManager *TensorBuilder::dynamicTensorManager(void)
{
- // NOTE For now nothing to do. Allocation is done in prepare stage, which is not appropriate
- // This is because CPU kernels require `ITensor`s to be allocated before Kernel Generation.
+ return _dynamic_tensor_mgr.get();
}
-std::shared_ptr<cpu_common::Tensor> TensorBuilder::nativeOwnTensorAt(const ir::OperandIndex &ind)
+basic::Tensor *TensorBuilder::nativeOwnTensorAt(const ir::OperandIndex &ind)
{
return _tensor_reg->getNativeOwnTensor(ind);
}
-std::unique_ptr<ITensorManager> TensorBuilder::releaseStaticTensorManager(void)
-{
- return std::move(_static_tensor_mgr);
-}
-
-std::unique_ptr<ITensorManager> TensorBuilder::releaseDynamicTensorManager(void)
-{
- return std::move(_dynamic_tensor_mgr);
-}
-
-void TensorBuilder::setNativeUserTensor(const ir::OperandIndex &ind,
- const std::shared_ptr<UserTensor> &tensor)
-{
- _tensor_reg->setNativeUserTensor(ind, tensor);
-}
-
-} // namespace controlflow
+} // namespace builtin
} // namespace backend
} // namespace onert
diff --git a/runtime/onert/core/src/backend/controlflow/TensorBuilder.h b/runtime/onert/core/src/backend/builtin/TensorBuilder.h
index 2f2a2c47e..1e364c927 100644
--- a/runtime/onert/core/src/backend/controlflow/TensorBuilder.h
+++ b/runtime/onert/core/src/backend/builtin/TensorBuilder.h
@@ -14,29 +14,27 @@
* limitations under the License.
*/
-#ifndef __ONERT_BACKEND_CONTROLFLOW_TENSOR_BUILDER_H__
-#define __ONERT_BACKEND_CONTROLFLOW_TENSOR_BUILDER_H__
+#ifndef __ONERT_BACKEND_BUILTIN_TENSOR_BUILDER_H__
+#define __ONERT_BACKEND_BUILTIN_TENSOR_BUILDER_H__
-#include <backend/cpu_common/StaticTensorManager.h>
-#include <backend/cpu_common/TensorRegistry.h>
-#include <backend/cpu_common/Tensor.h>
+#include <backend/basic/StaticTensorManager.h>
+#include <backend/basic/TensorRegistry.h>
+#include <backend/basic/Tensor.h>
-#include <backend/ITensorBuilder.h>
#include <ir/OperandIndexMap.h>
#include <unordered_map>
#include "DynamicTensorManager.h"
-#include "UserTensorRegistry.h"
namespace onert
{
namespace backend
{
-namespace controlflow
+namespace builtin
{
-class TensorBuilder : public ITensorBuilder
+class TensorBuilder
{
public:
TensorBuilder(const std::shared_ptr<TensorRegistry> &tensor_reg);
@@ -48,42 +46,34 @@ public:
* @param[in] layout Operand data layout
*/
void registerTensorInfo(const ir::OperandIndex &ind, const ir::OperandInfo &info,
- ir::Layout backend_layout) override;
+ ir::Layout backend_layout);
- void notifyFirstUse(const ir::OperandIndex &) override;
- void notifyLastUse(const ir::OperandIndex &) override;
+ void notifyFirstUse(const ir::OperandIndex &);
+ void notifyLastUse(const ir::OperandIndex &);
- bool isRegistered(const ir::OperandIndex &) const override;
+ bool isRegistered(const ir::OperandIndex &) const;
- void prepare(void) override;
- void allocate() override;
- void postFunctionPrepare() override { /* DO NOTHING */}
+ void allocate(void);
- std::unique_ptr<ITensorManager> releaseStaticTensorManager(void) override;
-
- IDynamicTensorManager *dynamicTensorManager(void) override { return _dynamic_tensor_mgr.get(); }
-
- std::unique_ptr<ITensorManager> releaseDynamicTensorManager(void) override;
+ DynamicTensorManager *dynamicTensorManager(void);
/**
* @brief Get tensor with a specific OperandIndex.
* @param ind OperandIndex for the tensor. There must exist a tensor with this ind.
* If not, program will crash with assert or exception.
- * @return shared_ptr<operand::Tensor>
+ * @return operand::Tensor *
*/
- std::shared_ptr<cpu_common::Tensor> nativeOwnTensorAt(const ir::OperandIndex &ind);
- void setNativeUserTensor(const ir::OperandIndex &ind, const std::shared_ptr<UserTensor> &tensor);
+ basic::Tensor *nativeOwnTensorAt(const ir::OperandIndex &ind);
private:
const std::shared_ptr<TensorRegistry> _tensor_reg;
std::unique_ptr<DynamicTensorManager> _dynamic_tensor_mgr;
- std::unique_ptr<cpu_common::StaticTensorManager> _static_tensor_mgr;
+ std::unique_ptr<basic::StaticTensorManager> _static_tensor_mgr;
ir::OperandIndexMap<ir::OperandInfo> _tensor_info_map;
- ir::OperandIndexMap<ir::Layout> _tensor_layout_map;
};
-} // namespace controlflow
+} // namespace builtin
} // namespace backend
} // namespace onert
-#endif // __ONERT_BACKEND_CONTROLFLOW_TENSOR_BUILDER_H__
+#endif // __ONERT_BACKEND_BUILTIN_TENSOR_BUILDER_H__
diff --git a/runtime/onert/core/src/backend/builtin/TensorRegistry.h b/runtime/onert/core/src/backend/builtin/TensorRegistry.h
new file mode 100644
index 000000000..ae68b1318
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/TensorRegistry.h
@@ -0,0 +1,134 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUILTIN_TENSOR_REGISTRY_H__
+#define __ONERT_BACKEND_BUILTIN_TENSOR_REGISTRY_H__
+
+#include "backend/basic/TensorRegistry.h"
+#include "backend/ITensorRegistry.h"
+#include "Tensor.h"
+#include "IOTensor.h"
+#include <assert.h>
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+
+/**
+ * @brief Tensor registry class for builtin backend
+ *
+ * This class contains three types of tensors. Two native tensors(tensors that are managed by this
+ * backend) and the other is migrant tensor.
+ *
+ * - NativeIOTensor - @c IOTensor managed by this backend ( in @c _base_reg )
+ * - NOTE The tensor it actually points to can be from another backend
+ * - NativeOwnTensor - @c basic::Tensor managed by this backend ( in @c _base_reg )
+ * - MigrantTensor - @c IPortableTensor managed by other backends
+ *
+ * @note @c _base_reg is used in implementation to reuse @c basic::StaticTensorManager
+ *
+ */
+class TensorRegistry : public ITensorRegistry
+{
+public:
+ TensorRegistry() : _base_reg{new basic::TensorRegistry} {}
+
+ ITensor *getITensor(const ir::OperandIndex &ind) override
+ {
+ auto base_tensor = _base_reg->getITensor(ind);
+ if (base_tensor)
+ return base_tensor;
+ return getNativeIOTensor(ind);
+ }
+
+ ITensor *getNativeITensor(const ir::OperandIndex &ind) override
+ {
+ auto base_tensor = _base_reg->getNativeITensor(ind);
+ if (base_tensor)
+ return base_tensor;
+ return getNativeIOTensor(ind);
+ }
+
+ IPortableTensor *getPortableTensor(const ir::OperandIndex &ind)
+ {
+ auto base_tensor = _base_reg->getPortableTensor(ind);
+ if (base_tensor)
+ return base_tensor;
+ return getNativeIOTensor(ind);
+ }
+
+ IPortableTensor *getNativeTensor(const ir::OperandIndex &ind)
+ {
+ auto base_tensor = _base_reg->getNativeTensor(ind);
+ if (base_tensor)
+ return base_tensor;
+ return getNativeIOTensor(ind);
+ }
+
+ Tensor *getNativeOwnTensor(const ir::OperandIndex &ind)
+ {
+ return _base_reg->getNativeTensor(ind);
+ }
+
+ IOTensor *getNativeIOTensor(const ir::OperandIndex &ind)
+ {
+ auto tensor = _native_io_tensors.find(ind);
+ if (tensor != _native_io_tensors.end())
+ return tensor->second.get();
+ return nullptr;
+ }
+
+ bool setMigrantTensor(const ir::OperandIndex &ind, IPortableTensor *tensor) override
+ {
+ assert(tensor);
+ assert(!getITensor(ind)); // For the ind, tensor is not registered yet
+ _base_reg->setMigrantTensor(ind, tensor);
+ return true;
+ }
+
+ void setNativeOwnTensor(ir::OperandIndex ind, std::unique_ptr<Tensor> &&tensor)
+ {
+ assert(tensor);
+ assert(!getITensor(ind)); // For the ind, tensor is not registered yet
+ _base_reg->setNativeTensor(ind, std::move(tensor));
+ }
+
+ void setNativeIOTensor(ir::OperandIndex ind, std::unique_ptr<IOTensor> &&tensor)
+ {
+ assert(tensor);
+ assert(!getITensor(ind)); // For the ind, tensor is not registered yet
+ _native_io_tensors[ind] = std::move(tensor);
+ }
+
+ const ir::OperandIndexMap<std::unique_ptr<IOTensor>> &native_io_tensors()
+ {
+ return _native_io_tensors;
+ }
+ std::shared_ptr<basic::TensorRegistry> base_reg() { return _base_reg; }
+
+private:
+ std::shared_ptr<basic::TensorRegistry> _base_reg;
+ ir::OperandIndexMap<std::unique_ptr<IOTensor>> _native_io_tensors;
+};
+
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // ifndef __ONERT_BACKEND_BUILTIN_TENSOR_REGISTRY_H__
diff --git a/runtime/onert/core/src/backend/builtin/UserTensor.cc b/runtime/onert/core/src/backend/builtin/UserTensor.cc
new file mode 100644
index 000000000..e260de275
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/UserTensor.cc
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "UserTensor.h"
+
+#include "util/Exceptions.h"
+#include "ir/DataType.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+
+bool UserTensor::applyShape(const ir::Shape &new_shape)
+{
+ // User tensors cannot be reallocated.
+ auto new_size = new_shape.num_elements() * ir::sizeOfDataType(data_type());
+ if (_size < new_size)
+ throw InsufficientBufferSizeException{"User given buffer size is too small."};
+ setShape(new_shape);
+ return true;
+}
+
+} // namespace builtin
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/builtin/UserTensor.h b/runtime/onert/core/src/backend/builtin/UserTensor.h
new file mode 100644
index 000000000..b7f6ce091
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/UserTensor.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUILTIN_USER_TENSOR_H__
+#define __ONERT_BACKEND_BUILTIN_USER_TENSOR_H__
+
+#include "ir/OperandInfo.h"
+#include "backend/IPortableTensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+
+/**
+ * @brief Tensor object that is for Input and Output tensors from the user.
+ *
+ * This class is a wrapped buffer that is allocated by the user. So it does not have resposibility
+ * on allocation nor deallocation. All the model input/output tensors are wrapped with this class
+ * for execution.
+ *
+ */
+class UserTensor : public IPortableTensor
+{
+public:
+ UserTensor(const ir::OperandInfo &info, ir::Layout layout, uint8_t *buffer, size_t size)
+ : IPortableTensor{info}, _layout{layout}, _buffer{buffer}, _size{size}
+ {
+ }
+
+public:
+ uint8_t *buffer() const override { return _buffer; }
+ ir::Layout layout() const override { return _layout; }
+ void set_dynamic() override { _info.setDynamic(); }
+ void setShape(const ir::Shape &new_shape) override { _info.shape(new_shape); }
+ bool applyShape(const ir::Shape &) override;
+
+private:
+ ir::Layout _layout;
+ uint8_t *_buffer;
+ size_t _size;
+};
+
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUILTIN_USER_TENSOR_H__
diff --git a/runtime/onert/core/src/backend/builtin/kernel/IfLayer.cc b/runtime/onert/core/src/backend/builtin/kernel/IfLayer.cc
new file mode 100644
index 000000000..bf8c5fc68
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/kernel/IfLayer.cc
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IfLayer.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace kernel
+{
+
+IfLayer::IfLayer(backend::IPortableTensor *cond_tensor,
+ const std::vector<backend::IPortableTensor *> input_tensors,
+ const std::vector<backend::IPortableTensor *> output_tensors,
+ const ir::SubgraphIndex &then_subg_index, const ir::SubgraphIndex &else_subg_index,
+ exec::IExecutors *executors, const ir::ModelIndex &model_index,
+ const std::shared_ptr<ExternalContext> &external_context)
+ : _cond_tensor{cond_tensor}, _input_tensors{input_tensors}, _output_tensors{output_tensors},
+ _then_subg_index{then_subg_index}, _else_subg_index{else_subg_index}, _executors{executors},
+ _model_index{model_index}, _external_context{external_context}
+{
+ // At this point, executors may not have executors of then subg and else subg
+}
+
+void IfLayer::run()
+{
+ // Check condition
+ // // If true
+ // // // Set _input_tensors -> then-subg's inputs
+ // // // Set outputs of then-subg -> _output_tensors
+ // // // Run then-subg
+ // // Else
+ // // // Set _input_tensors -> else-subg's inputs
+ // // // Set outputs of else-subg -> _output_tensors
+ // // // Run else-subg
+
+ auto getResultCond = [](backend::IPortableTensor *tensor) -> bool {
+ bool ret = false;
+ tensor->access([&](ITensor &tensor) { ret = *reinterpret_cast<bool *>(tensor.buffer()); });
+ return ret;
+ };
+
+ exec::IExecutor *subg_exec = nullptr;
+ bool cond_result = getResultCond(_cond_tensor);
+ if (cond_result)
+ {
+ VERBOSE(If) << "Call to $" << _then_subg_index << " (then)" << std::endl;
+ subg_exec = _executors->at(_model_index, _then_subg_index);
+ }
+ else
+ {
+ VERBOSE(If) << "Call to $" << _else_subg_index << " (else)" << std::endl;
+ subg_exec = _executors->at(_model_index, _else_subg_index);
+ }
+
+ subg_exec->execute(_input_tensors, _output_tensors,
+ _executors->entryExecutor()->currentOptions());
+ VERBOSE(If) << "Return from $" << (cond_result ? _then_subg_index : _else_subg_index)
+ << std::endl;
+}
+
+} // namespace kernel
+} // namespace builtin
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/controlflow/kernel/IfLayer.h b/runtime/onert/core/src/backend/builtin/kernel/IfLayer.h
index ef3a6e6f6..a9b8f2710 100644
--- a/runtime/onert/core/src/backend/controlflow/kernel/IfLayer.h
+++ b/runtime/onert/core/src/backend/builtin/kernel/IfLayer.h
@@ -14,17 +14,19 @@
* limitations under the License.
*/
-#ifndef __ONERT_BACKEND_CONTROLFLOW_KERNEL_IF_LAYER_H__
-#define __ONERT_BACKEND_CONTROLFLOW_KERNEL_IF_LAYER_H__
+#ifndef __ONERT_BACKEND_BUILTIN_KERNEL_IF_LAYER_H__
+#define __ONERT_BACKEND_BUILTIN_KERNEL_IF_LAYER_H__
-#include <backend/ITensor.h>
-#include <exec/IExecutor.h>
+#include <backend/IPortableTensor.h>
+#include <exec/IExecutors.h>
+#include <exec/IFunction.h>
+#include "../ExternalContext.h"
namespace onert
{
namespace backend
{
-namespace controlflow
+namespace builtin
{
namespace kernel
{
@@ -32,32 +34,30 @@ namespace kernel
class IfLayer : public ::onert::exec::IFunction
{
public:
- IfLayer(const std::shared_ptr<backend::ITensor> &cond_tensor,
- const std::vector<std::shared_ptr<backend::ITensor>> input_tensors,
- const std::vector<std::shared_ptr<backend::ITensor>> output_tensors,
- const ir::OperandIndexSequence &output_indices, const ir::Graph &graph,
- const exec::DynAllocInfoMap &outputs_dyn_alloc_info,
+ IfLayer(backend::IPortableTensor *cond_tensor,
+ const std::vector<backend::IPortableTensor *> input_tensors,
+ const std::vector<backend::IPortableTensor *> output_tensors,
const ir::SubgraphIndex &then_subg_index, const ir::SubgraphIndex &else_subg_index,
- exec::ExecutorMap *executor_map);
+ exec::IExecutors *executors, const ir::ModelIndex &model_index,
+ const std::shared_ptr<ExternalContext> &external_context);
public:
void run() override;
private:
- const std::shared_ptr<backend::ITensor> _cond_tensor;
- const std::vector<std::shared_ptr<backend::ITensor>> _input_tensors;
- const std::vector<std::shared_ptr<backend::ITensor>> _output_tensors;
- const ir::OperandIndexSequence &_output_indices;
- const ir::Graph &_graph;
- const exec::DynAllocInfoMap _outputs_dyn_alloc_info;
+ backend::IPortableTensor *_cond_tensor;
+ const std::vector<backend::IPortableTensor *> _input_tensors;
+ const std::vector<backend::IPortableTensor *> _output_tensors;
const ir::SubgraphIndex _then_subg_index;
const ir::SubgraphIndex _else_subg_index;
- exec::ExecutorMap *_executor_map;
+ exec::IExecutors *_executors;
+ ir::ModelIndex _model_index;
+ const std::shared_ptr<ExternalContext> _external_context;
};
} // namespace kernel
-} // namespace controlflow
+} // namespace builtin
} // namespace backend
} // namespace onert
-#endif // __ONERT_BACKEND_CONTROLFLOW_KERNEL_IF_LAYER_H__
+#endif // __ONERT_BACKEND_BUILTIN_KERNEL_IF_LAYER_H__
diff --git a/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.cc b/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.cc
new file mode 100644
index 000000000..600180077
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.cc
@@ -0,0 +1,316 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PermuteLayer.h"
+
+#include "../../../exec/ShapeConverter.h"
+
+#include <ruy/context.h> // from @ruy
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace kernel
+{
+
+PermuteLayer::PermuteLayer(const std::vector<ITensor *> &src_tensors,
+ const std::vector<ITensor *> &dst_tensors,
+ const std::shared_ptr<ExternalContext> &external_context)
+ : _external_context{external_context}, _tasks_map{}
+{
+ assert(src_tensors.size() == dst_tensors.size());
+ _src_tensors = src_tensors;
+ _dst_tensors = dst_tensors;
+ _src_tensors_offsets.resize(src_tensors.size());
+ _dst_tensors_offsets.resize(dst_tensors.size());
+}
+
+void PermuteLayer::optimize()
+{
+ // Remove copying of tensor as nullptr
+ auto src_it = _src_tensors.begin();
+ auto dst_it = _dst_tensors.begin();
+ auto src_offsets_it = _src_tensors_offsets.begin();
+ auto dst_offsets_it = _dst_tensors_offsets.begin();
+ while (src_it != _src_tensors.end())
+ {
+ if ((*src_it == *dst_it) || (*src_it == nullptr || *dst_it == nullptr))
+ {
+ src_it = _src_tensors.erase(src_it);
+ dst_it = _dst_tensors.erase(dst_it);
+ src_offsets_it = _src_tensors_offsets.erase(src_offsets_it);
+ dst_offsets_it = _dst_tensors_offsets.erase(dst_offsets_it);
+ }
+ else
+ {
+ auto src = *src_it;
+ auto dst = *dst_it;
+ src_offsets_it->resize(0);
+ dst_offsets_it->resize(0);
+ if (underlying_type(src->data_type()) != underlying_type(dst->data_type()))
+ continue;
+ const auto permute_type = [&]() -> PermuteType {
+ if (src->getShape().rank() == 4 && src->layout() == ir::Layout::NHWC &&
+ dst->layout() == ir::Layout::NCHW)
+ {
+ return PermuteType::NHWC_TO_NCHW;
+ }
+ else if (src->getShape().rank() == 4 && src->layout() == ir::Layout::NCHW &&
+ dst->layout() == ir::Layout::NHWC)
+ {
+ return PermuteType::NCHW_TO_NHWC;
+ }
+ else
+ {
+ return PermuteType::COPY;
+ }
+ }();
+
+ // TODO Support different types
+ auto fn = [&](backend::ITensor &src_tensor) {
+ dst->access([&](backend::ITensor &dst_tensor) {
+ // NOTE The buffer of both tensor can be nullptr in this step
+ const auto data_size = ir::sizeOfDataType(src_tensor.data_type());
+
+ if (permute_type == PermuteType::COPY)
+ {
+ if ((!src_tensor.has_padding() && !dst_tensor.has_padding()))
+ {
+ const auto num_elements = src_tensor.getShape().num_elements();
+ const int thread_count =
+ _external_context->ruy_context()->max_num_threads() < static_cast<int>(num_elements)
+ ? _external_context->ruy_context()->max_num_threads()
+ : num_elements;
+
+ std::vector<PermuteWorkerTask> tasks;
+ auto start = 0;
+ for (auto i = 0; i < thread_count; ++i)
+ {
+ int end = start + (num_elements - start) / (thread_count - i);
+ tasks.emplace_back(src_tensor.buffer(), dst_tensor.buffer(), start * data_size,
+ start * data_size, (end - start) * data_size);
+ start = end;
+ }
+ assert(tasks.size() >= 1);
+ _tasks_map[src] = std::move(tasks);
+ }
+ else
+ {
+ auto loop_shape = src_tensor.getShape();
+
+ auto copy_axis = loop_shape.rank() - 1;
+ copy_axis = copy_axis < 0 ? 1 : copy_axis;
+ const auto copy_len = loop_shape.dim(copy_axis) * data_size;
+ loop_shape.dim(copy_axis) = 1;
+
+ appendPermuteTasks(src, dst, loop_shape, copy_len);
+ }
+ }
+ else
+ {
+ assert(src_tensor.getShape().rank() == 4 &&
+ (permute_type == PermuteType::NHWC_TO_NCHW ||
+ permute_type == PermuteType::NCHW_TO_NHWC));
+ const auto loop_shape = src_tensor.getShape();
+ const auto copy_len = data_size;
+
+ appendPermuteTasks(src, dst, loop_shape, copy_len);
+ }
+ });
+ };
+ src->access(fn);
+ src_it++;
+ dst_it++;
+ src_offsets_it++;
+ dst_offsets_it++;
+ }
+ }
+}
+
+void PermuteLayer::appendPermuteTasks(const ITensor *src_tensor, ITensor *dst_tensor,
+ const ir::Shape &loop_shape, size_t size)
+{
+ size_t distributed_dim = 0;
+ auto src_shape = src_tensor->getShape();
+ if (src_tensor->layout() == dst_tensor->layout())
+ {
+ for (int i = 1; i < src_shape.rank() - 1; ++i)
+ {
+ distributed_dim = src_shape.dim(distributed_dim) < src_shape.dim(i) ? i : distributed_dim;
+ }
+ }
+ const auto distributed_dim_val = src_shape.dim(distributed_dim);
+ const int thread_count =
+ _external_context->ruy_context()->max_num_threads() < static_cast<int>(distributed_dim_val)
+ ? _external_context->ruy_context()->max_num_threads()
+ : distributed_dim_val;
+ // NOTE Do not remove this assertion. It would cause performance degradation by new threads to be
+ // created in the context's thread pool
+ assert(thread_count <= _external_context->ruy_context()->max_num_threads());
+
+ std::vector<PermuteWorkerTask> tasks;
+ int start = 0;
+ auto one_thread_loop_shape = loop_shape;
+ for (auto i = 0; i < thread_count; ++i)
+ {
+ ir::Coordinates start_coords(one_thread_loop_shape.rank());
+ start_coords.set(distributed_dim, start);
+ int end = start + (distributed_dim_val - start) / (thread_count - i);
+ one_thread_loop_shape.dim(distributed_dim) = end - start;
+ tasks.emplace_back(*src_tensor, *dst_tensor, start_coords, one_thread_loop_shape, size);
+ start = end;
+ }
+ assert(tasks.size() >= 1);
+ _tasks_map[src_tensor] = std::move(tasks);
+}
+
+void PermuteLayer::runPermuteTasks(backend::ITensor *src, uint8_t *dst_buffer)
+{
+ assert(src->getShape().num_elements() * ir::sizeOfDataType(src->data_type()) <=
+ src->total_size());
+ std::vector<PermuteWorkerTask> &tasks = _tasks_map.at(src);
+ for (size_t i = 0; i < tasks.size(); ++i)
+ {
+ tasks.at(i).setBuffers(src->buffer(), dst_buffer);
+ }
+ assert(tasks.size() >= 1);
+ _external_context->ruy_context()->mutable_thread_pool()->Execute(tasks.size(), tasks.data());
+}
+
+void PermuteLayer::run()
+{
+ assert(_src_tensors.size() == _dst_tensors.size());
+ // PermuteLayer infers dynamic shape inside itself whenever run is called for the following
+ // reasons:
+ // 1. PermuteLayer has to access dynamic tensor manager for input/output tensors of other backends
+ // 2. Other controlflow operation(If/While) uses this layout for copying tensors of other
+ // subgraphs(with other backends)
+ // 3. This infering code is placed here to avoid duplicated code that can be caused by above 2
+ // reasons
+
+ // check if output is not dynamic
+ for (size_t i = 0; i < _src_tensors.size(); ++i)
+ {
+ auto dst_tensor = _dst_tensors.at(i);
+ auto src_tensor = _src_tensors.at(i);
+ if (src_tensor->is_dynamic() || dst_tensor->is_dynamic())
+ {
+ // getting output shape
+ auto src_shape = src_tensor->getShape();
+
+ // set output shape and output buffer
+ ir::Shape new_shape =
+ exec::convertShape(src_shape, src_tensor->layout(), dst_tensor->layout());
+
+ try
+ {
+ if (!dst_tensor->applyShape(new_shape))
+ throw std::runtime_error{
+ "Error: PermuteLayer: output's TensorManager does not support dynamic tensor"};
+ assert(dst_tensor->buffer() != nullptr);
+ }
+ catch (const std::out_of_range &e)
+ {
+ std::cerr << "Error: out_of_range in PermuteLayer: output's TensorManager does not support "
+ "dynamic tensor"
+ << '\n';
+ throw;
+ }
+ }
+ assert(exec::convertShape(src_tensor->getShape(), src_tensor->layout(), dst_tensor->layout()) ==
+ dst_tensor->getShape());
+ }
+ assert(_src_tensors.size() == _dst_tensors.size());
+ assert(_src_tensors.size() == _src_tensors_offsets.size());
+ assert(_dst_tensors.size() == _dst_tensors_offsets.size());
+ auto src_it = _src_tensors.begin();
+ auto dst_it = _dst_tensors.begin();
+ auto src_offsets_it = _src_tensors_offsets.begin();
+ auto dst_offsets_it = _dst_tensors_offsets.begin();
+ while (src_it != _src_tensors.end())
+ {
+ auto src = *src_it;
+ auto dst = *dst_it;
+ auto &src_offsets = *src_offsets_it;
+ auto &dst_offsets = *dst_offsets_it;
+
+ if (src->total_size() == 0)
+ {
+ assert(dst->total_size() == 0);
+ }
+ else
+ {
+ if (src != dst)
+ {
+ // Conditions to run permutation with multithreading
+ // 1. The tasks for multithreathing was created
+ // 2. The tasks's size > 1
+ // 3. Both tensors are not dynamic
+ // 4. Data types of both tensors are different
+ if (_tasks_map.find(src) == _tasks_map.end() || _tasks_map.at(src).size() == 1 ||
+ src->is_dynamic() || dst->is_dynamic() ||
+ underlying_type(src->data_type()) != underlying_type(dst->data_type()))
+ {
+ permute(src, dst, src->getShape().rank(), src_offsets, dst_offsets);
+ }
+ // If dst is subtensor, we have to use clEnqueueMapBuffer instead of clEnqueueWirteBuffer
+ else if (dst->needMemoryMap() && !dst->is_subtensor())
+ {
+ if (!src->has_padding() && !dst->has_padding() && src->layout() == dst->layout())
+ {
+ // This is more effective than multi-threading
+ src->access([&](backend::ITensor &) { dst->enqueueWriteBuffer(src->buffer(), false); });
+ }
+ else
+ {
+ // TODO Optimize this block in case of that padding size of dst is big.
+ _buffers_map[dst].reserve(dst->total_size());
+ auto dst_buffer = _buffers_map[dst].data();
+
+ src->access([&](backend::ITensor &) { runPermuteTasks(src, dst_buffer); });
+ dst->enqueueWriteBuffer(dst_buffer, false);
+ }
+ }
+ else if (src->needMemoryMap() && !src->is_subtensor() && !src->has_padding() &&
+ !dst->has_padding() && src->layout() == dst->layout())
+ {
+ // This is more effective than multi-threading
+ assert(!dst->needMemoryMap());
+ dst->access([&](backend::ITensor &) { src->enqueueReadBuffer(dst->buffer(), true); });
+ }
+ else
+ {
+ auto fn = [&](backend::ITensor &) {
+ dst->access([&](backend::ITensor &) { runPermuteTasks(src, dst->buffer()); });
+ };
+ src->access(fn);
+ }
+ }
+ }
+ src_it++;
+ dst_it++;
+ src_offsets_it++;
+ dst_offsets_it++;
+ }
+}
+
+} // namespace kernel
+} // namespace builtin
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.h b/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.h
new file mode 100644
index 000000000..cf25f5447
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/kernel/PermuteLayer.h
@@ -0,0 +1,150 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUILTIN_KERNEL_PERMUTELAYER_H__
+#define __ONERT_BACKEND_BUILTIN_KERNEL_PERMUTELAYER_H__
+
+#include "../ExternalContext.h"
+#include "../../../exec/IPermuteFunction.h"
+
+#include <ruy/thread_pool.h> // from @ruy
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace kernel
+{
+
+class PermuteLayer : public onert::exec::IPermuteFunction
+{
+public:
+ PermuteLayer(const std::vector<ITensor *> &src_tensors, const std::vector<ITensor *> &dst_tensors,
+ const std::shared_ptr<ExternalContext> &external_context);
+
+ void optimize() override;
+
+ void run() override;
+
+private:
+ std::shared_ptr<ExternalContext> _external_context;
+
+private:
+ void appendPermuteTasks(const ITensor *src_tensor, ITensor *dst_tensor,
+ const ir::Shape &loop_shape, size_t size);
+
+ void runPermuteTasks(backend::ITensor *src, uint8_t *dst_buffer);
+
+ struct PermuteWorkerTask : ruy::Task
+ {
+ using Strides = ir::Coordinates;
+
+ PermuteWorkerTask(const ITensor &src_tensor, ITensor &dst_tensor,
+ const ir::Coordinates &start_coords, const ir::Shape &loop_shape, size_t size)
+ : _src_buffer{src_tensor.buffer()}, _dst_buffer{dst_tensor.buffer()},
+ _src_start_offset{src_tensor.calcOffset(start_coords)},
+ _dst_start_offset{dst_tensor.calcOffset(start_coords)}, _src_strides{}, _dst_strides{},
+ _loop_shape{loop_shape}, _size{size}, _src_layout{src_tensor.layout()},
+ _dst_layout{dst_tensor.layout()}, _is_permutation{true}
+ {
+ // Set strides
+ setStrides(src_tensor, &_src_strides);
+ setStrides(dst_tensor, &_dst_strides);
+
+ _is_permutation = (_src_layout != _dst_layout && loop_shape.rank() == 4);
+ }
+ // Constructor for a copy
+ PermuteWorkerTask(const uint8_t *src_buffer, uint8_t *dst_buffer, uint32_t src_start_offset,
+ uint32_t dst_start_offset, size_t size)
+ : _src_buffer{src_buffer}, _dst_buffer{dst_buffer}, _src_start_offset{src_start_offset},
+ _dst_start_offset{dst_start_offset}, _src_strides{0}, _dst_strides{0}, _loop_shape{1},
+ _size{size}, _src_layout{}, _dst_layout{}, _is_permutation{false}
+ {
+ // DO NOTHING
+ }
+ void setBuffers(const uint8_t *src_buffer, uint8_t *dst_buffer)
+ {
+ _src_buffer = src_buffer;
+ _dst_buffer = dst_buffer;
+ }
+ void Run() override
+ {
+ ShapeLoop(_loop_shape, [&](const onert::ir::Coordinates &coords) {
+ size_t src_offset = _src_start_offset;
+ size_t dst_offset = _dst_start_offset;
+ assert(static_cast<size_t>(_loop_shape.rank()) == coords.size());
+ ir::Coordinates dst_coords = coords;
+ if (_is_permutation)
+ {
+ dst_coords = ir::convertCoordinates(coords, _src_layout, _dst_layout);
+ }
+ for (auto i = 0; i < _loop_shape.rank(); ++i)
+ {
+ assert(coords[i] >= 0 && dst_coords[i] >= 0);
+ src_offset += coords[i] * _src_strides[i];
+ dst_offset += dst_coords[i] * _dst_strides[i];
+ }
+ memcpy(_dst_buffer + dst_offset, _src_buffer + src_offset, _size);
+ });
+ }
+
+ private:
+ void setStrides(const ITensor &tensor, Strides *strides)
+ {
+ auto shape = tensor.getShape();
+ const size_t rank = shape.rank();
+ for (size_t i = 0; i < rank; ++i)
+ {
+ ir::Coordinates no_step(rank), one_step(rank);
+ one_step.set(i, 1);
+ if (shape.dim(i) > 1)
+ {
+ strides->set(i, tensor.calcOffset(one_step) - tensor.calcOffset(no_step));
+ }
+ else
+ {
+ // If dimension value is 0 or 1, the stride of the dimension will be not used
+ // Do not call calcOffset() with coordinate value that is greater than dimension value
+ strides->set(i, 0);
+ }
+ assert((*strides)[i] >= 0);
+ }
+ }
+
+ private:
+ const uint8_t *_src_buffer;
+ uint8_t *_dst_buffer;
+ size_t _src_start_offset;
+ size_t _dst_start_offset;
+ Strides _src_strides;
+ Strides _dst_strides;
+ const ir::Shape _loop_shape;
+ const size_t _size;
+ const ir::Layout _src_layout;
+ const ir::Layout _dst_layout;
+ bool _is_permutation;
+ };
+ std::unordered_map<const ITensor *, std::vector<PermuteWorkerTask>> _tasks_map;
+};
+
+} // namespace kernel
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUILTIN_KERNEL_PERMUTELAYER_H__
diff --git a/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc
new file mode 100644
index 000000000..06e5722c8
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc
@@ -0,0 +1,151 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "WhileLayer.h"
+
+#include "PermuteLayer.h"
+#include "../../../exec/ExecutorBase.h"
+
+#include <misc/polymorphic_downcast.h>
+
+#include <algorithm>
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace kernel
+{
+
+WhileLayer::WhileLayer(const std::vector<backend::IPortableTensor *> input_tensors,
+ const std::vector<backend::IPortableTensor *> output_tensors,
+ const ir::SubgraphIndex &cond_subg_index,
+ const ir::SubgraphIndex &body_subg_index, exec::IExecutors *executors,
+ const ir::ModelIndex &model_index,
+ basic::DynamicMemoryManager *dyn_memory_manager,
+ const std::shared_ptr<ExternalContext> &external_context)
+ : _cond_subg_index{cond_subg_index}, _body_subg_index{body_subg_index},
+ _input_tensors{input_tensors}, _output_tensors{output_tensors}, _executors{executors},
+ _model_index{model_index}, _dyn_memory_manager{dyn_memory_manager},
+ _external_context{external_context}
+{
+ // At this point, executors may not have executors of cond subg and body subg
+}
+
+void WhileLayer::run()
+{
+ // Copy "_input_tensors" -> "cond subg inputs"
+ // Run cond subg
+ // Start loop while output of cond subg is ture
+ // // Copy "_input_tensors" -> "body subg inputs" in the first iteration, then copy "body subg
+ // outputs" -> "body subg inputs" in the second or more iterations
+ // // Run body subg
+ // // Copy "body subg outputs" -> "cond subg inputs"
+ // // Run cond subg
+ // If there is no loop copy "_input_tensors" -> "_dst_tensors", else copy "cond subg inputs" ->
+ // "_dst_tensors"
+ auto cond_exec = _executors->at(_model_index, _cond_subg_index);
+ auto body_exec = _executors->at(_model_index, _body_subg_index);
+
+ // Need a temp tensor to hold the cond subgraph output
+ assert(cond_exec->outputSize() == 1);
+ auto cond_output_tensor = [&]() {
+ auto tensor = std::make_unique<Tensor>(cond_exec->outputInfo(0), cond_exec->outputLayout(0),
+ _dyn_memory_manager);
+ tensor->set_dynamic();
+ tensor->setBuffer(_dyn_memory_manager->allocate(tensor.get(), tensor->total_size()));
+ return tensor;
+ }();
+
+ VERBOSE(While) << "Call to $" << _cond_subg_index << " (cond)" << std::endl;
+ const auto &options = _executors->entryExecutor()->currentOptions();
+ cond_exec->execute(_input_tensors, {cond_output_tensor.get()}, options);
+ VERBOSE(While) << "Return from $" << _cond_subg_index << std::endl;
+
+ auto getResultCond = [](backend::ITensor *tensor) -> bool {
+ bool ret = false;
+ tensor->access([&](ITensor &tensor) { ret = *reinterpret_cast<bool *>(tensor.buffer()); });
+ return ret;
+ };
+
+ std::vector<ITensor *> op_inputs(_input_tensors.begin(), _input_tensors.end());
+ std::vector<ITensor *> op_outputs(_output_tensors.begin(), _output_tensors.end());
+ // Copying body inputs to outputs when the loop body is never executed
+ if (!getResultCond(cond_output_tensor.get()))
+ {
+ PermuteLayer copy_body_inputs_to_op_outputs{op_inputs, op_outputs, _external_context};
+ copy_body_inputs_to_op_outputs.run();
+ return;
+ }
+
+ // Need some temp tensors to hold the body subgraph output
+ std::vector<std::unique_ptr<Tensor>> temp_outputs_o;
+ std::vector<IPortableTensor *> temp_outputs;
+ for (uint32_t i = 0; i < body_exec->outputSize(); i++)
+ {
+ auto tensor = std::make_unique<Tensor>(body_exec->outputInfo(i), body_exec->outputLayout(i),
+ _dyn_memory_manager);
+ tensor->set_dynamic();
+ tensor->setBuffer(_dyn_memory_manager->allocate(tensor.get(), tensor->total_size()));
+ temp_outputs.push_back(tensor.get());
+ temp_outputs_o.push_back(std::move(tensor));
+ }
+
+ std::vector<ITensor *> body_outputs(temp_outputs.begin(), temp_outputs.end());
+ PermuteLayer copy_body_outputs_to_op_outputs{body_outputs, op_outputs, _external_context};
+
+ const auto body_execute_with_op_inputs = [&]() {
+ VERBOSE(While) << "Call to $" << _body_subg_index << " (body)" << std::endl;
+ body_exec->execute(_input_tensors, temp_outputs, options);
+ VERBOSE(While) << "Return from $" << _body_subg_index << std::endl;
+ };
+
+ const auto body_execute_with_body_outputs = [&]() {
+ VERBOSE(While) << "Call to $" << _body_subg_index << " (body)" << std::endl;
+ body_exec->execute(_output_tensors, temp_outputs, options);
+ VERBOSE(While) << "Return from $" << _body_subg_index << std::endl;
+ };
+
+ std::function<void()> body_execute = body_execute_with_op_inputs;
+ const auto cond_execute = [&]() {
+ VERBOSE(While) << "Call to $" << _cond_subg_index << " (cond)" << std::endl;
+ cond_exec->execute(_output_tensors, {cond_output_tensor.get()}, options);
+ VERBOSE(While) << "Return from $" << _cond_subg_index << std::endl;
+ };
+
+ // Loop while Cond subgraph's output is true
+ while (getResultCond(cond_output_tensor.get()))
+ {
+ body_execute();
+ copy_body_outputs_to_op_outputs.run();
+ cond_execute();
+ body_execute = body_execute_with_body_outputs;
+ }
+
+ // Clean-up the temp tensors
+ _dyn_memory_manager->deallocate(cond_output_tensor.get());
+ for (auto &&tensor : temp_outputs)
+ {
+ _dyn_memory_manager->deallocate(tensor);
+ }
+}
+
+} // namespace kernel
+} // namespace builtin
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/controlflow/kernel/WhileLayer.h b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.h
index ebca8acdc..40ca4fe23 100644
--- a/runtime/onert/core/src/backend/controlflow/kernel/WhileLayer.h
+++ b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.h
@@ -14,20 +14,23 @@
* limitations under the License.
*/
-#ifndef __ONERT_BACKEND_CONTROLFLOW_KERNEL_WHILE_LAYER_H__
-#define __ONERT_BACKEND_CONTROLFLOW_KERNEL_WHILE_LAYER_H__
+#ifndef __ONERT_BACKEND_BUILTIN_KERNEL_WHILE_LAYER_H__
+#define __ONERT_BACKEND_BUILTIN_KERNEL_WHILE_LAYER_H__
-#include <backend/ITensor.h>
-#include <exec/IExecutor.h>
+#include <backend/IPortableTensor.h>
+#include <exec/IExecutors.h>
#include <exec/IFunction.h>
#include <ir/OperandIndexSequence.h>
#include <ir/Graph.h>
+#include "../ExternalContext.h"
+
+#include "backend/basic/MemoryManager.h"
namespace onert
{
namespace backend
{
-namespace controlflow
+namespace builtin
{
namespace kernel
{
@@ -35,12 +38,12 @@ namespace kernel
class WhileLayer : public ::onert::exec::IFunction
{
public:
- WhileLayer(const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors,
- const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors,
- const ir::OperandIndexSequence &output_indices, const ir::Graph &graph,
- const exec::DynAllocInfoMap &outputs_dyn_alloc_info,
+ WhileLayer(const std::vector<backend::IPortableTensor *> input_tensors,
+ const std::vector<backend::IPortableTensor *> output_tensors,
const ir::SubgraphIndex &cond_subg_index, const ir::SubgraphIndex &body_subg_index,
- exec::ExecutorMap *executor_map);
+ exec::IExecutors *executors, const ir::ModelIndex &model_index,
+ basic::DynamicMemoryManager *dyn_memory_manager,
+ const std::shared_ptr<ExternalContext> &external_context);
public:
void run() override;
@@ -48,17 +51,17 @@ public:
private:
const ir::SubgraphIndex _cond_subg_index;
const ir::SubgraphIndex _body_subg_index;
- const ir::OperandIndexSequence &_output_indices;
- const ir::Graph &_graph;
- const std::vector<std::shared_ptr<backend::ITensor>> _input_tensors;
- const std::vector<std::shared_ptr<backend::ITensor>> _output_tensors;
- const exec::DynAllocInfoMap _outputs_dyn_alloc_info;
- exec::ExecutorMap *_executor_map;
+ const std::vector<backend::IPortableTensor *> _input_tensors;
+ const std::vector<backend::IPortableTensor *> _output_tensors;
+ exec::IExecutors *_executors;
+ const ir::ModelIndex _model_index;
+ basic::DynamicMemoryManager *_dyn_memory_manager; // For generating temp tensors
+ const std::shared_ptr<ExternalContext> _external_context;
};
} // namespace kernel
-} // namespace controlflow
+} // namespace builtin
} // namespace backend
} // namespace onert
-#endif // __ONERT_BACKEND_CONTROLFLOW_KERNEL_WHILE_LAYER_H__
+#endif // __ONERT_BACKEND_BUILTIN_KERNEL_WHILE_LAYER_H__
diff --git a/runtime/onert/core/src/backend/builtin/train/BackendContext.cc b/runtime/onert/core/src/backend/builtin/train/BackendContext.cc
new file mode 100644
index 000000000..69483eade
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/BackendContext.cc
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "BackendContext.h"
+
+#include "backend/basic/train/TrainableBackendContextHelpers.h"
+#include "exec/FunctionSequence.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+
+backend::ITensorRegistry *BackendContext::genTensors()
+{
+ // For now, there is no need to generate tensors for forwarding.
+ // builtin train backend handles 3 operators: `Permute`, `IF`, `WHILE`.
+ // `Permute`: Tensor generation is not required.
+ // `IF`, `WHILE`: Not supported yet
+ return tensor_registry().get();
+}
+
+backend::train::ITensorRegistry *BackendContext::genTrainingTensors()
+{
+ // For now, there is no need to generate tensors for backwarding.
+ return tensor_registry().get();
+}
+
+backend::train::FunctionMap BackendContext::genKernels()
+{
+ backend::train::FunctionMap ret;
+
+ for (auto &&op_ind : _tdata->op_order)
+ {
+ auto tn_seq = kernel_gen->generate(op_ind);
+ ret.emplace(op_ind, std::move(tn_seq));
+ }
+
+ trainable_graph()->operands().iterate(
+ [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
+ if (!external_operands().contains(ind) && operand.isConstant())
+ {
+ throw std::runtime_error(
+ "BackendContext: builtin backend does not support updatable weights yet");
+ }
+ });
+
+ // TODO Enable prepare()
+ // for (auto &&it : ret)
+ // {
+ // auto &fn_seq = it.second;
+ // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
+ // }
+
+ return ret;
+}
+
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/builtin/train/BackendContext.h b/runtime/onert/core/src/backend/builtin/train/BackendContext.h
new file mode 100644
index 000000000..4782756c3
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/BackendContext.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_BACKEND_CONTEXT_H__
+#define __ONERT_BACKEND_BUILTIN_TRAIN_BACKEND_CONTEXT_H__
+
+#include <backend/train/TrainableBackendContext.h>
+
+#include "KernelGenerator.h"
+#include "../ExternalContext.h"
+#include "../TensorBuilder.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+
+class BackendContext : public backend::train::TrainableBackendContext
+{
+public:
+ BackendContext(const backend::train::ITrainableBackend *backend,
+ std::unique_ptr<backend::train::TrainableContextData> &&data,
+ std::shared_ptr<backend::train::ITensorRegistry> tensor_registry = nullptr,
+ std::shared_ptr<TensorBuilder> tensor_builder = nullptr,
+ std::shared_ptr<KernelGenerator> kernel_gen = nullptr)
+ : backend::train::TrainableBackendContext(backend, std::move(data), tensor_registry),
+ kernel_gen{kernel_gen}, _external_context(new ExternalContext),
+ _tensor_builder{tensor_builder}
+ {
+ }
+
+ backend::ITensorRegistry *genTensors() override;
+ backend::train::ITensorRegistry *genTrainingTensors() override;
+
+public:
+ backend::train::FunctionMap genKernels() override;
+
+ std::shared_ptr<ExternalContext> external_context() { return _external_context; }
+
+public:
+ // TODO Make it private
+ std::shared_ptr<KernelGenerator> kernel_gen;
+
+private:
+ // NOTE ruy context has a thread pool, and when multiple ruy contexts are created,
+ // the thread pool is also created in duplicate
+ // TODO Create one ruy context for session
+ std::shared_ptr<ExternalContext> _external_context;
+
+private:
+ std::shared_ptr<TensorBuilder> _tensor_builder;
+};
+
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUILTIN_TRAIN_BACKEND_CONTEXT_H__
diff --git a/runtime/onert/core/src/backend/builtin/train/KernelGenerator.cc b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.cc
new file mode 100644
index 000000000..32032de4a
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.cc
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "KernelGenerator.h"
+
+#include "kernel/PermuteLayer.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+
+KernelGenerator::KernelGenerator(const ir::train::TrainableGraph &tgraph,
+ const std::shared_ptr<TensorRegistry> &tensor_reg,
+ const std::shared_ptr<ExternalContext> &external_context)
+ : KernelGeneratorBase{tgraph}, _tensor_reg{tensor_reg}, _external_context(external_context)
+{
+}
+
+std::unique_ptr<exec::train::TrainableFnSequence> KernelGenerator::generate(ir::OperationIndex ind)
+{
+ auto ret = std::make_unique<exec::train::TrainableFnSequence>();
+ const auto &op = _tgraph.operation(ind);
+ op.accept(*this);
+ // _return_fn must have been generated
+ if (_return_fn == nullptr)
+ {
+ throw std::runtime_error(op.name() + " op does not supported trainable kernel yet");
+ }
+
+ ret->_functions.emplace_back(std::move(_return_fn));
+
+ return ret;
+}
+
+void KernelGenerator::visit(const ir::train::operation::Permute &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(0)};
+
+ // Add PermuteLayer
+ std::vector<ITensor *> output_tensors{getTensor(output_index)};
+ std::vector<ITensor *> input_tensors{getTensor(input_index)};
+
+ std::vector<ITensor *> output_back_prop_tensors;
+ std::vector<ITensor *> input_back_prop_tensors;
+
+ auto input_back_prop_tensor = getBackPropTensor(input_index);
+ auto output_back_prop_tensor = getBackPropTensor(output_index);
+ output_back_prop_tensors.emplace_back(output_back_prop_tensor);
+ input_back_prop_tensors.emplace_back(input_back_prop_tensor);
+
+ // NOTE The output buffers of IOTensors are not essential for training. If there
+ // is no output buffer provided by the user, permute is not performed.
+ bool ignore_forward_in_training = false;
+ for (const auto dst_tensor : output_tensors)
+ {
+ if (dst_tensor->buffer() == nullptr || dst_tensor->total_size() == 0)
+ ignore_forward_in_training = true;
+ }
+
+ auto fn = std::make_unique<kernel::PermuteLayer>(
+ input_tensors, output_tensors, input_back_prop_tensors, output_back_prop_tensors,
+ ignore_forward_in_training, _external_context);
+
+ _return_fn = std::move(fn);
+}
+
+backend::ITensor *KernelGenerator::getTensor(const ir::OperandIndex &index)
+{
+ // Get Tensor from all tensor registries (for Permute op)
+ auto ret = _tensor_registries.getITensor(index);
+ assert(ret != nullptr);
+ return ret;
+}
+
+backend::ITensor *KernelGenerator::getBackPropTensor(const ir::OperandIndex &index)
+{
+ // Get back propagation Tensor from all tensor registries (for Permute op)
+ auto ret = _tensor_registries.getBackPropITensor(index);
+ return ret;
+}
+
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/builtin/train/KernelGenerator.h b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.h
new file mode 100644
index 000000000..162955b6d
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/KernelGenerator.h
@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUTIN_TRAIN_KERNEL_GENERATOR_H__
+#define __ONERT_BACKEND_BUTIN_TRAIN_KERNEL_GENERATOR_H__
+
+#include "../ExternalContext.h"
+#include "../train/TensorRegistry.h"
+#include "../../../compiler/train/TensorRegistries.h"
+
+#include <backend/train/KernelGeneratorBase.h>
+#include <exec/train/TrainableFnSequence.h>
+#include <ir/train/TrainableGraph.h>
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+
+class KernelGenerator : public backend::train::KernelGeneratorBase
+{
+public:
+ KernelGenerator(const ir::train::TrainableGraph &tgraph,
+ const std::shared_ptr<TensorRegistry> &tensor_reg,
+ const std::shared_ptr<ExternalContext> &external_context);
+
+ std::unique_ptr<exec::train::TrainableFnSequence> generate(ir::OperationIndex ind) override;
+
+ void setTensorRegistries(const compiler::train::TensorRegistries &tensor_registries)
+ {
+ _tensor_registries = tensor_registries;
+ }
+
+ void setWholeGraphOutputs(const ir::OperandIndexSequence &outputs)
+ {
+ _whole_graph_outputs = outputs;
+ }
+
+private:
+ void visit(const ir::train::operation::Permute &) override;
+
+private:
+ backend::ITensor *getTensor(const ir::OperandIndex &index);
+ backend::ITensor *getBackPropTensor(const ir::OperandIndex &index);
+
+private:
+ std::shared_ptr<TensorRegistry> _tensor_reg;
+ compiler::train::TensorRegistries _tensor_registries;
+ const std::shared_ptr<ExternalContext> _external_context;
+ ir::OperandIndexSequence _whole_graph_outputs;
+};
+
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUTIN_TRAIN_KERNEL_GENERATOR_H__
diff --git a/runtime/onert/core/src/backend/builtin/train/Tensor.h b/runtime/onert/core/src/backend/builtin/train/Tensor.h
new file mode 100644
index 000000000..baf42796c
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/Tensor.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_TRAINABLE_TENSOR_H__
+#define __ONERT_BACKEND_BUILTIN_TRAIN_TRAINABLE_TENSOR_H__
+
+#include <backend/basic/train/TrainableTensor.h>
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+
+using TrainableTensor = basic::train::TrainableTensor;
+using BackPropTensor = basic::Tensor;
+using GradientTensor = basic::Tensor;
+
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUILTIN_TRAIN_TRAINABLE_TENSOR_H__
diff --git a/runtime/onert/core/src/backend/builtin/train/TensorRegistry.h b/runtime/onert/core/src/backend/builtin/train/TensorRegistry.h
new file mode 100644
index 000000000..7c8166bde
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/TensorRegistry.h
@@ -0,0 +1,140 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__
+#define __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__
+
+#include <backend/train/ITensorRegistry.h>
+
+#include "../IOTensor.h"
+#include "../Tensor.h"
+#include "Tensor.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+
+using BaseTensorRegistry =
+ backend::train::PortableTensorRegistryTemplate<Tensor, TrainableTensor, BackPropTensor,
+ GradientTensor>;
+
+class TensorRegistry : public backend::train::ITensorRegistry
+{
+public:
+ TensorRegistry() : _base_reg{new BaseTensorRegistry} {}
+
+ ITensor *getITensor(const ir::OperandIndex &index) override
+ {
+ auto base_tensor = _base_reg->getITensor(index);
+ if (base_tensor)
+ return base_tensor;
+ return getNativeIOTensor(index);
+ }
+
+ ITensor *getNativeITensor(const ir::OperandIndex &index) override
+ {
+ auto base_tensor = _base_reg->getNativeITensor(index);
+ if (base_tensor)
+ return base_tensor;
+ return getNativeIOTensor(index);
+ }
+
+ IPortableTensor *getPortableTensor(const ir::OperandIndex &index)
+ {
+ auto base_tensor = _base_reg->getPortableTensor(index);
+ if (base_tensor)
+ return base_tensor;
+ return getNativeIOTensor(index);
+ }
+
+ IOTensor *getNativeIOTensor(const ir::OperandIndex &index)
+ {
+ auto tensor = _native_io_tensors.find(index);
+ if (tensor != _native_io_tensors.end())
+ return tensor->second.get();
+ return nullptr;
+ }
+
+ ITensor *getBackPropITensor(const ir::OperandIndex &index) override
+ {
+ return _base_reg->getBackPropTensor(index);
+ }
+
+ ITensor *getGradientITensor(const ir::OperandIndex &index) override
+ {
+ return _base_reg->getGradientTensor(index);
+ }
+
+ BackPropTensor *getBackPropTensor(const ir::OperandIndex &index)
+ {
+ return _base_reg->getBackPropTensor(index);
+ }
+
+ bool setMigrantTensor(const ir::OperandIndex &index, IPortableTensor *tensor) override
+ {
+ assert(tensor);
+ assert(!getITensor(index)); // For the index, tensor is not registered yet
+ _base_reg->setMigrantTensor(index, tensor);
+ return true;
+ }
+
+ void iterateTrainableTensors(
+ const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &)
+ const override
+ {
+ // DO NOTHING
+ // Builtin tensor registry does not have trainable tensor.
+ }
+
+ void setBackPropTensor(const ir::OperandIndex &index, std::unique_ptr<BackPropTensor> tensor)
+ {
+ _base_reg->setBackPropTensor(index, std::move(tensor));
+ }
+
+ void setGradientTensor(const ir::OperandIndex &index, std::unique_ptr<GradientTensor> tensor)
+ {
+ _base_reg->setGradientTensor(index, std::move(tensor));
+ }
+
+ void setNativeIOTensor(ir::OperandIndex index, std::unique_ptr<IOTensor> &&tensor)
+ {
+ assert(tensor);
+ assert(!getITensor(index)); // For the index, tensor is not registered yet
+ _native_io_tensors[index] = std::move(tensor);
+ }
+
+ const ir::OperandIndexMap<std::unique_ptr<IOTensor>> &native_io_tensors()
+ {
+ return _native_io_tensors;
+ }
+ std::shared_ptr<BaseTensorRegistry> base_reg() { return _base_reg; }
+
+private:
+ std::shared_ptr<BaseTensorRegistry> _base_reg;
+ ir::OperandIndexMap<std::unique_ptr<IOTensor>> _native_io_tensors;
+};
+
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUILTIN_TRAIN_TENSOR_REGISTRY_H__
diff --git a/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.cc b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.cc
new file mode 100644
index 000000000..dce7482e2
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.cc
@@ -0,0 +1,87 @@
+
+
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PermuteLayer.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+namespace kernel
+{
+
+PermuteLayer::PermuteLayer(const std::vector<ITensor *> &src_tensors,
+ const std::vector<ITensor *> &dst_tensors,
+ const std::vector<ITensor *> &input_back_prop_tensors,
+ const std::vector<ITensor *> &output_back_prop_tensors,
+ bool ignore_forward_in_training,
+ const std::shared_ptr<ExternalContext> &external_context)
+ : builtin::kernel::PermuteLayer{src_tensors, dst_tensors, external_context},
+ _input_back_prop_tensors{input_back_prop_tensors},
+ _output_back_prop_tensors{output_back_prop_tensors},
+ _ignore_forward_in_training{ignore_forward_in_training}
+{
+ assert(input_back_prop_tensors.size() == output_back_prop_tensors.size());
+ assert(src_tensors.size() == dst_tensors.size());
+}
+
+void PermuteLayer::optimize()
+{
+ builtin::kernel::PermuteLayer::optimize();
+
+ // TODO Calculate offsets of back propagation tensors if necessary
+}
+
+void PermuteLayer::forward(bool)
+{
+ if (_ignore_forward_in_training)
+ return;
+
+ builtin::kernel::PermuteLayer::run();
+}
+
+void PermuteLayer::backward()
+{
+ for (uint32_t i = 0; i < _output_back_prop_tensors.size(); ++i)
+ {
+ auto src_back_prop = _output_back_prop_tensors.at(i);
+ auto dst_back_prop = _input_back_prop_tensors.at(i);
+
+ // NOTE The back propagation tensors corresponding to inputs/outputs of model are nullptr
+ // because permuting those tensors is meaningless
+ if (src_back_prop && dst_back_prop)
+ {
+ const auto rank = src_back_prop->getShape().rank();
+ auto output_offsets = _dst_tensors_offsets.at(i);
+ auto input_offsets = _src_tensors_offsets.at(i);
+
+ exec::IPermuteFunction::permute(src_back_prop, dst_back_prop, rank, output_offsets,
+ input_offsets);
+ }
+ }
+}
+
+} // namespace kernel
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
diff --git a/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.h b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.h
new file mode 100644
index 000000000..1dc221b09
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/train/kernel/PermuteLayer.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_BACKEND_BUILTIN_TRAIN_KERNEL_PERMUTELAYER_H__
+#define __ONERT_BACKEND_BUILTIN_TRAIN_KERNEL_PERMUTELAYER_H__
+
+#include "../../kernel/PermuteLayer.h"
+
+#include "exec/train/ITrainableFunction.h"
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace train
+{
+namespace kernel
+{
+
+class PermuteLayer : public builtin::kernel::PermuteLayer, public exec::train::ITrainableFunction
+{
+public:
+ PermuteLayer(const std::vector<ITensor *> &src_tensors, const std::vector<ITensor *> &dst_tensors,
+ const std::vector<ITensor *> &input_back_prop_tensors,
+ const std::vector<ITensor *> &output_back_prop_tensors,
+ bool ignore_forward_in_training,
+ const std::shared_ptr<ExternalContext> &external_context);
+
+ void optimize() override;
+
+ void forward(bool training) override;
+ void backward() override;
+
+private:
+ std::vector<ITensor *> _input_back_prop_tensors;
+ std::vector<ITensor *> _output_back_prop_tensors;
+ bool _ignore_forward_in_training;
+};
+
+} // namespace kernel
+} // namespace train
+} // namespace builtin
+} // namespace backend
+} // namespace onert
+
+#endif // __ONERT_BACKEND_BUILTIN_TRAIN_KERNEL_PERMUTELAYER_H__
diff --git a/runtime/onert/core/src/backend/controlflow/ConstantInitializer.h b/runtime/onert/core/src/backend/controlflow/ConstantInitializer.h
deleted file mode 100644
index e21a8f357..000000000
--- a/runtime/onert/core/src/backend/controlflow/ConstantInitializer.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __ONERT_COMPILER_CONTROLFLOW_CONSTANT_INITIALIZER_H__
-#define __ONERT_COMPILER_CONTROLFLOW_CONSTANT_INITIALIZER_H__
-
-#include "TensorRegistry.h"
-
-#include <backend/IConstantInitializer.h>
-#include <ir/Operands.h>
-
-namespace onert
-{
-namespace backend
-{
-namespace controlflow
-{
-
-class ConstantInitializer : public IConstantInitializer
-{
-public:
- ConstantInitializer(const ir::Operands &operands,
- const std::shared_ptr<ITensorRegistry> &tensor_reg)
- : IConstantInitializer{operands}, _tensor_reg{tensor_reg}
- {
- }
-
-private:
- std::shared_ptr<ITensorRegistry> tensor_registry() const override { return _tensor_reg; }
-
-private:
- std::shared_ptr<ITensorRegistry> _tensor_reg;
-};
-
-} // namespace controlflow
-} // namespace backend
-} // namespace onert
-
-#endif // __ONERT_COMPILER_CONTROLFLOW_CONSTANT_INITIALIZER_H__
diff --git a/runtime/onert/core/src/backend/controlflow/DynamicTensorManager.cc b/runtime/onert/core/src/backend/controlflow/DynamicTensorManager.cc
deleted file mode 100644
index 1288e4c96..000000000
--- a/runtime/onert/core/src/backend/controlflow/DynamicTensorManager.cc
+++ /dev/null
@@ -1,144 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "DynamicTensorManager.h"
-
-#include "util/logging.h"
-#include "util/Exceptions.h"
-#include "ir/DataType.h"
-
-namespace onert
-{
-namespace backend
-{
-namespace controlflow
-{
-
-DynamicTensorManager::DynamicTensorManager(const std::shared_ptr<TensorRegistry> &tensors)
- : _dynamic_mem_mgr{new cpu_common::DynamicMemoryManager()}, _tensors{tensors}
-{
- // DO NOTHING
-}
-
-void DynamicTensorManager::applyShape(const ir::OperandIndex &ind, const ir::Shape &new_shape)
-{
- // NOTE Handle user tensors first
- auto user_tensor = _tensors->getNativeUserTensor(ind);
- if (user_tensor)
- {
- // User tensors cannot be reallocated.
- auto buffer_size = user_tensor->total_size();
- auto new_size = new_shape.num_elements() * sizeOfDataType(user_tensor->data_type());
- if (buffer_size < new_size)
- throw InsufficientBufferSizeException{"Output buffer size is less than output tensor size"};
- user_tensor->setShape(new_shape);
- return;
- }
-
- // NOTE Then handle own tensors
- auto tensor = _tensors->getNativeOwnTensor(ind);
- assert(tensor);
-
- bool previously_dynamic = tensor->is_dynamic();
-
- auto allocTensorMem = [&](bool overwrite = false) {
- auto capacity = tensor->total_size();
- auto alloc = _dynamic_mem_mgr->allocate(ind, capacity);
-
- if (overwrite)
- tensor->overwriteBuffer(alloc);
- else
- tensor->setBuffer(alloc);
- };
-
- if (!previously_dynamic)
- {
- // TODO deallocate tensor->buffer()
- // issue is that staticTensorManager might have allocate this memory
- tensor->setShape(new_shape);
- tensor->set_dynamic();
- allocTensorMem(true);
- }
- else if (tensor->buffer() == nullptr)
- {
- tensor->setShape(new_shape);
- tensor->set_dynamic();
- allocTensorMem();
- }
- // when buffer was already allocated and new_shape requires different size
- else
- {
- auto previous_size = tensor->total_size();
- auto new_size = new_shape.num_elements() * sizeOfDataType(tensor->data_type());
- if (previous_size != new_size)
- {
- _dynamic_mem_mgr->deallocate(ind);
-
- tensor->setShape(new_shape);
- tensor->set_dynamic();
- allocTensorMem(true);
- }
- else
- { // when buffer with same size was already allocated, shape could differ
- tensor->setShape(new_shape);
- }
- }
-}
-
-void DynamicTensorManager::buildTensor(const ir::OperandIndex &ind,
- const ir::OperandInfo &tensor_info,
- ir::Layout backend_layout)
-{
- auto tensor = std::make_shared<cpu_common::Tensor>(tensor_info, backend_layout, this);
- _tensors->setNativeOwnTensor(ind, tensor);
-}
-
-void DynamicTensorManager::planDealloc(ir::OperationIndex op_ind, ir::OperandIndex operand_ind)
-{
- _dealloc_tensor_map[op_ind].emplace(operand_ind);
-}
-
-void DynamicTensorManager::deallocInput(ir::OperationIndex op_ind)
-{
- auto find = _dealloc_tensor_map.find(op_ind);
- if (find == _dealloc_tensor_map.end())
- return;
-
- auto &input_set = find->second;
- for (auto input_ind : input_set)
- {
- if (!_tensors->getNativeTensor(input_ind)->is_dynamic())
- continue;
-
- _dynamic_mem_mgr->deallocate(input_ind);
- VERBOSE(DynamicTensorManager) << "Deallocating #" << input_ind.value()
- << " (input of op_ind: " << op_ind.value() << ")" << std::endl;
- }
-}
-
-void DynamicTensorManager::deallocSubgraphOutput(ir::OperandIndex output_ind)
-{
- if (!_tensors->getNativeTensor(output_ind)->is_dynamic())
- return;
-
- _dynamic_mem_mgr->deallocate(output_ind);
- VERBOSE(DynamicTensorManager) << "Deallocating #" << output_ind.value()
- << " (output of a subgraph)" << std::endl;
-}
-
-} // namespace controlflow
-} // namespace backend
-} // namespace onert
diff --git a/runtime/onert/core/src/backend/controlflow/DynamicTensorManager.h b/runtime/onert/core/src/backend/controlflow/DynamicTensorManager.h
deleted file mode 100644
index dbe388ba2..000000000
--- a/runtime/onert/core/src/backend/controlflow/DynamicTensorManager.h
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __ONERT_BACKEND_CONTROLFLOW_DYNAMICTENSOR_MANAGER_H__
-#define __ONERT_BACKEND_CONTROLFLOW_DYNAMICTENSOR_MANAGER_H__
-
-#include "TensorRegistry.h"
-#include "Tensor.h"
-
-#include <backend/IDynamicTensorManager.h>
-#include <backend/cpu_common/MemoryManager.h>
-#include <ir/OperandInfo.h>
-#include <ir/Operation.h>
-#include <ir/Index.h>
-
-namespace onert
-{
-namespace backend
-{
-namespace controlflow
-{
-
-/**
- * @brief Class to manage dynamic tensor and its memory
- */
-class DynamicTensorManager : public backend::IDynamicTensorManager
-{
-public:
- DynamicTensorManager(const std::shared_ptr<TensorRegistry> &tensors);
-
- virtual ~DynamicTensorManager() = default;
-
- void applyShape(const ir::OperandIndex &ind, const ir::Shape &new_shape) override;
-
- void buildTensor(const ir::OperandIndex &ind, const ir::OperandInfo &tensor_info,
- ir::Layout backend_layout);
-
- void planDealloc(ir::OperationIndex op_ind, ir::OperandIndex operand_ind) override;
- void deallocInput(ir::OperationIndex op_ind) override;
- void deallocSubgraphOutput(ir::OperandIndex ind) override;
-
-private:
- /**
- * @brief Memory manager for dynamic tensor.
- * @todo DynamicMemoryManager is not optimized. Optimized one is needed
- */
- std::shared_ptr<cpu_common::DynamicMemoryManager> _dynamic_mem_mgr;
- const std::shared_ptr<TensorRegistry> _tensors;
-
- // contains list of dynamic tensor index, which can be deallocated after running operation
- // note: this map could contain static tensor index too. Careful use is required.
- std::unordered_map<ir::OperationIndex, std::unordered_set<ir::OperandIndex>> _dealloc_tensor_map;
-};
-
-} // namespace controlflow
-} // namespace backend
-} // namespace onert
-
-#endif // __ONERT_BACKEND_CONTROLFLOW_DYNAMICTENSOR_MANAGER_H__
diff --git a/runtime/onert/core/src/backend/controlflow/KernelGenerator.cc b/runtime/onert/core/src/backend/controlflow/KernelGenerator.cc
deleted file mode 100644
index de5a6a5f6..000000000
--- a/runtime/onert/core/src/backend/controlflow/KernelGenerator.cc
+++ /dev/null
@@ -1,171 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "KernelGenerator.h"
-
-#include <backend/BackendContext.h>
-#include <util/Utils.h>
-#include "kernel/IfLayer.h"
-#include "kernel/WhileLayer.h"
-#include "kernel/PermuteLayer.h"
-#include "exec/ExecutorBase.h"
-#include "exec/FunctionSequence.h"
-
-namespace onert
-{
-namespace backend
-{
-namespace controlflow
-{
-
-KernelGenerator::KernelGenerator(const ir::Graph &graph, IDynamicTensorManager *dyn_tensor_manager,
- const std::shared_ptr<TensorRegistry> &tensor_reg)
- : _graph{graph}, _dyn_tensor_manager{dyn_tensor_manager}, _tensor_reg{tensor_reg},
- _tensor_registries{}, _executor_map{nullptr}
-{
- UNUSED_RELEASE(_graph);
- UNUSED_RELEASE(_tensor_registries);
- UNUSED_RELEASE(_executor_map);
-}
-
-void KernelGenerator::visit(const ir::OpSequence &op_seq)
-{
- assert(!_return_fn_seq);
- assert(_dyn_tensor_manager);
- assert(_tensor_reg);
-
- auto dyn_shape_inferer =
- std::make_unique<exec::DynamicShapeInferer>(_graph.operands(), _tensor_reg);
-
- _return_fn_seq = std::make_unique<exec::FunctionSequence>();
-
- // Prepare to handle dynamic tensors later
- auto dyn_ctx = std::make_shared<exec::FunctionSequence::DynamicTensorCtx>();
- {
- dyn_ctx->op_seq = &op_seq;
- dyn_ctx->operations = &_graph.operations();
- dyn_ctx->dynamic_shape_inferer = std::move(dyn_shape_inferer);
- dyn_ctx->tensor_registry = _tensor_reg;
- dyn_ctx->dynamic_tensor_manager = _dyn_tensor_manager;
-
- _return_fn_seq->dynamic_tensor_ctx(dyn_ctx);
- }
- _return_fn_seq->enableDynamicShapeInferer(true);
-
- for (const auto &op_idx : op_seq.operations())
- {
- const auto &node = _graph.operations().at(op_idx);
- node.accept(*this);
- _return_fn_seq->append(releaseFunction());
- }
-}
-
-void KernelGenerator::visit(const ir::operation::If &node)
-{
- const auto then_subg_index = node.param().then_subg_index;
- const auto else_subg_index = node.param().else_subg_index;
-
- std::vector<std::shared_ptr<backend::ITensor>> input_tensors;
- for (const auto input_index : node.getInputs())
- {
- auto input_tensor = getTensor(input_index);
-
- input_tensors.emplace_back(input_tensor);
- }
-
- std::vector<std::shared_ptr<backend::ITensor>> output_tensors;
- exec::DynAllocInfoMap outputs_dyn_alloc_info;
- for (const auto output_index : node.getOutputs())
- {
- auto output_tensor = getTensor(output_index);
-
- output_tensors.emplace_back(output_tensor);
- outputs_dyn_alloc_info[output_tensor] = exec::DynAllocInfo{output_index};
- }
-
- // IfLayer just set ExecutorMap instead of then and else executor to avoid complexity of
- // creating executor recusively
- const auto cond_tensor = input_tensors.front();
- input_tensors.erase(input_tensors.begin());
- auto fn = std::make_unique<::onert::backend::controlflow::kernel::IfLayer>(
- cond_tensor, input_tensors, output_tensors, node.getOutputs(), _graph, outputs_dyn_alloc_info,
- then_subg_index, else_subg_index, _executor_map);
-
- _return_fn = std::move(fn);
-}
-
-void KernelGenerator::visit(const ir::operation::Permute &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto input_index{node.getInputs().at(0)};
-
- // Add PermuteLayer
- std::vector<std::shared_ptr<ITensor>> output_tensors{getTensor(output_index)};
- std::vector<std::shared_ptr<ITensor>> input_tensors{getTensor(input_index)};
- std::unordered_map<std::shared_ptr<ITensor>, exec::DynAllocInfo> outputs_dyn_alloc_info;
- outputs_dyn_alloc_info[output_tensors.at(0)] = exec::DynAllocInfo{output_index};
-
- auto fn =
- std::make_unique<kernel::PermuteLayer>(input_tensors, output_tensors, outputs_dyn_alloc_info);
-
- _return_fn = std::move(fn);
-}
-
-void KernelGenerator::visit(const ir::operation::While &node)
-{
- const auto cond_subg_index = node.param().cond_subg_index;
- const auto body_subg_index = node.param().body_subg_index;
-
- // This op does not support input as a constant, because controlflow backend does not have
- // TensorBuilder
- std::vector<std::shared_ptr<backend::ITensor>> input_tensors;
- for (const auto input_index : node.getInputs())
- {
- auto input_tensor = getTensor(input_index);
-
- input_tensors.emplace_back(input_tensor);
- }
-
- std::vector<std::shared_ptr<backend::ITensor>> output_tensors;
- std::unordered_map<std::shared_ptr<ITensor>, exec::DynAllocInfo> outputs_dyn_alloc_info;
- for (const auto output_index : node.getOutputs())
- {
- auto output_tensor = getTensor(output_index);
-
- output_tensors.emplace_back(output_tensor);
-
- outputs_dyn_alloc_info[output_tensor] = exec::DynAllocInfo{output_index};
- }
-
- // WhileLayer just set ExecutorMap instead of cond and body executor to avoid complexity of
- // creating executor recusively
- auto fn = std::make_unique<::onert::backend::controlflow::kernel::WhileLayer>(
- input_tensors, output_tensors, node.getOutputs(), _graph, outputs_dyn_alloc_info,
- cond_subg_index, body_subg_index, _executor_map);
-
- _return_fn = std::move(fn);
-}
-
-std::shared_ptr<backend::ITensor> KernelGenerator::getTensor(const ir::OperandIndex &index)
-{
- std::shared_ptr<backend::ITensor> ret = _tensor_registries.getITensor(index);
- assert(ret != nullptr);
- return ret;
-}
-
-} // namespace controlflow
-} // namespace backend
-} // namespace onert
diff --git a/runtime/onert/core/src/backend/controlflow/TensorRegistry.h b/runtime/onert/core/src/backend/controlflow/TensorRegistry.h
deleted file mode 100644
index 678c5b73b..000000000
--- a/runtime/onert/core/src/backend/controlflow/TensorRegistry.h
+++ /dev/null
@@ -1,134 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __ONERT_BACKEND_CONTROLFLOW_TENSOR_REGISTRY_H__
-#define __ONERT_BACKEND_CONTROLFLOW_TENSOR_REGISTRY_H__
-
-#include "backend/cpu_common/TensorRegistry.h"
-#include "backend/ITensorRegistry.h"
-#include "Tensor.h"
-#include "UserTensor.h"
-#include <assert.h>
-
-namespace onert
-{
-namespace backend
-{
-namespace controlflow
-{
-
-/**
- * @brief Tensor registry class for controlflow backend
- *
- * This class contains three types of tensors. Two native tensors(tensors that are managed by this
- * backend) and the other is migrant tensor.
- *
- * - NativeUserTensor - @c UserTensor managed by this backend, buffer is user-given
- * - NativeOwnTensor - @c cpu_common::Tensor managed by this backend ( in @c _base_reg )
- * - MigrantTensor - @c IPortableTensor managed by other backends ( in @c _base_reg )
- *
- * @note @c _base_reg is used in implementation to reuse @c cpu_common::StaticTensorManager
- *
- */
-class TensorRegistry : public ITensorRegistry
-{
-public:
- TensorRegistry() : _base_reg{new cpu_common::TensorRegistry} {}
-
- std::shared_ptr<ITensor> getITensor(const ir::OperandIndex &ind) override
- {
- auto base_tensor = _base_reg->getITensor(ind);
- if (base_tensor)
- return base_tensor;
- return getNativeUserTensor(ind);
- }
-
- std::shared_ptr<ITensor> getNativeITensor(const ir::OperandIndex &ind) override
- {
- auto base_tensor = _base_reg->getNativeITensor(ind);
- if (base_tensor)
- return base_tensor;
- return getNativeUserTensor(ind);
- }
-
- std::shared_ptr<IPortableTensor> getPortableTensor(const ir::OperandIndex &ind)
- {
- auto base_tensor = _base_reg->getPortableTensor(ind);
- if (base_tensor)
- return base_tensor;
- return getNativeUserTensor(ind);
- }
-
- std::shared_ptr<IPortableTensor> getNativeTensor(const ir::OperandIndex &ind)
- {
- auto base_tensor = _base_reg->getNativeTensor(ind);
- if (base_tensor)
- return base_tensor;
- return getNativeUserTensor(ind);
- }
-
- std::shared_ptr<Tensor> getNativeOwnTensor(const ir::OperandIndex &ind)
- {
- return _base_reg->getNativeTensor(ind);
- }
-
- std::shared_ptr<UserTensor> getNativeUserTensor(const ir::OperandIndex &ind)
- {
- auto tensor = _native_user_tensors.find(ind);
- if (tensor != _native_user_tensors.end())
- return tensor->second;
- return nullptr;
- }
-
- bool setMigrantTensor(const ir::OperandIndex &ind,
- const std::shared_ptr<IPortableTensor> &tensor) override
- {
- assert(tensor);
- assert(!getITensor(ind)); // For the ind, tensor is not registered yet
- _base_reg->setMigrantTensor(ind, tensor);
- return true;
- }
-
- void setNativeOwnTensor(ir::OperandIndex ind, const std::shared_ptr<Tensor> &tensor)
- {
- assert(tensor);
- assert(!getITensor(ind)); // For the ind, tensor is not registered yet
- _base_reg->setNativeTensor(ind, tensor);
- }
-
- void setNativeUserTensor(ir::OperandIndex ind, const std::shared_ptr<UserTensor> &tensor)
- {
- assert(tensor);
- assert(!getITensor(ind)); // For the ind, tensor is not registered yet
- _native_user_tensors[ind] = tensor;
- }
-
- const ir::OperandIndexMap<std::shared_ptr<UserTensor>> &native_user_tensors()
- {
- return _native_user_tensors;
- }
- std::shared_ptr<cpu_common::TensorRegistry> base_reg() { return _base_reg; }
-
-private:
- std::shared_ptr<cpu_common::TensorRegistry> _base_reg;
- ir::OperandIndexMap<std::shared_ptr<UserTensor>> _native_user_tensors;
-};
-
-} // namespace controlflow
-} // namespace backend
-} // namespace onert
-
-#endif // ifndef __ONERT_BACKEND_CONTROLFLOW_TENSOR_REGISTRY_H__
diff --git a/runtime/onert/core/src/backend/controlflow/UserTensor.h b/runtime/onert/core/src/backend/controlflow/UserTensor.h
deleted file mode 100644
index 9be33595d..000000000
--- a/runtime/onert/core/src/backend/controlflow/UserTensor.h
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __ONERT_BACKEND_CONTROLFLOW_USER_TENSOR_H__
-#define __ONERT_BACKEND_CONTROLFLOW_USER_TENSOR_H__
-
-#include "ir/OperandInfo.h"
-#include "backend/IPortableTensor.h"
-
-namespace onert
-{
-namespace backend
-{
-namespace controlflow
-{
-
-/**
- * @brief Tensor object that is for Input and Output tensors from the user.
- *
- * This class is a wrapped buffer that is allocated by the user. So it does not have resposibility
- * on allocation nor deallocation. All the model input/output tensors are wrapped with this class
- * for execution.
- *
- */
-class UserTensor : public IPortableTensor
-{
-public:
- UserTensor(const ir::OperandInfo &info, ir::Layout layout, uint8_t *buffer, size_t size,
- IDynamicTensorManager *dynamic_tensor_manager)
- : _info{info}, _layout{layout}, _buffer{buffer}, _size{size}, _dynamic{false},
- _dynamic_tensor_manager{dynamic_tensor_manager}
- {
- }
-
- UserTensor(const ir::OperandInfo &info, ir::Layout layout,
- IDynamicTensorManager *dynamic_tensor_manager)
- : UserTensor{info, layout, nullptr, 0, dynamic_tensor_manager}
- {
- }
-
-public:
- void setBuffer(uint8_t *buffer, size_t size)
- {
- _buffer = buffer;
- _size = size;
- }
-
-public:
- uint8_t *buffer() const override { return _buffer; }
- size_t total_size() const override { return _size; }
- size_t dimension(size_t index) const override { return _info.shape().dim(index); }
- size_t num_dimensions() const override { return _info.shape().rank(); }
- size_t calcOffset(const ir::Coordinates &coords) const override;
- ir::Layout layout() const override { return _layout; }
- ir::DataType data_type() const override { return _info.typeInfo().type(); }
- float data_scale() const override { return _info.typeInfo().scale(); }
- int32_t data_offset() const override { return _info.typeInfo().offset(); }
- bool is_dynamic() const override { return _dynamic; }
- void set_dynamic() override { _dynamic = true; }
- ir::Shape getShape() const override { return _info.shape(); }
- void setShape(const ir::Shape &new_shape) override { _info.shape(new_shape); }
- bool is_constant() const override { return false; }
- IDynamicTensorManager *dynamic_tensor_manager() override { return _dynamic_tensor_manager; }
-
-private:
- ir::OperandInfo _info;
- ir::Layout _layout;
- uint8_t *_buffer;
- size_t _size;
- bool _dynamic;
- IDynamicTensorManager *_dynamic_tensor_manager;
-};
-
-} // namespace controlflow
-} // namespace backend
-} // namespace onert
-
-#endif // __ONERT_BACKEND_CONTROLFLOW_USER_TENSOR_H__
diff --git a/runtime/onert/core/src/backend/controlflow/kernel/IfLayer.cc b/runtime/onert/core/src/backend/controlflow/kernel/IfLayer.cc
deleted file mode 100644
index 8377c7183..000000000
--- a/runtime/onert/core/src/backend/controlflow/kernel/IfLayer.cc
+++ /dev/null
@@ -1,128 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "IfLayer.h"
-
-#include <backend/ITensor.h>
-#include "exec/ExecutorBase.h"
-#include <misc/polymorphic_downcast.h>
-#include "PermuteLayer.h"
-
-namespace onert
-{
-namespace backend
-{
-namespace controlflow
-{
-namespace kernel
-{
-
-IfLayer::IfLayer(const std::shared_ptr<backend::ITensor> &cond_tensor,
- const std::vector<std::shared_ptr<backend::ITensor>> input_tensors,
- const std::vector<std::shared_ptr<backend::ITensor>> output_tensors,
- const ir::OperandIndexSequence &output_indices, const ir::Graph &graph,
- const exec::DynAllocInfoMap &outputs_dyn_alloc_info,
- const ir::SubgraphIndex &then_subg_index, const ir::SubgraphIndex &else_subg_index,
- exec::ExecutorMap *executor_map)
- : _cond_tensor{cond_tensor}, _input_tensors{input_tensors}, _output_tensors{output_tensors},
- _output_indices{output_indices}, _graph{graph},
- _outputs_dyn_alloc_info{outputs_dyn_alloc_info}, _then_subg_index{then_subg_index},
- _else_subg_index{else_subg_index}, _executor_map{executor_map}
-{
- // At this point, executor_map may not have executors of then subg and else subg
-}
-
-void IfLayer::run()
-{
- // Check condition
- // // If true
- // // // Copy _input_tensors -> then subg's inputs
- // // // Run then subg
- // // // Copy outputs of then subg -> _output_tensors
- // // Else
- // // // Copy _input_tensors -> else subg's inputs if false
- // // // Run else subg
- // // // Copy outputs of else subg -> _output_tensors
- auto getResultCond = [](backend::ITensor *tensor) -> bool {
- bool ret = false;
- tensor->access([&](ITensor &tensor) { ret = *reinterpret_cast<bool *>(tensor.buffer()); });
- return ret;
- };
-
- exec::ExecutorBase *subg_exec = nullptr;
- if (getResultCond(_cond_tensor.get()))
- {
- subg_exec = nnfw::misc::polymorphic_downcast<exec::ExecutorBase *>(
- _executor_map->at(_then_subg_index).get());
- }
- else
- {
- subg_exec = nnfw::misc::polymorphic_downcast<exec::ExecutorBase *>(
- _executor_map->at(_else_subg_index).get());
- }
-
- const auto &subg_graph = subg_exec->graph();
-
- std::vector<std::shared_ptr<backend::ITensor>> src_tensors;
- std::vector<std::shared_ptr<backend::ITensor>> dst_tensors;
- // Add tensors used in subgraph or contained in outputs of subgraph
- assert(subg_graph.getInputs().size() == _input_tensors.size());
- assert(subg_graph.getInputs().size() == subg_exec->getInputTensors().size());
- for (uint32_t i = 0; i < subg_graph.getInputs().size(); ++i)
- {
- const auto &subg_input_index = subg_graph.getInputs().at(i);
- const auto &subg_input = subg_graph.operands().at(subg_input_index);
- if (subg_input.getUses().size() > 0 || subg_graph.getOutputs().contains(subg_input_index))
- {
- src_tensors.emplace_back(_input_tensors.at(i));
- dst_tensors.emplace_back(subg_exec->getInputTensors().at(i));
- }
- }
- const auto &subg_inputs_dyn_alloc_info = subg_exec->getInputsDynamicAllocInfo();
- const auto permute_op_input_to_subg_input =
- std::make_shared<PermuteLayer>(src_tensors, dst_tensors, subg_inputs_dyn_alloc_info);
-
- // Add tensors used as output of operation or contained in outputs of operation
- src_tensors.clear();
- dst_tensors.clear();
- assert(_output_indices.size() == subg_exec->getOutputTensors().size());
- assert(_output_indices.size() == _output_tensors.size());
- for (uint32_t i = 0; i < _output_indices.size(); ++i)
- {
- const auto &output_index = _output_indices.at(i);
- const auto &output = _graph.operands().at(output_index);
- if (output.getUses().size() > 0 || _graph.getOutputs().contains(output_index))
- {
- src_tensors.emplace_back(subg_exec->getOutputTensors().at(i));
- dst_tensors.emplace_back(_output_tensors.at(i));
- }
- }
- const auto permute_subg_output_to_op_output =
- std::make_shared<PermuteLayer>(src_tensors, dst_tensors, _outputs_dyn_alloc_info);
-
- // Remove copying of unused tensor
- permute_op_input_to_subg_input->prepare();
- permute_subg_output_to_op_output->prepare();
-
- // Copy & run
- subg_exec->execute(_input_tensors, permute_op_input_to_subg_input);
- permute_subg_output_to_op_output->run();
-}
-
-} // namespace kernel
-} // namespace controlflow
-} // namespace backend
-} // namespace onert
diff --git a/runtime/onert/core/src/backend/controlflow/kernel/PermuteLayer.cc b/runtime/onert/core/src/backend/controlflow/kernel/PermuteLayer.cc
deleted file mode 100644
index e8f1ea679..000000000
--- a/runtime/onert/core/src/backend/controlflow/kernel/PermuteLayer.cc
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "PermuteLayer.h"
-
-#include "exec/ShapeConverter.h"
-
-namespace onert
-{
-namespace backend
-{
-namespace controlflow
-{
-namespace kernel
-{
-
-void PermuteLayer::run()
-{
- assert(_src_tensors.size() == _dst_tensors.size());
- // PermuteLayer infers dynamic shape inside itself whenever run is called for the following
- // reasons:
- // 1. PermuteLayer has to access dynamic tensor manager for input/output tensors of other backends
- // 2. Other controlflow operation(If/While) uses this layout for copying tensors of other
- // subgraphs(with other backends)
- // 3. This infering code is placed here to avoid duplicated code that can be caused by above 2
- // reasons
-
- // check if output is not dynamic
- for (size_t i = 0; i < _src_tensors.size(); ++i)
- {
- auto dst_tensor = _dst_tensors.at(i);
- auto src_tensor = _src_tensors.at(i);
- if (src_tensor->is_dynamic() || dst_tensor->is_dynamic())
- {
- // getting output shape
- auto src_shape = src_tensor->getShape();
-
- // set output shape and output buffer
- ir::Shape new_shape =
- exec::convertShape(src_shape, src_tensor->layout(), dst_tensor->layout());
-
- try
- {
- const auto dst_index = _dst_dyn_alloc_info_map.at(dst_tensor).ind;
- auto dyn_tensor_manager = dst_tensor->dynamic_tensor_manager();
- if (!dyn_tensor_manager)
- throw std::runtime_error{
- "Error: PermuteLayer: output's TensorManager does not support dynamic tensor"};
- dyn_tensor_manager->applyShape(dst_index, new_shape);
- assert(dst_tensor->buffer() != nullptr);
- }
- catch (const std::out_of_range &e)
- {
- std::cerr << "Error: out_of_range in PermuteLayer: output's TensorManager does not support "
- "dynamic tensor"
- << '\n';
- throw;
- }
- }
- assert(exec::convertShape(src_tensor->getShape(), src_tensor->layout(), dst_tensor->layout()) ==
- dst_tensor->getShape());
- }
- IPermuteFunction::run();
-}
-
-} // namespace kernel
-} // namespace controlflow
-} // namespace backend
-} // namespace onert
diff --git a/runtime/onert/core/src/backend/controlflow/kernel/PermuteLayer.h b/runtime/onert/core/src/backend/controlflow/kernel/PermuteLayer.h
deleted file mode 100644
index 403ac770d..000000000
--- a/runtime/onert/core/src/backend/controlflow/kernel/PermuteLayer.h
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __ONERT_BACKEND_CONTROLFLOW_KERNEL_PERMUTELAYER_H__
-#define __ONERT_BACKEND_CONTROLFLOW_KERNEL_PERMUTELAYER_H__
-
-#include "backend/ITensorBuilder.h"
-#include "exec/IPermuteFunction.h"
-#include "exec/IExecutor.h"
-
-namespace onert
-{
-namespace backend
-{
-namespace controlflow
-{
-namespace kernel
-{
-
-class PermuteLayer : public onert::exec::IPermuteFunction
-{
-public:
- PermuteLayer(const std::vector<std::shared_ptr<ITensor>> &src_tensors,
- const std::vector<std::shared_ptr<ITensor>> &dst_tensors,
- const exec::DynAllocInfoMap &dst_dyn_alloc_info_map)
- : _dst_dyn_alloc_info_map{dst_dyn_alloc_info_map}
- {
- assert(src_tensors.size() == dst_tensors.size());
- _src_tensors = src_tensors;
- _dst_tensors = dst_tensors;
- }
-
- void optimize() override
- {
- // Remove copying of tensor as nullptr
- auto src_it = _src_tensors.begin();
- auto dst_it = _dst_tensors.begin();
- while (src_it != _src_tensors.end())
- {
- if ((*src_it == *dst_it) || (*src_it == nullptr || *dst_it == nullptr))
- {
- src_it = _src_tensors.erase(src_it);
- dst_it = _dst_tensors.erase(dst_it);
- }
- else
- {
- ++src_it;
- ++dst_it;
- }
- }
- }
-
- void run() override;
-
-private:
- const exec::DynAllocInfoMap _dst_dyn_alloc_info_map;
-};
-
-} // namespace kernel
-} // namespace controlflow
-} // namespace backend
-} // namespace onert
-
-#endif // __ONERT_BACKEND_CONTROLFLOW_KERNEL_PERMUTELAYER_H__
diff --git a/runtime/onert/core/src/backend/controlflow/kernel/WhileLayer.cc b/runtime/onert/core/src/backend/controlflow/kernel/WhileLayer.cc
deleted file mode 100644
index 50936e5f6..000000000
--- a/runtime/onert/core/src/backend/controlflow/kernel/WhileLayer.cc
+++ /dev/null
@@ -1,216 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "WhileLayer.h"
-
-#include <backend/ITensor.h>
-#include "exec/ExecutorBase.h"
-#include <misc/polymorphic_downcast.h>
-#include "PermuteLayer.h"
-
-namespace onert
-{
-namespace backend
-{
-namespace controlflow
-{
-namespace kernel
-{
-
-WhileLayer::WhileLayer(const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors,
- const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors,
- const ir::OperandIndexSequence &output_indices, const ir::Graph &graph,
- const exec::DynAllocInfoMap &outputs_dyn_alloc_info,
- const ir::SubgraphIndex &cond_subg_index,
- const ir::SubgraphIndex &body_subg_index, exec::ExecutorMap *executor_map)
- : _cond_subg_index{cond_subg_index}, _body_subg_index{body_subg_index},
- _output_indices{output_indices}, _graph{graph}, _input_tensors{input_tensors},
- _output_tensors{output_tensors}, _outputs_dyn_alloc_info{outputs_dyn_alloc_info},
- _executor_map{executor_map}
-{
- // At this point, executor_map may not have executors of cond subg and body subg
-}
-
-void WhileLayer::run()
-{
- // Copy "_input_tensors" -> "cond subg inputs"
- // Run cond subg
- // Start loop while output of cond subg is ture
- // // Copy "_input_tensors" -> "body subg inputs" in the first iteration, then copy "body subg
- // outputs" -> "body subg inputs" in the second or more iterations
- // // Run body subg
- // // Copy "body subg outputs" -> "cond subg inputs"
- // // Run cond subg
- // If there is no loop copy "_input_tensors" -> "_dst_tensors", else copy "cond subg inputs" ->
- // "_dst_tensors"
- auto cond_exec = nnfw::misc::polymorphic_downcast<exec::ExecutorBase *>(
- _executor_map->at(_cond_subg_index).get());
- auto body_exec = nnfw::misc::polymorphic_downcast<exec::ExecutorBase *>(
- _executor_map->at(_body_subg_index).get());
-
- const auto &cond_graph = cond_exec->graph();
- const auto &cond_inputs_dyn_alloc = cond_exec->getInputsDynamicAllocInfo();
- const auto &body_graph = body_exec->graph();
- const auto &body_inputs_dyn_alloc = body_exec->getInputsDynamicAllocInfo();
-
- std::vector<std::shared_ptr<backend::ITensor>> input_tensors;
- std::vector<std::shared_ptr<backend::ITensor>> cond_input_tensors;
- std::vector<std::shared_ptr<backend::ITensor>> body_input_tensors;
- std::vector<std::shared_ptr<backend::ITensor>> body_output_tensors;
- std::vector<std::shared_ptr<backend::ITensor>> output_tensors;
-
- // Add only used tensors in cond subgraph
- assert(cond_graph.getInputs().size() == _input_tensors.size());
- assert(cond_graph.getInputs().size() == cond_exec->getInputTensors().size());
- for (uint32_t i = 0; i < cond_graph.getInputs().size(); ++i)
- {
- const auto &cond_input = cond_graph.operands().at(cond_graph.getInputs().at(i));
- if (cond_input.getUses().size() > 0)
- {
- input_tensors.emplace_back(_input_tensors.at(i));
- cond_input_tensors.emplace_back(cond_exec->getInputTensors().at(i));
- }
- }
- const auto permute_op_input_to_cond_input =
- std::make_shared<PermuteLayer>(input_tensors, cond_input_tensors, cond_inputs_dyn_alloc);
-
- // Add only used tensors among outputs of while operation
- assert(_output_indices.size() == _input_tensors.size());
- assert(_output_indices.size() == _output_tensors.size());
- input_tensors.clear();
- output_tensors.clear();
- for (size_t i = 0; i < _output_indices.size(); ++i)
- {
- const auto &output_index = _output_indices.at(i);
- const auto &output = _graph.operands().at(output_index);
- if (output.getUses().size() > 0 || _graph.getOutputs().contains(output_index))
- {
- input_tensors.emplace_back(_input_tensors.at(i));
- output_tensors.emplace_back(_output_tensors.at(i));
- }
- }
- const auto permute_op_input_to_op_output =
- std::make_shared<PermuteLayer>(input_tensors, output_tensors, _outputs_dyn_alloc_info);
-
- // Add all tensors with unused tensors in body subgraph because unused input tensors will be
- // copied output tensors in body subgraph
- assert(_input_tensors.size() == body_exec->getInputTensors().size());
- input_tensors = _input_tensors;
- body_input_tensors = body_exec->getInputTensors();
- const auto permute_op_input_to_body_input =
- std::make_shared<PermuteLayer>(input_tensors, body_input_tensors, body_inputs_dyn_alloc);
-
- // Add only used tensors in cond subgraph
- assert(cond_graph.getInputs().size() == body_exec->getOutputTensors().size());
- assert(cond_graph.getInputs().size() == cond_exec->getInputTensors().size());
- body_output_tensors.clear();
- cond_input_tensors.clear();
- for (uint32_t i = 0; i < cond_graph.getInputs().size(); ++i)
- {
- const auto &cond_input = cond_graph.operands().at(cond_graph.getInputs().at(i));
- if (cond_input.getUses().size() > 0)
- {
- body_output_tensors.emplace_back(body_exec->getOutputTensors().at(i));
- cond_input_tensors.emplace_back(cond_exec->getInputTensors().at(i));
- }
- }
- const auto permute_body_output_to_cond_input = std::make_shared<PermuteLayer>(
- body_output_tensors, cond_input_tensors, cond_inputs_dyn_alloc);
-
- // Add only used tensors in body subgraph
- assert(body_graph.getInputs().size() == body_exec->getOutputTensors().size());
- assert(body_graph.getInputs().size() == body_exec->getInputTensors().size());
- body_output_tensors.clear();
- body_input_tensors.clear();
- for (uint32_t i = 0; i < body_graph.getInputs().size(); ++i)
- {
- const auto &body_input_index = body_graph.getInputs().at(i);
- const auto &body_input = body_graph.operands().at(body_input_index);
- if (body_input.getUses().size() > 0 &&
- !body_exec->graph().getOutputs().contains(body_input_index))
- {
- body_output_tensors.emplace_back(body_exec->getOutputTensors().at(i));
- body_input_tensors.emplace_back(body_exec->getInputTensors().at(i));
- }
- }
- const auto permute_body_output_to_body_input = std::make_shared<PermuteLayer>(
- body_output_tensors, body_input_tensors, body_inputs_dyn_alloc);
-
- // Add only used tensors among outputs of while operation
- assert(_output_indices.size() == body_exec->getOutputTensors().size());
- assert(_output_indices.size() == _output_tensors.size());
- body_output_tensors.clear();
- output_tensors.clear();
- for (size_t i = 0; i < _output_indices.size(); ++i)
- {
- const auto &output_index = _output_indices.at(i);
- const auto &output = _graph.operands().at(output_index);
- if (output.getUses().size() > 0 || _graph.getOutputs().contains(output_index))
- {
- body_output_tensors.emplace_back(body_exec->getOutputTensors().at(i));
- output_tensors.emplace_back(_output_tensors.at(i));
- }
- }
- const auto permute_body_output_to_op_output =
- std::make_shared<PermuteLayer>(body_output_tensors, output_tensors, _outputs_dyn_alloc_info);
-
- // Remove copying of unused tensor
- permute_op_input_to_cond_input->prepare();
- permute_op_input_to_op_output->prepare();
- permute_op_input_to_body_input->prepare();
- permute_body_output_to_cond_input->prepare();
- permute_body_output_to_body_input->prepare();
- permute_body_output_to_op_output->prepare();
-
- cond_exec->execute(_input_tensors, permute_op_input_to_cond_input);
-
- assert(cond_exec->getOutputTensors().size() == 1);
- auto &cond_output_tensor = cond_exec->getOutputTensors().at(0);
- auto getResultCond = [](backend::ITensor *tensor) -> bool {
- bool ret = false;
- tensor->access([&](ITensor &tensor) { ret = *reinterpret_cast<bool *>(tensor.buffer()); });
- return ret;
- };
-
- const auto body_execute_with_op_inputs = [&]() {
- body_exec->execute(_input_tensors, permute_op_input_to_body_input);
- };
-
- const auto body_execute_with_body_outputs = [&]() {
- body_exec->execute(body_exec->getOutputTensors(), permute_body_output_to_body_input);
- };
-
- std::function<void()> body_execute = body_execute_with_op_inputs;
- const auto cond_execute = [&]() {
- cond_exec->execute(body_exec->getOutputTensors(), permute_body_output_to_cond_input);
- };
- auto permute_to_outputs_fn = permute_op_input_to_op_output;
-
- // Loop while Cond subgraph's output is true
- while (getResultCond(cond_output_tensor.get()))
- {
- body_execute();
- cond_execute();
- body_execute = body_execute_with_body_outputs;
- permute_to_outputs_fn = permute_body_output_to_op_output;
- }
- permute_to_outputs_fn->run();
-}
-
-} // namespace kernel
-} // namespace controlflow
-} // namespace backend
-} // namespace onert
diff --git a/runtime/onert/core/src/backend/cpu_common/DynamicTensorManager.cc b/runtime/onert/core/src/backend/cpu_common/DynamicTensorManager.cc
deleted file mode 100644
index f7ce3d011..000000000
--- a/runtime/onert/core/src/backend/cpu_common/DynamicTensorManager.cc
+++ /dev/null
@@ -1,137 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "backend/cpu_common/DynamicTensorManager.h"
-
-#include "util/logging.h"
-
-namespace onert
-{
-namespace backend
-{
-namespace cpu_common
-{
-
-DynamicTensorManager::DynamicTensorManager(const std::shared_ptr<TensorRegistry> &reg)
- : _dynamic_mem_mgr{new DynamicMemoryManager()}, _tensors{reg}
-{
- // DO NOTHING
-}
-
-void DynamicTensorManager::applyShape(const ir::OperandIndex &ind, const ir::Shape &new_shape)
-{
- VERBOSE_F() << ind << std::endl;
-
- auto tensor = _tensors->getNativeTensor(ind);
- assert(tensor);
-
- bool previously_dynamic = tensor->is_dynamic();
-
- auto allocTensorMem = [&](bool overwrite = false) {
- auto capacity = tensor->total_size();
- auto alloc = _dynamic_mem_mgr->allocate(ind, capacity);
-
- if (overwrite)
- tensor->overwriteBuffer(alloc);
- else
- tensor->setBuffer(alloc);
- };
-
- if (!previously_dynamic)
- {
- // TODO deallocate tensor->buffer()
- // issue is that staticTensorManager might have allocate this memory
- tensor->setShape(new_shape);
- tensor->set_dynamic();
- allocTensorMem(true);
- }
- else if (tensor->buffer() == nullptr)
- {
- tensor->setShape(new_shape);
- tensor->set_dynamic();
- allocTensorMem();
- }
- // when buffer was already allocated and new_shape requires different size
- else
- {
- auto previous_size = tensor->total_size();
- auto new_size = new_shape.num_elements() * sizeOfDataType(tensor->data_type());
- if (previous_size != new_size)
- {
- _dynamic_mem_mgr->deallocate(ind);
-
- tensor->setShape(new_shape);
- tensor->set_dynamic();
- allocTensorMem(true);
- }
- else
- { // when buffer with same size was already allocated, shape could differ
- tensor->setShape(new_shape);
- }
- }
-}
-
-void DynamicTensorManager::buildTensor(const ir::OperandIndex &ind,
- const ir::OperandInfo &tensor_info,
- ir::Layout backend_layout)
-{
- assert(_tensors->getNativeTensor(ind) == nullptr);
- auto tensor = std::make_shared<Tensor>(tensor_info, backend_layout, this);
- _tensors->setNativeTensor(ind, tensor);
-}
-
-void DynamicTensorManager::planDealloc(ir::OperationIndex op_ind, ir::OperandIndex operand_ind)
-{
- _dealloc_tensor_map[op_ind].emplace(operand_ind);
-}
-
-void DynamicTensorManager::deallocInput(ir::OperationIndex op_ind)
-{
- auto find = _dealloc_tensor_map.find(op_ind);
- if (find == _dealloc_tensor_map.end())
- return;
-
- auto &input_set = find->second;
- for (auto input_ind : input_set)
- {
- auto *tensor = _tensors->getNativeTensor(input_ind).get();
- if (!tensor->is_dynamic())
- continue;
-
- _dynamic_mem_mgr->deallocate(input_ind);
- tensor->resetBuffer();
-
- VERBOSE(DynamicTensorManager) << "Deallocating #" << input_ind.value()
- << " (input of op_ind: " << op_ind.value() << ")" << std::endl;
- }
-}
-
-void DynamicTensorManager::deallocSubgraphOutput(ir::OperandIndex output_ind)
-{
- auto *tensor = _tensors->getNativeTensor(output_ind).get();
- if (!tensor->is_dynamic())
- return;
-
- _dynamic_mem_mgr->deallocate(output_ind);
- tensor->resetBuffer();
-
- VERBOSE(DynamicTensorManager) << "Deallocating #" << output_ind.value()
- << " (output of a subgraph)" << std::endl;
-}
-
-} // namespace cpu_common
-} // namespace backend
-} // namespace onert
diff --git a/runtime/onert/core/src/compiler/BackendManager.cc b/runtime/onert/core/src/compiler/BackendManager.cc
index db7a14a96..44442c065 100644
--- a/runtime/onert/core/src/compiler/BackendManager.cc
+++ b/runtime/onert/core/src/compiler/BackendManager.cc
@@ -16,22 +16,17 @@
#include "compiler/BackendManager.h"
-#include <memory>
-#include <dlfcn.h>
+#include "../backend/builtin/Backend.h"
+#include "../backend/builtin/Config.h"
-#include "backend/Backend.h"
-#include "backend/controlflow/Backend.h"
-#include "backend/controlflow/Config.h"
-#include "backend/IConfig.h"
-#include "util/logging.h"
-#include "util/ConfigSource.h"
-#include "misc/string_helpers.h"
+#include <dlfcn.h>
+#include <memory>
static const char *SHARED_LIB_EXT =
#if defined(__APPLE__) && defined(__MACH__)
- ".dylib";
+ ".dylib";
#else
- ".so";
+ ".so";
#endif
namespace onert
@@ -45,20 +40,20 @@ BackendManager &BackendManager::get()
return object;
}
-BackendManager::BackendManager() { loadControlflowBackend(); }
+BackendManager::BackendManager() { loadBuiltinBackend(); }
-void BackendManager::loadControlflowBackend()
+void BackendManager::loadBuiltinBackend()
{
- auto backend_object = std::unique_ptr<backend::controlflow::Backend, backend_destroy_t>(
- new backend::controlflow::Backend, [](backend::Backend *backend) { delete backend; });
+ auto backend_object = std::unique_ptr<backend::builtin::Backend, backend_destroy_t>(
+ new backend::builtin::Backend, [](backend::Backend *backend) { delete backend; });
bool initialized = backend_object->config()->initialize(); // Call initialize here?
if (!initialized)
{
- throw std::runtime_error(backend::controlflow::Config::ID + " backend initialization failed");
+ throw std::runtime_error(backend::builtin::Config::ID + " backend initialization failed");
}
- _controlflow = backend_object.get(); // Save the controlflow backend implementation pointer
- assert(_controlflow);
+ _builtin = backend_object.get(); // Save the builtin backend implementation pointer
+ assert(_builtin);
_gen_map.emplace(backend_object->config()->id(), std::move(backend_object));
}
@@ -69,68 +64,67 @@ void BackendManager::loadBackend(const std::string &backend)
return;
}
- // TODO Remove indentation
- // Workaround If backend have dynamic library with "-boost" suffix naming,
- // BackendManager load library with "-boost" suffix instead of library without suffix
- // This feature is used for custom backend extension to support additional operations
- {
- const std::string backend_boost_so = "libbackend_" + backend + "-boost" + SHARED_LIB_EXT;
- const std::string backend_so = "libbackend_" + backend + SHARED_LIB_EXT;
+ const std::string backend_so = "libbackend_" + backend + SHARED_LIB_EXT;
+ void *handle = dlopen(backend_so.c_str(), RTLD_LAZY | RTLD_LOCAL);
- void *handle = dlopen(backend_boost_so.c_str(), RTLD_LAZY | RTLD_LOCAL);
- if (handle == nullptr)
- {
- handle = dlopen(backend_so.c_str(), RTLD_LAZY | RTLD_LOCAL);
+ if (handle == nullptr)
+ {
+ VERBOSE(BackendManager) << "Failed to load backend '" << backend << "' - " << dlerror() << "\n";
+ return;
+ }
- if (handle == nullptr)
- {
- VERBOSE_F() << "Failed to load backend '" << backend << "' - " << dlerror() << std::endl;
- return;
- }
+ VERBOSE(BackendManager) << "Successfully loaded '" << backend << "'(" << backend_so << ")\n";
- VERBOSE_F() << "Successfully loaded '" << backend << "' - " << backend_so << "\n";
+ {
+ // load object creator function
+ auto backend_create = (backend_create_t)dlsym(handle, "onert_backend_create");
+ if (backend_create == nullptr)
+ {
+ // TODO replace `fprintf` with `VERBOSE`
+ fprintf(stderr, "BackendManager: unable to find function `onert_backend_create` : %s\n",
+ dlerror());
+ dlclose(handle);
+ return;
}
- else
+
+ // load object creator function
+ auto backend_destroy = (backend_destroy_t)dlsym(handle, "onert_backend_destroy");
+ if (backend_destroy == nullptr)
{
- VERBOSE_F() << "Successfully loaded '" << backend << "' - " << backend_boost_so << "\n";
+ // TODO replace `fprintf` with `VERBOSE`
+ fprintf(stderr, "BackendManager: unable to find `function onert_backend_destroy` : %s\n",
+ dlerror());
+ dlclose(handle);
+ return;
}
+ auto backend_object =
+ std::unique_ptr<backend::Backend, backend_destroy_t>(backend_create(), backend_destroy);
+ bool initialized = backend_object->config()->initialize(); // Call initialize here?
+ if (!initialized)
{
- // load object creator function
- auto backend_create = (backend_create_t)dlsym(handle, "onert_backend_create");
- if (backend_create == nullptr)
- {
- fprintf(stderr, "BackendManager: unable to open function onert_backend_create : %s\n",
- dlerror());
- abort();
- }
+ VERBOSE(BackendManager) << backend.c_str()
+ << " backend initialization failed. Don't use this backend"
+ << std::endl;
+ dlclose(handle);
+ return;
+ }
+ _gen_map.emplace(backend_object->config()->id(), std::move(backend_object));
+ }
- // load object creator function
- auto backend_destroy = (backend_destroy_t)dlsym(handle, "onert_backend_destroy");
- if (backend_destroy == nullptr)
+ // Save backend handle (avoid warning by handle lost without dlclose())
+ auto u_handle = std::unique_ptr<void, dlhandle_destroy_t>{
+ handle, [id = backend, filename = backend_so](void *h) {
+ if (dlclose(h) == 0)
{
- fprintf(stderr, "BackendManager: unable to open function onert_backend_destroy : %s\n",
- dlerror());
- abort();
+ VERBOSE(BackendManager) << "Successfully unloaded '" << id << "'(" << filename << ")\n";
}
-
- auto backend_object =
- std::unique_ptr<backend::Backend, backend_destroy_t>(backend_create(), backend_destroy);
- bool initialized = backend_object->config()->initialize(); // Call initialize here?
- if (!initialized)
+ else
{
- VERBOSE_F() << backend.c_str() << " backend initialization failed. Don't use this backend"
- << std::endl;
- dlclose(handle);
- return;
+ VERBOSE(BackendManager) << "Failed to unload backend '" << id << "'- " << dlerror() << "\n";
}
- _gen_map.emplace(backend_object->config()->id(), std::move(backend_object));
- }
-
- // Save backend handle (avoid warning by handle lost without dlclose())
- auto u_handle = std::unique_ptr<void, dlhandle_destroy_t>{handle, [](void *h) { dlclose(h); }};
- _handle_map.emplace(backend, std::move(u_handle));
- }
+ }};
+ _handle_map.emplace(backend, std::move(u_handle));
}
backend::Backend *BackendManager::get(const std::string &key)
@@ -153,7 +147,7 @@ const backend::Backend *BackendManager::get(const std::string &key) const
return nullptr;
}
-const backend::controlflow::Backend *BackendManager::getControlflow() const { return _controlflow; }
+const backend::Backend *BackendManager::getBuiltin() const { return _builtin; }
} // namespace compiler
} // namespace onert
diff --git a/runtime/onert/core/src/compiler/Compiler.cc b/runtime/onert/core/src/compiler/Compiler.cc
index 93dbbc3b5..63667a063 100644
--- a/runtime/onert/core/src/compiler/Compiler.cc
+++ b/runtime/onert/core/src/compiler/Compiler.cc
@@ -16,284 +16,177 @@
#include "compiler/Compiler.h"
-#include "ParamChecker.h"
+#include "CompilerHelpers.h"
#include "ExecutorFactory.h"
-#include "OperationValidator.h"
-#include "Fp32ToFp16Converter.h"
+#include "ShapeValidator.h"
+#include "pass/ConstantOutputPass.h"
+#include "pass/OddOutputPass.h"
+#include "pass/PassRunner.h"
+#include "pass/UnusedOperandEliminationPass.h"
+#include "../dumper/dot/DotDumper.h"
+#include "../exec/SingleModelExecutors.h"
+#include "../ir/OperationDumper.h"
+#include "../ir/verifier/Verifier.h"
-#include <backend/controlflow/Config.h>
-#include "compiler/BackendManager.h"
-#include "compiler/IScheduler.h"
-#include "compiler/ManualScheduler.h"
-#include "compiler/HEScheduler.h"
-#include "compiler/StaticShapeInference.h"
-#include "exec/ExecTime.h"
-#include "ir/operation/LowerInfo.h"
-#include "dumper/dot/DotDumper.h"
-#include "compiler/Linear.h"
-#include "interp/InterpExecutor.h"
-#include "util/ConfigSource.h"
-#include "util/logging.h"
-#include "ir/OperationDumper.h"
-#include "misc/string_helpers.h"
+#include "compiler/StaticShapeInferer.h"
+
+#include <misc/string_helpers.h>
+#include <misc/polymorphic_downcast.h>
namespace onert
{
-
namespace compiler
{
-CompilerOptions fetchCompilerOptionsFromGlobalConfig(const ir::Subgraphs &subgs)
+Compiler::Compiler(const std::shared_ptr<ir::Model> &model, CompilerOptions *copts)
+ : _model{model}, _options{copts}
{
- CompilerOptions options;
- options.backend_list = nnfw::misc::split(util::getConfigString(util::config::BACKENDS), ';');
- options.is_primary_subgraph = false;
- options.trace_filepath = util::getConfigString(util::config::TRACE_FILEPATH);
- options.graph_dump_level = util::getConfigInt(util::config::GRAPH_DOT_DUMP);
- options.op_seq_max_node = util::getConfigInt(util::config::OP_SEQ_MAX_NODE);
- options.executor = util::getConfigString(util::config::EXECUTOR);
- options.he_scheduler = util::getConfigBool(util::config::USE_SCHEDULER);
- options.he_profiling_mode = util::getConfigBool(util::config::PROFILING_MODE);
- options.disable_compile = util::getConfigBool(util::config::DISABLE_COMPILE);
- options.fp16_enable = util::getConfigBool(util::config::FP16_ENABLE);
-#ifdef RUY_PROFILER
- options.op_seq_max_node = 1;
-#endif
-
- {
- // Backend for all
- auto &ms_options = options.manual_scheduler_options;
-
- // Default value for op_backend_all is first element in the backend list
- ms_options.backend_for_all = util::getConfigString(util::config::OP_BACKEND_ALLOPS);
-
-// Opcode to Backend
-#define OP(OpName) \
- { \
- const auto &backend_str = util::getConfigString(util::config::OP_BACKEND_##OpName); \
- if (!backend_str.empty()) \
- { \
- ms_options.opcode_to_backend[ir::OpCode::OpName] = backend_str; \
- } \
- }
-#include "ir/Operations.lst"
-#undef OP
-
- // Index to Backend
- // TODO Support multiple subgraphs for manual scheduling
- auto map_str = util::getConfigString(util::config::OP_BACKEND_MAP);
- auto key_val_list = nnfw::misc::split(map_str, ';');
- for (const auto &key_val_str : key_val_list)
- {
- if (key_val_str.empty())
- {
- continue;
- }
-
- auto key_val = nnfw::misc::split(key_val_str, '=');
- const auto &key_str = key_val.at(0);
- const auto &val = key_val.at(1);
- auto key = static_cast<uint32_t>(std::stoi(key_str));
-
- subgs.at(ir::SubgraphIndex{0})
- ->operations()
- .at(ir::OperationIndex{key}); // Check if exist, or this wil throw
- ms_options.index_to_backend.emplace(ir::OperationIndex{key}, val);
- }
- }
- return options;
+ // DO NOTHING
}
-Compiler::Compiler(const std::shared_ptr<ir::Subgraphs> &subgs)
- : _subgraphs{subgs}, _state{State::CREATED}
+Compiler::Compiler(const std::shared_ptr<ir::NNPkg> &nnpkg, CompilerOptions *copts)
+ : _model{nnpkg->primary_model()}, _options{copts}
{
- // Set default values for CompilerOptions
- // All these default values should not be fetched from Env, when we stop supporting Android NN
- // API.
- _options = fetchCompilerOptionsFromGlobalConfig(*subgs);
+ // Use for single model only
+ assert(nnpkg->model_count() == 1);
}
-void Compiler::enableToFp16() { _options.fp16_enable = true; }
-
-void Compiler::checkProfilerConditions()
+std::shared_ptr<CompilerArtifact> Compiler::compile(void)
{
- if (!_options.he_scheduler)
- throw std::runtime_error("Heterogeneous scheduler must be enabled during profiling.");
-
- if (_options.executor != "Dataflow")
- throw std::runtime_error("Profiling mode works only with 'Dataflow' executor");
-}
+ /***************************************************
+ * Prepare compilation phase
+ ***************************************************/
+ if (!_options)
+ throw std::runtime_error{"Empty compile option"};
-std::shared_ptr<exec::ExecutorMap> Compiler::compile(void)
-{
- // Set control flow backend for control flow operators
+ // Mode check
+ // TODO handle option for each model
+ if (_options->he_profiling_mode)
{
- _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::If] =
- backend::controlflow::Config::ID;
- _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::While] =
- backend::controlflow::Config::ID;
- }
+ if (!_options->he_scheduler)
+ throw std::runtime_error("Heterogeneous scheduler must be enabled during profiling.");
- // FIXME This is a workaround for bcq operations, should remove it
- {
- _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQFullyConnected] = "bcq";
- _options.manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQGather] = "bcq";
+ if (_options->executor != "Dataflow")
+ throw std::runtime_error("Profiling mode works only with 'Dataflow' executor");
}
+ if (!_model->hasOnly<ir::Graph>())
{
- VERBOSE(Compiler) << std::boolalpha;
- VERBOSE(Compiler) << "==== Compiler Options ====" << std::endl;
- VERBOSE(Compiler) << "backend_list : "
- << nnfw::misc::join(_options.backend_list.begin(),
- _options.backend_list.end(), "/")
- << std::endl;
- VERBOSE(Compiler) << "trace_filepath : " << _options.trace_filepath << std::endl;
- VERBOSE(Compiler) << "graph_dump_level : " << _options.graph_dump_level << std::endl;
- VERBOSE(Compiler) << "op_seq_max_node : " << _options.op_seq_max_node << std::endl;
- VERBOSE(Compiler) << "executor : " << _options.executor << std::endl;
- VERBOSE(Compiler) << "manual_scheduler_options : (Too many things to print)" << std::endl;
- VERBOSE(Compiler) << "he_scheduler : " << _options.he_scheduler << std::endl;
- VERBOSE(Compiler) << "he_profiling_mode : " << _options.he_profiling_mode << std::endl;
- VERBOSE(Compiler) << "disable_compile : " << _options.disable_compile << std::endl;
- VERBOSE(Compiler) << "fp16_enable : " << _options.fp16_enable << std::endl;
- VERBOSE(Compiler) << std::noboolalpha;
+ throw std::runtime_error("Compiler can only compile models for inference.");
}
- /***************************************************
- * Prepare compilation phase
- ***************************************************/
+ _options->forceInternalOptions();
+ _options->verboseOptions();
- auto executors = std::make_shared<exec::ExecutorMap>();
+ auto custom_kernel_builder = _model->getKernelBuilder();
- // Compilable check
- // TODO: Support hybrid execution -
- // execution between interpreter and compiled executor (including control flow)
- if (!checkCompilable())
- {
- _subgraphs->iterate([&](const ir::SubgraphIndex &index, ir::Graph &subg) {
- executors->emplace(index, std::make_unique<interp::InterpExecutor>(subg));
- });
- _state = State::COMPILED;
- return executors;
- }
+ _model->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) {
+ auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph);
- // Mode check
- if (_options.he_profiling_mode)
- checkProfilerConditions();
+ // Mandatory passes
+ pass::PassRunner{}
+ .append(std::make_unique<pass::ConstantOutputPass>(subg))
+ .append(std::make_unique<pass::OddOutputPass>(subg))
+ .run();
+
+ // Optimizations
+ pass::PassRunner{}.append(std::make_unique<pass::UnusedOperandEliminationPass>(subg)).run();
+ });
/***************************************************
* Backend independent analysis & optimization phase
***************************************************/
- auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options.graph_dump_level);
+ // TODO Handle dump level for each model
+ auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options->graph_dump_level);
+ onert::dumper::dot::DotDumper dot_dumper(dump_level);
+
+ // Tracing context
+ auto tracing_ctx = std::make_unique<util::TracingCtx>();
// Lower: Assign backend
std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::LoweredGraph>> lowered_subgs;
- _subgraphs->iterate([&](const ir::SubgraphIndex &index, ir::Graph &subg) {
- _options.is_primary_subgraph = (index == ir::SubgraphIndex{0});
- onert::dumper::dot::DotDumper dot_dumper(subg, dump_level);
- dot_dumper.dump(nnfw::misc::str("before_lower_subg-", index.value()));
-
- // Lower: Assign backend
- lowered_subgs[index] = std::make_unique<compiler::LoweredGraph>(subg, _options);
-
- // Check backend(s) for subgraph support FP16
- bool backends_support_fp16 = true;
- auto &contexts = (*lowered_subgs[index]).backend_contexts();
- for (auto it = contexts.begin(); it != contexts.end(); it++)
- {
- // Controlflow backend is not for actual computaion of operations so it is an exception
- if (it->first->config()->id() != backend::controlflow::Config::ID)
- backends_support_fp16 &= it->first->config()->supportFP16();
- }
+ {
+ _model->iterate([&](const ir::SubgraphIndex &subg_index, ir::IGraph &graph) {
+ auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph);
- if (_options.fp16_enable && backends_support_fp16)
- {
- // NOTE: the only acl_cl backend enables fp16 mode
- Fp32ToFp16Converter(*lowered_subgs[index]).run();
- }
+ // Lower: Assign backend
+ lowered_subgs[subg_index] = std::make_unique<compiler::LoweredGraph>(subg, *_options);
+ // Set tracing_ctx for copied graph
+ tracing_ctx->setSubgraphIndex(&(lowered_subgs[subg_index]->graph()), subg_index.value());
+ });
+ }
- subg.setSubgraphs(nullptr);
- });
+ _model.reset();
- _subgraphs.reset();
+ for (const auto &pair : lowered_subgs)
+ {
+ const auto &subg_index = pair.first;
+ const auto &lowered_subg = pair.second;
+ dot_dumper.dump(*lowered_subg, nnfw::misc::str("after_lower_subg-", subg_index.value()));
+ }
// Shape inference.
{
+ // Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called
+ // recursively
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers =
+ createStaticShapeInferers(lowered_subgs);
+
const auto primary_subg_idx = ir::SubgraphIndex{0};
- StaticShapeInferer inferer(primary_subg_idx, lowered_subgs);
- lowered_subgs.at(primary_subg_idx)
- ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
- auto has_dynamic_tensor = inferer.infer(op_seq);
- op_seq.has_dynamic_tensor(has_dynamic_tensor);
- });
- inferer.dump();
- }
+ inferers.at(primary_subg_idx)->infer();
- /*************************************************************
- * Backend independent analysis & optimization phase finished
- *************************************************************/
+ for (const auto &pair_inferer : inferers)
+ {
+ const auto inferer = pair_inferer.second.get();
+ inferer->dump();
+ }
+ }
- // operation validation
- for (auto &pair : lowered_subgs)
+ // Shape validation
+ // TODO Move shape independent feature check from ShapeValidator to OperationValidator
+ // TODO Move ShapeValidator into shape inference
+ // - Check input tensor shape validation
+ // - Check parameter value validation which valid value is depend on input tensor shape
+ // - Output tensor shape validation check is needless because
+ // static/dynamic shape inferer will make valid output shape
+ for (const auto &pair : lowered_subgs)
{
auto &lowered_subg = pair.second;
- compiler::OperationValidator{lowered_subg->graph()}();
+ compiler::ShapeValidator{lowered_subg->graph()}();
}
- executors = std::make_shared<exec::ExecutorMap>();
- for (auto &pair : lowered_subgs)
+ /*************************************************************
+ * Backend independent analysis & optimization phase finished
+ *************************************************************/
+ auto executors = std::make_shared<exec::SingleModelExecutors>();
+ for (auto &&pair : lowered_subgs)
{
- const auto &subg_index = pair.first;
+ auto const model_index = ir::ModelIndex{0};
+ auto const subg_index = pair.first;
auto &lowered_subg = pair.second;
- auto indexed_ranks = lowered_subg->indexed_ranks();
-
- _options.is_primary_subgraph = (subg_index == ir::SubgraphIndex{0});
+ auto const indexed_ranks = lowered_subg->indexed_ranks();
- onert::dumper::dot::DotDumper dot_dumper_lowered(lowered_subg.get(), dump_level);
- dot_dumper_lowered.dump("after_lower_subg-" + std::to_string(subg_index.value()));
-
- ir::OperationDumper dumper("START SUBGRAPH " + std::to_string(subg_index.value()));
+ ir::OperationDumper dumper("Executor generation of Subgraph " +
+ std::to_string(subg_index.value()));
lowered_subg->graph().operations().iterate(
- [&](const ir::OperationIndex &, const ir::Operation &op) { op.accept(dumper); });
+ [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); });
+
+ ExecutorFactoryArgs args;
+ args.tracing_ctx = tracing_ctx.get();
+ args.options = _options;
+ args.model_index = model_index;
+ args.custom_kernel_builder = custom_kernel_builder;
auto executor = std::unique_ptr<exec::IExecutor>{
- ExecutorFactory::get().create(std::move(lowered_subg), _options, executors)};
+ ExecutorFactory::get().create(std::move(lowered_subg), executors, args)};
executor->setIndexedRanks(indexed_ranks);
- executors->insert(std::make_pair(subg_index, std::move(executor)));
+ executors->emplace(model_index, subg_index, std::move(executor));
}
/********************************
* Code generation phase finished
********************************/
- _state = State::COMPILED;
- return executors;
-}
-
-bool Compiler::checkCompilable()
-{
- // Disable compile phase
- // When ready to use interpreter backend, remove this config and use backend setting
- if (_options.disable_compile)
- {
- return false;
- }
-
- // TODO check unspecified operand shape
-
- // Check compilable parameter
- for (uint32_t i = 0; i < _subgraphs->count(); ++i)
- {
- auto graph = _subgraphs->at(ir::SubgraphIndex{i});
- ParamChecker paramChecker{graph};
- paramChecker();
- if (paramChecker.haveNoneConstParam())
- {
- return false;
- }
- }
-
- return true;
+ return std::make_shared<CompilerArtifact>(executors, std::move(tracing_ctx));
}
} // namespace compiler
-
} // namespace onert
diff --git a/runtime/onert/core/src/compiler/CompilerFactory.cc b/runtime/onert/core/src/compiler/CompilerFactory.cc
new file mode 100644
index 000000000..3e1209a52
--- /dev/null
+++ b/runtime/onert/core/src/compiler/CompilerFactory.cc
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "compiler/CompilerFactory.h"
+
+#include "MultiModelCompiler.h"
+#include "train/TrainingCompiler.h"
+#include "compiler/Compiler.h"
+
+namespace onert
+{
+namespace compiler
+{
+
+CompilerFactory &CompilerFactory::get()
+{
+ static CompilerFactory singleton;
+ return singleton;
+}
+
+std::unique_ptr<ICompiler> CompilerFactory::create(const std::shared_ptr<ir::NNPkg> &nnpkg,
+ CompilerOptions *copts,
+ const ir::train::TrainingInfo *training_info)
+{
+ // Returing compiler for training
+ if (training_info)
+ return std::make_unique<train::TrainingCompiler>(nnpkg, copts, *training_info);
+
+ // Returing compiler for inference
+ if (nnpkg->model_count() == 1)
+ return std::make_unique<Compiler>(nnpkg, copts);
+
+ return std::make_unique<MultiModelCompiler>(nnpkg, copts);
+}
+
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/CompilerHelpers.h b/runtime/onert/core/src/compiler/CompilerHelpers.h
new file mode 100644
index 000000000..798334b3b
--- /dev/null
+++ b/runtime/onert/core/src/compiler/CompilerHelpers.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_COMPILER_HELPERS_H__
+#define __ONERT_COMPILER_COMPILER_HELPERS_H__
+
+#include <compiler/ILoweredGraph.h>
+#include <compiler/StaticShapeInferer.h>
+#include <ir/Index.h>
+
+#include <memory>
+#include <unordered_map>
+
+namespace onert
+{
+namespace compiler
+{
+
+/**
+ * @brief Create a shape inferer map for a lowered model
+ * @param[in] lowered_subgs lowered model map
+ * @return Shape inferer map
+ */
+template <typename LoweredGraphType,
+ typename = std::enable_if_t<std::is_base_of<ILoweredGraph, LoweredGraphType>::value>>
+static std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>>
+createStaticShapeInferers(
+ const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraphType>> &lowered_subgs)
+{
+ std::unordered_map<ir::SubgraphIndex, ILoweredGraph *> lsubgs;
+ for (auto &&e : lowered_subgs)
+ lsubgs[e.first] = e.second.get();
+ return StaticShapeInferer::createStaticShapeInferers(lsubgs);
+}
+
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_COMPILER_HELPERS_H__
diff --git a/runtime/onert/core/src/compiler/CompilerOptions.cc b/runtime/onert/core/src/compiler/CompilerOptions.cc
new file mode 100644
index 000000000..c5aee1956
--- /dev/null
+++ b/runtime/onert/core/src/compiler/CompilerOptions.cc
@@ -0,0 +1,147 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "compiler/CompilerOptions.h"
+
+#include "../backend/builtin/Backend.h"
+
+#include "util/ConfigSource.h"
+#include "util/logging.h"
+
+#include <misc/string_helpers.h>
+
+namespace
+{
+
+using namespace onert;
+
+std::string getOpBackends(std::unordered_map<ir::OpCode, std::string> &opcode_to_backend)
+{
+ std::unordered_map<ir::OpCode, std::string>::iterator it;
+ std::string opbackends;
+
+ for (it = opcode_to_backend.begin(); it != opcode_to_backend.end(); ++it)
+ {
+ if (!opbackends.empty())
+ opbackends = opbackends + ", ";
+
+ auto opcode = it->first;
+ const std::string opname = ir::toString(opcode);
+ opbackends += opname + "=" + it->second;
+ }
+ return opbackends;
+}
+
+} // namespace
+
+namespace onert
+{
+namespace compiler
+{
+
+void ManualSchedulerOptions::setBackendMap(const std::string &str)
+{
+ // TODO Support multiple subgraphs for manual scheduling
+ auto key_val_list = nnfw::misc::split(str, ';');
+ for (const auto &key_val_str : key_val_list)
+ {
+ if (key_val_str.empty())
+ {
+ continue;
+ }
+
+ auto key_val = nnfw::misc::split(key_val_str, '=');
+ if (key_val.size() != 2)
+ throw std::runtime_error{"Invalid key-value pair"};
+
+ const auto &key_str = key_val.at(0);
+ const auto &val = key_val.at(1);
+ auto key = static_cast<uint32_t>(std::stoi(key_str));
+ this->index_to_backend.emplace(ir::OperationIndex{key}, val);
+ }
+}
+
+std::unique_ptr<CompilerOptions> CompilerOptions::fromGlobalConfig()
+{
+ auto o = std::make_unique<CompilerOptions>();
+ o->backend_list = nnfw::misc::split(util::getConfigString(util::config::BACKENDS), ';');
+ o->graph_dump_level = util::getConfigInt(util::config::GRAPH_DOT_DUMP);
+ o->executor = util::getConfigString(util::config::EXECUTOR);
+ o->he_scheduler = util::getConfigBool(util::config::USE_SCHEDULER);
+ o->he_profiling_mode = util::getConfigBool(util::config::PROFILING_MODE);
+ o->fp16_enable = util::getConfigBool(util::config::FP16_ENABLE);
+ o->workspace_dir = util::getConfigString(util::config::WORKSPACE_DIR);
+ {
+ // Backend for all
+ auto &ms_options = o->manual_scheduler_options;
+
+ // Default value for op_backend_all is first element in the backend list
+ ms_options.backend_for_all = util::getConfigString(util::config::OP_BACKEND_ALLOPS);
+
+// Opcode to Backend
+#define OP(OpName) \
+ { \
+ const auto &backend_str = util::getConfigString(util::config::OP_BACKEND_##OpName); \
+ if (!backend_str.empty()) \
+ { \
+ ms_options.opcode_to_backend[ir::OpCode::OpName] = backend_str; \
+ } \
+ }
+#include "ir/Operations.lst"
+#undef OP
+
+ // Index to Backend
+ auto map_str = util::getConfigString(util::config::OP_BACKEND_MAP);
+ ms_options.setBackendMap(map_str);
+ }
+ return o;
+}
+
+void CompilerOptions::forceInternalOptions()
+{
+ // Set control flow backend for control flow operators
+ auto &builtin_id = backend::builtin::Config::ID;
+ manual_scheduler_options.opcode_to_backend[ir::OpCode::If] = builtin_id;
+ manual_scheduler_options.opcode_to_backend[ir::OpCode::While] = builtin_id;
+ manual_scheduler_options.opcode_to_backend[ir::OpCode::Permute] = builtin_id;
+
+ // FIXME This is a workaround for bcq operations, should remove it
+ manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQFullyConnected] = "bcq";
+ manual_scheduler_options.opcode_to_backend[ir::OpCode::BCQGather] = "bcq";
+
+ // FIXME This is a workaround for bulk operations, should remove it
+ manual_scheduler_options.opcode_to_backend[ir::OpCode::Bulk] = "trix";
+}
+
+void CompilerOptions::verboseOptions()
+{
+ VERBOSE(Compiler) << std::boolalpha << "==== Compiler Options ====" << std::endl;
+ VERBOSE(Compiler) << "backend_list : "
+ << nnfw::misc::join(backend_list.begin(), backend_list.end(), "/") << std::endl;
+ VERBOSE(Compiler) << "graph_dump_level : " << graph_dump_level << std::endl;
+ VERBOSE(Compiler) << "executor : " << executor << std::endl;
+ VERBOSE(Compiler) << "manual backend_for_all : " << manual_scheduler_options.backend_for_all
+ << std::endl;
+ VERBOSE(Compiler) << "manual_scheduler_options : "
+ << getOpBackends(manual_scheduler_options.opcode_to_backend) << std::endl;
+ VERBOSE(Compiler) << "he_scheduler : " << he_scheduler << std::endl;
+ VERBOSE(Compiler) << "he_profiling_mode : " << he_profiling_mode << std::endl;
+ VERBOSE(Compiler) << "fp16_enable : " << fp16_enable << std::endl
+ << std::noboolalpha;
+}
+
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/ExecutorFactory.cc b/runtime/onert/core/src/compiler/ExecutorFactory.cc
index 062c6c9c3..eff3f5abe 100644
--- a/runtime/onert/core/src/compiler/ExecutorFactory.cc
+++ b/runtime/onert/core/src/compiler/ExecutorFactory.cc
@@ -16,24 +16,29 @@
#include "ExecutorFactory.h"
+#include "Linear.h"
+#include "../backend/builtin/BackendContext.h"
+#include "../backend/builtin/Config.h"
+#include "../backend/builtin/UserTensor.h"
+#include "../backend/builtin/train/BackendContext.h"
+#include "../dumper/text/GraphDumper.h"
+#include "../exec/DataflowExecutor.h"
+#include "../exec/ExecTime.h"
+#include "../exec/ExecutionObservers.h"
+#include "../exec/LinearExecutor.h"
+#include "../exec/MinMaxRecorder.h"
+#include "../exec/ParallelExecutor.h"
+#include "../exec/train/TrainableExecutor.h"
+#include "../ir/OperationCloner.h"
+
+#include <backend/IPortableTensor.h>
+#include <backend/train/TrainableBackendContext.h>
+#include <backend/train/ITrainableBackend.h>
+#include <compiler/BackendManager.h>
+#include <compiler/ExecutionBuilder.h>
+#include <util/TracingCtx.h>
+
#include <functional>
-#include "exec/ExecutionObservers.h"
-#include "exec/LinearExecutor.h"
-#include "exec/DataflowExecutor.h"
-#include "exec/ParallelExecutor.h"
-#include "compiler/BackendManager.h"
-#include "compiler/ExecutionBuilder.h"
-#include "exec/ExecTime.h"
-#include "compiler/Linear.h"
-#include "compiler/TensorBuilders.h"
-#include "backend/IConstantInitializer.h"
-#include "backend/IKernelGenerator.h"
-#include "backend/IOptimizer.h"
-#include "backend/ITensorRegister.h"
-#include "backend/controlflow/Config.h"
-#include "backend/controlflow/KernelGenerator.h"
-#include "backend/controlflow/UserTensor.h"
-#include "backend/controlflow/TensorBuilder.h"
#include <memory>
namespace onert
@@ -46,7 +51,7 @@ class SyncFunction final : public exec::IFunction
public:
virtual ~SyncFunction() = default;
SyncFunction(std::unique_ptr<exec::IFunction> fn, const std::shared_ptr<backend::IConfig> config)
- : _fn{std::move(fn)}, _config{config}
+ : _fn{std::move(fn)}, _config{config}
{
assert(_fn);
assert(_config);
@@ -65,21 +70,221 @@ private:
std::shared_ptr<backend::IConfig> _config;
};
-// TODO Think of a better way to manage TensorManagers
-backend::TensorManagerSet createTensorManagerSet(const compiler::TensorBuilders &tensor_builders)
+using DeallocList = std::vector<backend::ITensor *>;
+// Deallocation after execution of an operation used by Linear Executor
+class DeallocFunction final : public exec::IFunction
+{
+public:
+ DeallocFunction(const DeallocList &tensors) : _dealloc_list{tensors} {}
+
+ void run() override
+ {
+ for (auto &&tensor : _dealloc_list)
+ {
+ if (!tensor->is_dynamic())
+ continue;
+ tensor->deallocBuffer();
+ }
+ }
+
+private:
+ DeallocList _dealloc_list;
+};
+
+// TODO Unify initializeSubgraphIOTensors
+void initializeSubgraphIOTensors(compiler::ILoweredGraph &lowered_graph,
+ const backend::BackendContexts &backend_contexts,
+ const ir::OperandIndexSequence &indices)
+{
+ // TODO Store builtin backend in BackendContext
+ std::shared_ptr<backend::builtin::TensorRegistry> builtin_tensor_reg;
+ for (const auto &e : backend_contexts)
+ {
+ auto backend = e.first;
+ auto &context = e.second;
+ if (backend->config()->id() == backend::builtin::Config::ID)
+ {
+ builtin_tensor_reg =
+ std::dynamic_pointer_cast<backend::builtin::TensorRegistry>(context->tensor_registry);
+ }
+ }
+ assert(builtin_tensor_reg);
+
+ for (auto &&ind : indices)
+ {
+ const auto &operand = lowered_graph.graph().operands().at(ind);
+ auto tensor = std::make_unique<backend::builtin::IOTensor>(
+ operand.info(),
+ ir::Layout::NHWC /* FIXME find operation for this operand and use frontend_layout */
+ );
+
+ // Add tensor to builtin TensorRegistry.
+ builtin_tensor_reg->setNativeIOTensor(ind, std::move(tensor));
+ }
+}
+
+void initializeSubgraphIOTensors(compiler::ILoweredGraph &lowered_graph,
+ const backend::train::TrainableBackendContexts &backend_contexts,
+ const ir::OperandIndexSequence &indices)
+{
+ std::shared_ptr<backend::builtin::train::TensorRegistry> builtin_tensor_reg;
+ for (const auto &e : backend_contexts)
+ {
+ auto backend = e.first;
+ auto &context = e.second;
+ if (backend->config()->id() == backend::builtin::Config::ID)
+ {
+ builtin_tensor_reg = std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>(
+ context->tensor_registry());
+ }
+ }
+ assert(builtin_tensor_reg);
+
+ for (auto &&ind : indices)
+ {
+ const auto &operand = lowered_graph.graph().operands().at(ind);
+ auto tensor = std::make_unique<backend::builtin::IOTensor>(
+ operand.info(),
+ ir::Layout::NHWC /* FIXME find operation for this operand and use frontend_layout */
+ );
+
+ // Add tensor to builtin TensorRegistry.
+ builtin_tensor_reg->setNativeIOTensor(ind, std::move(tensor));
+ }
+}
+
+backend::BackendContexts
+createBackendContexts(compiler::ILoweredGraph &lgraph, bool linear_executor,
+ std::shared_ptr<backend::custom::IKernelBuilder> custom_kernel_builder)
{
- backend::TensorManagerSet tensor_mgrs;
- for (auto &tensor_builder : tensor_builders)
+ backend::BackendContexts contexts;
+ std::unordered_map<const backend::Backend *, backend::ContextData> context_data_map;
+
+ // Generate partial graphs for each backend
+ auto init_context_data = [&](const backend::Backend *backend) {
+ auto &data = context_data_map[backend];
+ auto graph = std::make_unique<ir::Graph>();
+ graph->setLayout(lgraph.graph().layout());
+ data.graph = std::move(graph);
+ };
+
+ auto &whole_graph = lgraph.graph();
+ // Separate operands into partial graphs
+ whole_graph.operands().iterate([&](const ir::OperandIndex &operand_ind, ir::Operand &operand) {
+ auto &operand_li = lgraph.lower_info().operand;
+ const auto &def_factors = operand_li.at(operand_ind).def_factors();
+ if (def_factors.size() == 0) // Ignore unused tensor
+ return;
+ const auto &def_factor = def_factors.getOnlyElement();
+ const auto backend = def_factor.backend();
+ if (context_data_map.find(backend) == context_data_map.end())
+ init_context_data(backend);
+
+ auto &partial_graph = *context_data_map[backend].graph;
+ auto &operand_layouts = context_data_map[backend].operand_layouts;
+ assert(operand_layouts.find(operand_ind) == operand_layouts.end());
+ operand_layouts[operand_ind] = def_factor.layout();
+
+ // Copy the operand and insert it to the partial graph
+ auto new_operand = std::make_unique<ir::Operand>(operand);
+ new_operand->clearDefUse();
+ operand.releaseData(); // Deref data of LoweredGraph
+ auto new_operand_ind = partial_graph.addOperand(operand_ind, std::move(new_operand));
+ UNUSED_RELEASE(new_operand_ind);
+ assert(new_operand_ind == operand_ind);
+ });
+ // Separate operations into partial graphs
+ whole_graph.operations().iterate(
+ [&](const ir::OperationIndex &op_ind, const ir::IOperation &operation) {
+ auto &op_li = lgraph.lower_info().operation;
+ auto backend = op_li.at(op_ind).backend();
+ if (context_data_map.find(backend) == context_data_map.end())
+ init_context_data(backend);
+
+ auto &partial_graph = *context_data_map[backend].graph;
+ auto &external_operands = context_data_map[backend].external_operands;
+ auto &operand_layouts = context_data_map[backend].operand_layouts;
+
+ {
+ // Add missing operands (externals)
+ auto io_list = (operation.getInputs() + operation.getOutputs()) | ir::Remove::DUPLICATED |
+ ir::Remove::UNDEFINED;
+ for (auto &&operand_ind : io_list)
+ {
+ if (partial_graph.operands().exist(operand_ind))
+ continue;
+
+ // Copy the operand and insert it to the partial graph
+ const auto &operand = whole_graph.operands().at(operand_ind);
+ auto new_operand = std::make_unique<ir::Operand>(operand);
+ new_operand->clearDefUse();
+ auto new_operand_ind = partial_graph.addOperand(operand_ind, std::move(new_operand));
+ UNUSED_RELEASE(new_operand_ind);
+ assert(new_operand_ind == operand_ind);
+
+ auto layout =
+ lgraph.lower_info().operand.at(operand_ind).def_factors().getOnlyElement().layout();
+ assert(operand_layouts.find(operand_ind) == operand_layouts.end());
+ operand_layouts[operand_ind] = layout;
+ external_operands.add(operand_ind);
+ }
+
+ auto new_op_ind = partial_graph.addOperation(op_ind, clone(operation));
+ UNUSED_RELEASE(new_op_ind);
+ assert(new_op_ind == op_ind);
+ }
+ });
+
+ // Create contexts
+ auto whole_op_order = lgraph.graph().topolSortOperations();
+ for (auto &&pair : context_data_map)
{
- auto s_tensor_manager = tensor_builder->releaseStaticTensorManager();
- if (s_tensor_manager != nullptr)
- tensor_mgrs.insert(std::move(s_tensor_manager));
+ auto backend = pair.first;
+ auto &data = pair.second;
+ // Handle graph input/outputs or external tensors
+ data.graph->operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &operand) {
+ if (whole_graph.getInputs().contains(ind) || whole_graph.getOutputs().contains(ind))
+ data.external_operands.add(ind);
+ // Inputs are either "graph input" or "no def op and non-constant"
+ if (whole_graph.getInputs().contains(ind) ||
+ (!operand.getDef().valid() && !operand.isConstant()))
+ // Outputs are either "graph output" or "no uses"
+ data.graph->addInput(ind);
+ if (whole_graph.getOutputs().contains(ind) || operand.getUses().size() == 0)
+ data.graph->addOutput(ind);
+ });
+ VERBOSE(ExecutorFactory) << "createBackendContexts: partial graph for backend="
+ << backend->config()->id() << std::endl;
+ dumper::text::dumpGraph(*data.graph);
+
+ std::copy_if(whole_op_order.begin(), whole_op_order.end(), std::back_inserter(data.op_order),
+ [&](const auto &ind) { return data.graph->operations().exist(ind); });
+ data.is_linear_executor = linear_executor;
+ data.custom_kernel_builder = custom_kernel_builder;
+ contexts.emplace(backend, backend->newContext(std::move(data)));
+ }
+ return contexts;
+}
- auto d_tensor_manager = tensor_builder->releaseDynamicTensorManager();
- if (d_tensor_manager != nullptr)
- tensor_mgrs.insert(std::move(d_tensor_manager));
+template <typename Context>
+std::deque<std::pair<const backend::Backend *, Context *>> orderBackendContext(
+ const std::unordered_map<const backend::Backend *, std::unique_ptr<Context>> &tbackend_contexts)
+{
+ std::deque<std::pair<const backend::Backend *, Context *>> ordered_contexts;
+
+ for (auto &&pair : tbackend_contexts)
+ {
+ // NOTE builtin backend must be processed lastly.
+ // This is because of Permute layer's specialty which is the only operation that could have
+ // different ITensor objects for the input and the output. And it requires all other backends'
+ // tensors are ready to use.
+ if (pair.first->config()->id() == "builtin")
+ ordered_contexts.emplace_back(pair.first, pair.second.get());
+ else
+ ordered_contexts.emplace_front(pair.first, pair.second.get());
}
- return tensor_mgrs;
+
+ return ordered_contexts;
}
} // namespace
@@ -106,412 +311,582 @@ ExecutorFactory::ExecutorFactory()
}
exec::IExecutor *ExecutorFactory::create(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const compiler::CompilerOptions &options,
- const std::shared_ptr<exec::ExecutorMap> &executor_map)
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args)
{
- return _map.at(options.executor)(std::move(lowered_graph), options, executor_map);
+ assert(args.options != nullptr);
+ return _map.at(args.options->executor)(std::move(lowered_graph), executors, args);
}
-void ExecutorFactory::initializeBackendContext(compiler::LoweredGraph *lowered_graph)
+void ExecutorFactory::prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph,
+ const backend::BackendContexts &backend_contexts)
{
- struct Entry
- {
- std::vector<backend::BackendContext::OperationInfo> operation_list;
- std::vector<ir::OperandIndex> operand_list;
- };
- std::unordered_map<const backend::Backend *, Entry> backend_assets;
-
- // Build lists for operations
- lowered_graph->op_seqs().iterate(
- [&](const ir::OpSequenceIndex &op_seq_index, const ir::OpSequence &op_seq) {
- auto &op_seq_li = lowered_graph->getLowerInfo()->op_seq;
- auto backend = op_seq_li.at(op_seq_index)->backend();
- for (auto &operation_idx : op_seq.operations())
+ TensorRegistries tensor_regs{backend_contexts, true};
+
+ lowered_graph.graph().operations().iterate(
+ [&](const ir::OperationIndex &op_ind, const ir::IOperation &op) {
+ auto lower_info = lowered_graph.lower_info().operation.getRawPtr(op_ind);
+ auto &backend_ctx = backend_contexts.at(lower_info->backend());
+ for (auto &&ind :
+ (op.getInputs() + op.getOutputs()) | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
+ {
+ // If an Operation's input/output tensor does not have an own tensor object,
+ // it must be using migrant tensors, so find the tensor from other tensor registries and
+ // register it to the current tensor registry if it is portable
+ if (!backend_ctx->tensor_registry->getITensor(ind))
{
- backend_assets[backend].operation_list.emplace_back(operation_idx, op_seq.getLayout());
+ auto tensor = tensor_regs.getITensor(ind);
+ assert(tensor); // The tensor must have been registered
+ auto ptensor = dynamic_cast<backend::IPortableTensor *>(tensor);
+ if (ptensor)
+ backend_ctx->tensor_registry->setMigrantTensor(ind, ptensor);
}
- });
+ }
+ });
+}
- // Build lists for operands
- lowered_graph->graph().operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &) {
- const auto lower_info = lowered_graph->getLowerInfo(ind);
- for (auto factor : lower_info->def_factors())
+void ExecutorFactory::prepareBuiltinBackend(const TensorRegistries &tensor_regs,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const backend::BackendContexts &backend_contexts,
+ const ir::ModelIndex &index)
+{
+ for (auto &&pair : backend_contexts)
+ {
+ auto builtin_context = dynamic_cast<backend::builtin::BackendContext *>(pair.second.get());
+ if (builtin_context != nullptr)
{
- auto backend = factor.backend();
- backend_assets[backend].operand_list.emplace_back(ind);
+ auto builtin_kernel_gen = builtin_context->kernel_gen;
+ builtin_kernel_gen->setTensorRegistries(tensor_regs);
+ builtin_kernel_gen->setExecutors(executors);
+ builtin_kernel_gen->setModelIndex(index);
}
- });
+ }
+}
- for (auto &pair : backend_assets)
+std::deque<std::pair<const backend::Backend *, backend::BackendContext *>>
+ExecutorFactory::orderBackendContext(const backend::BackendContexts &backend_contexts)
+{
+ std::deque<std::pair<const backend::Backend *, backend::BackendContext *>> ordered_contexts;
+ for (auto &&pair : backend_contexts)
{
- auto backend = pair.first;
- auto &arg = pair.second;
- lowered_graph->backend_contexts().at(backend)->initialize(arg.operation_list, arg.operand_list);
+ // NOTE builtin backend must be processed lastly.
+ // This is because of Permute layer's specialty which is the only operation that could have
+ // different ITensor objects for the input and the output. And it requires all other backends'
+ // tensors are ready to use.
+ if (pair.first->config()->id() == "builtin")
+ ordered_contexts.emplace_back(pair.first, pair.second.get());
+ else
+ ordered_contexts.emplace_front(pair.first, pair.second.get());
}
+ return ordered_contexts;
}
-void ExecutorFactory::runTensorRegistration(compiler::LoweredGraph *lowered_graph,
- const std::vector<ir::OpSequenceIndex> &order)
+exec::IExecutor *
+ExecutorFactory::createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args)
{
- for (const auto index : order)
+ const auto options = args.options;
+ const auto &model_index = args.model_index;
+ const auto tracing_ctx = args.tracing_ctx;
+ auto custom_kernel_builder = args.custom_kernel_builder;
+ auto &graph = lowered_graph->graph();
+
+ backend::BackendContexts backend_contexts =
+ createBackendContexts(*lowered_graph, options->executor == "Linear", custom_kernel_builder);
+
+ TensorRegistries tensor_regs{backend_contexts, true};
+
+ initializeSubgraphIOTensors(
+ *lowered_graph, backend_contexts,
+ (lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs()) |
+ ir::Remove::DUPLICATED | ir::Remove::UNDEFINED);
+
+ // linearize
+ auto order = Linear::linearize(*lowered_graph);
+ Linear::dump(*lowered_graph, order);
+
+ for (auto &&pair : backend_contexts)
+ {
+ pair.second->genTensors();
+ }
+
+ prepareMigrantTensors(*lowered_graph, backend_contexts);
+
+ // Give some runtime objects to builtin KernelGenerator
+ prepareBuiltinBackend(tensor_regs, executors, backend_contexts, model_index);
+
+ ExecutionBuilder builder;
+
+ // Adjust the order of backends for the upcoming iteration
+ auto ordered_contexts = orderBackendContext(backend_contexts);
+
+ // Simulate the execution for deallocation of tensors
+ std::unordered_map<ir::OperationIndex, DeallocList> dealloc_list_map;
{
- const auto &op_seq = lowered_graph->op_seqs().at(index);
- const auto backend = lowered_graph->getLowerInfo(index)->backend();
- const auto tensor_register = lowered_graph->backend_contexts().at(backend)->tensor_register;
- auto tensor_builder = lowered_graph->backend_contexts().at(backend)->tensor_builder;
- auto model_io = lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs();
+ ir::OperandIndexMap<uint32_t> uses_map;
+ ir::OperandIndexSequence constants;
+
+ auto model_io =
+ (graph.getInputs() + graph.getOutputs()) | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
+
+ // Prepare scanning
+ graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) {
+ uses_map[ind] = obj.getUses().size();
+
+ if (obj.isConstant())
+ constants.append(ind);
+ });
- if (tensor_register)
+ // A trick to consider constants as an execption
+ for (const auto &ind : constants)
{
- // Custom registration
- tensor_register->registerTensors(op_seq, lowered_graph->getLowerInfo());
+ uses_map[ind]++;
}
- else
+
+ for (const auto &op_ind : order)
{
- // Default registration
- for (const auto op_idx : op_seq)
+ const auto &op = graph.operations().at(op_ind);
+ auto op_inputs = op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED;
+ auto op_outputs = op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED;
+
+ for (const auto &ind : op_inputs)
{
- const auto &op = lowered_graph->graph().operations().at(op_idx);
- for (const auto &index : (op.getInputs() | ir::Remove::UNDEFINED) + op.getOutputs())
+ const auto &operand = graph.operands().at(ind);
+ assert(uses_map.find(ind) != uses_map.end());
+ assert(uses_map[ind] > 0);
+ uses_map[ind]--;
+ if (uses_map[ind] == 0 && !operand.info().isVariable() && !model_io.contains(ind))
{
- if (!tensor_builder->isRegistered(index) && !model_io.contains(index))
- {
- const auto &operand_lower_info =
- lowered_graph->getLowerInfo(index)->def_factors().getOnlyElement();
-
- // E.g., permute (CPU) -> tensor A -> MaxPool2D(acl_cl)
- // op.getOutputs() of permute (CPU) returns tensor A
- // but tensor A belongs to the backend of acl_cl.
- // So, we have to make this tensor NOT registered for CPU.
- if (operand_lower_info.backend() != backend)
- continue;
-
- const auto &obj = lowered_graph->graph().operands().at(index);
- const auto frontend_layout = op_seq.getLayout();
- const auto backend_layout = operand_lower_info.layout();
- ir::OperandInfo backend_info{permuteShape(obj.shape(), frontend_layout, backend_layout),
- obj.typeInfo(), obj.info().memAllocType(),
- obj.isConstant()};
- tensor_builder->registerTensorInfo(index, backend_info, backend_layout);
- }
+ dealloc_list_map[op_ind].emplace_back(tensor_regs.getITensor(ind));
}
}
}
- }
-}
-std::vector<std::shared_ptr<backend::ITensor>>
-ExecutorFactory::initializeModelIOTensors(compiler::LoweredGraph &lowered_graph,
- const ir::OperandIndexSequence &indices)
-{
- std::vector<std::shared_ptr<backend::ITensor>> ret;
+ // Dispose and validate
+ for (const auto &ind : constants)
+ {
+ --uses_map[ind];
+ }
+
+ assert(
+ std::all_of(uses_map.begin(), uses_map.end(),
+ [](std::pair<const ir::OperandIndex, uint32_t> it) { return it.second == 0; }));
+ }
- // TODO Store controlflow backend in BackendContext
- std::shared_ptr<backend::controlflow::TensorBuilder> cf_tensor_builder;
- std::shared_ptr<backend::controlflow::TensorRegistry> cf_tensor_reg;
- for (const auto &e : lowered_graph.backend_contexts())
+ // Generate kernels
+ for (auto &&pair : ordered_contexts)
{
- auto backend = e.first;
- auto &context = e.second;
- if (backend->config()->id() == backend::controlflow::Config::ID)
+ auto codes = pair.second->genKernels();
+ for (auto &&pair : codes)
{
- cf_tensor_builder =
- std::dynamic_pointer_cast<backend::controlflow::TensorBuilder>(context->tensor_builder);
- cf_tensor_reg =
- std::dynamic_pointer_cast<backend::controlflow::TensorRegistry>(context->tensor_registry);
+ auto &op_ind = pair.first;
+ auto &fn_seq = pair.second;
+ auto &op = lowered_graph->graph().operations().at(op_ind);
+ auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind);
+ if (options->he_profiling_mode)
+ fn_seq->wrap<SyncFunction>(lower_info->backend()->config());
+ if (!dealloc_list_map[op_ind].empty())
+ fn_seq->append(std::make_unique<DeallocFunction>(dealloc_list_map[op_ind]));
+ builder.append(op_ind, {op_ind, &op, lower_info, std::move(fn_seq)});
}
}
- assert(cf_tensor_builder);
- assert(cf_tensor_reg);
- for (auto ind : indices)
+ auto code_map = builder.releaseCodeMap();
+
+ auto exec = new exec::LinearExecutor{std::move(lowered_graph),
+ std::move(backend_contexts),
+ tensor_regs,
+ std::move(code_map),
+ order,
+ tracing_ctx};
+
+ if (!options->workspace_dir.empty())
{
- const auto &operand = lowered_graph.graph().operands().at(ind);
- auto tensor = std::make_shared<backend::controlflow::UserTensor>(
- operand.info(),
- ir::Layout::NHWC, /* FIXME find op_seq for this operand and use frontend_layout */
- cf_tensor_builder->dynamicTensorManager());
-
- // Add tensor to controlflow TensorRegistry.
- cf_tensor_reg->setNativeUserTensor(ind, tensor);
- ret.push_back(tensor);
+ exec->addObserver(
+ std::make_unique<exec::TracingObserver>(options->workspace_dir, exec->graph(), tracing_ctx));
+ exec->addObserver(std::make_unique<exec::MinMaxRecorder>(options->workspace_dir, exec->graph(),
+ exec->getBackendContexts()));
}
- return ret;
-}
-void ExecutorFactory::prepareExternalTensors(compiler::LoweredGraph &lowered_graph)
-{
- TensorRegistries tensor_regs{lowered_graph.backend_contexts(), true};
-
- lowered_graph.op_seqs().iterate(
- [&](const ir::OpSequenceIndex &op_seq_index, const ir::OpSequence &op_seq) {
- auto lower_info = lowered_graph.getLowerInfo(op_seq_index);
- auto &backend_ctx = lowered_graph.backend_contexts().at(lower_info->backend());
- for (auto ind : (op_seq.getInputs() + op_seq.getOutputs()) | ir::Remove::DUPLICATED |
- ir::Remove::UNDEFINED)
- {
- // If an OpSequence input/output tensor does not have a own tensor object,
- // it must be using external tensors, so find the tensor from other tensor builders and
- // set the tensor to this tensor builder if portable
- if (!backend_ctx->tensor_registry->getITensor(ind))
- {
- auto tensor = tensor_regs.getITensor(ind);
- assert(tensor); // The tensor must have been registered
- auto ptensor = std::dynamic_pointer_cast<backend::IPortableTensor>(tensor);
- if (ptensor)
- backend_ctx->tensor_registry->setMigrantTensor(ind, ptensor);
- }
- }
- });
+ return exec;
}
exec::IExecutor *
-ExecutorFactory::createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const compiler::CompilerOptions &options,
- const std::shared_ptr<exec::ExecutorMap> &executor_map)
+ExecutorFactory::createDataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args, bool parallel)
{
- const auto &backend_contexts = lowered_graph->backend_contexts();
+ const auto options = args.options;
+ const auto &model_index = args.model_index;
+ const auto tracing_ctx = args.tracing_ctx;
+ auto custom_kernel_builder = args.custom_kernel_builder;
- initializeBackendContext(lowered_graph.get());
+ backend::BackendContexts backend_contexts =
+ createBackendContexts(*lowered_graph, options->executor == "Linear", custom_kernel_builder);
- // linearize
- assert(!lowered_graph->graph().isBuildingPhase());
+ TensorRegistries tensor_regs{backend_contexts, true};
- /*************************************************
- * Backend dependent analysis & optimization phase
- *************************************************/
+ initializeSubgraphIOTensors(
+ *lowered_graph, backend_contexts,
+ (lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs()) |
+ ir::Remove::DUPLICATED | ir::Remove::UNDEFINED);
- for (auto &pair : backend_contexts)
+ for (auto &&pair : backend_contexts)
{
- auto &optimizer = pair.second->optimizer;
- if (optimizer)
- optimizer->optimize();
+ pair.second->genTensors();
}
- /**********************************************************
- * Backend dependent analysis & optimization phase finished
- **********************************************************/
+ prepareMigrantTensors(*lowered_graph, backend_contexts);
- /***********************
- * Code generation phase
- ***********************/
+ // Give some runtime objects to builtin KernelGenerator
+ prepareBuiltinBackend(tensor_regs, executors, backend_contexts, model_index);
- auto order = Linear::linearize(*lowered_graph);
- runTensorRegistration(lowered_graph.get(), order);
+ ExecutionBuilder builder;
+
+ // Adjust the order of backends for the upcoming iteration
+ auto ordered_contexts = orderBackendContext(backend_contexts);
- std::vector<std::shared_ptr<backend::ITensor>> input_tensors;
- std::vector<std::shared_ptr<backend::ITensor>> output_tensors;
- if (options.is_primary_subgraph)
+ // Generate kernels
+ for (auto &&pair : ordered_contexts)
{
- input_tensors = initializeModelIOTensors(*lowered_graph, lowered_graph->graph().getInputs());
- output_tensors = initializeModelIOTensors(*lowered_graph, lowered_graph->graph().getOutputs());
+ auto codes = pair.second->genKernels();
+ for (auto &&pair : codes)
+ {
+ auto &op_ind = pair.first;
+ auto &fn_seq = pair.second;
+ auto &op = lowered_graph->graph().operations().at(op_ind);
+ auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind);
+ if (options->he_profiling_mode)
+ fn_seq->wrap<SyncFunction>(lower_info->backend()->config());
+ builder.append(op_ind, {op_ind, &op, lower_info, std::move(fn_seq)});
+ }
}
- Linear::dump(*lowered_graph, order);
- Linear::planTensors(*lowered_graph, order);
+ auto code_map = builder.releaseCodeMap();
- TensorBuilders tensor_builders{lowered_graph->backend_contexts(), true};
- TensorRegistries tensor_regs{lowered_graph->backend_contexts(), true};
+ exec::ExecutorBase *exec = nullptr;
+ if (parallel)
+ {
+ exec = new exec::ParallelExecutor{std::move(lowered_graph), std::move(backend_contexts),
+ tensor_regs, std::move(code_map), tracing_ctx};
+ }
+ else
+ {
+ auto dataflow_exec =
+ new exec::DataflowExecutor{std::move(lowered_graph), std::move(backend_contexts), tensor_regs,
+ std::move(code_map), tracing_ctx};
+ if (options->he_profiling_mode)
+ {
+ std::vector<const backend::Backend *> backends;
+ for (const auto &pair : backend_contexts)
+ {
+ backends.push_back(pair.first);
+ }
+ auto et = std::make_shared<exec::ExecTime>(backends);
+ std::unique_ptr<exec::IExecutionObserver> obs =
+ std::make_unique<exec::ProfileObserver>(et, dataflow_exec->graph());
+ dataflow_exec->addObserver(std::move(obs));
+ }
+ exec = dataflow_exec;
+ }
- for (auto &tensor_builder : tensor_builders)
+ if (!options->workspace_dir.empty())
{
- tensor_builder->prepare();
+ exec->addObserver(
+ std::make_unique<exec::TracingObserver>(options->workspace_dir, exec->graph(), tracing_ctx));
}
- prepareExternalTensors(*lowered_graph);
+ return exec;
+}
+
+exec::IExecutor *
+ExecutorFactory::create(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args,
+ const ir::train::TrainingInfo &training_info)
+{
+ assert(args.options != nullptr);
+
+ if (args.options->executor != "Linear")
+ throw std::runtime_error("ExecutorFactory: TrainableExecutor supports only 'Linear' now");
- ExecutionBuilder builder;
+ return createTrainableExecutor(std::move(lowered_graph), executors, args, training_info);
+}
- // Generate kernels
- lowered_graph->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &op_seq_index,
- const ir::OpSequence &op_seq) {
- auto lower_info = lowered_graph->getLowerInfo(op_seq_index);
- auto kernel_gen = lowered_graph->backend_contexts().at(lower_info->backend())->kernel_gen;
- // Set TensorBuilderSet and ExecutorMap to kernel_gen of control flow
- auto cf_kernel_gen = dynamic_cast<backend::controlflow::KernelGenerator *>(kernel_gen.get());
- if (cf_kernel_gen != nullptr)
+void ExecutorFactory::prepareMigrantTensors(
+ compiler::ILoweredGraph &lowered_graph,
+ const backend::train::TrainableBackendContexts &backend_contexts)
+{
+ train::TensorRegistries tensor_regs{backend_contexts, true};
+
+ lowered_graph.graph().operations().iterate(
+ [&](const ir::OperationIndex &op_ind, const ir::IOperation &op) {
+ auto lower_info = lowered_graph.lower_info().operation.getRawPtr(op_ind);
+ auto &backend_ctx = backend_contexts.at(lower_info->backend());
+ for (auto &&ind :
+ (op.getInputs() + op.getOutputs()) | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
+ {
+ // If an Operation's input/output tensor does not have an own tensor object,
+ // it must be using migrant tensors, so find the tensor from other tensor registries and
+ // register it to the current tensor registry if it is portable
+ if (!backend_ctx->tensor_registry()->getITensor(ind))
+ {
+ auto tensor = tensor_regs.getITensor(ind);
+ assert(tensor); // The tensor must have been registered
+ auto ptensor = dynamic_cast<backend::IPortableTensor *>(tensor);
+ if (ptensor)
+ backend_ctx->tensor_registry()->setMigrantTensor(ind, ptensor);
+ }
+ }
+ });
+}
+
+exec::IExecutor *ExecutorFactory::createTrainableExecutor(
+ std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &, const ExecutorFactoryArgs &args,
+ const ir::train::TrainingInfo &training_info)
+{
+ const auto options = args.options;
+ const auto tracing_ctx = args.tracing_ctx;
+ auto custom_kernel_builder = args.custom_kernel_builder;
+
+ auto &graph = lowered_graph->graph();
+
+ lowered_graph->trainable_graph().operations().iterate([](const onert::ir::OperationIndex &,
+ const onert::ir::IOperation &op) {
+ try
{
- cf_kernel_gen->setTensorRegistries(tensor_regs);
- cf_kernel_gen->setExecutorMap(executor_map);
+ UNUSED_RELEASE(dynamic_cast<const ir::train::ITrainableOperation &>(op));
}
- auto fn_seq = kernel_gen->generate(op_seq);
- if (options.he_profiling_mode)
+ catch (std::bad_cast &)
{
- fn_seq->wrap<SyncFunction>(lower_info->backend()->config());
+ throw std::runtime_error("ExecutorFactory: " + op.name() + " is not trainable operation yet");
}
- builder.append(op_seq_index, {&op_seq, lower_info, std::move(fn_seq)});
});
- for (auto &tensor_builder : tensor_builders)
- {
- tensor_builder->allocate();
- }
+ // TODO Create context only once instead of replacing
+ backend::train::TrainableBackendContexts tbackend_contexts;
+ backend::BackendContexts base_backend_contexts =
+ createBackendContexts(*lowered_graph, true, custom_kernel_builder);
- for (auto &pair : backend_contexts)
+ // Replace BackendContext with TrainbleBackendContext
+ for (auto &&pair : base_backend_contexts)
{
- pair.second->initConsts();
- }
-
- lowered_graph->graph().operands().iterate(
- [](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); });
-
- auto code_map = builder.releaseCodeMap();
-
- for (auto &it : code_map)
- {
- auto op_seq_index = it.first;
- auto &fn_seq = it.second.fn_seq;
-
- fn_seq->iterate([&](exec::IFunction &ifunc) {
- ifunc.prepare();
- auto backend = lowered_graph->getLowerInfo(op_seq_index)->backend();
- auto tensor_builder = lowered_graph->backend_contexts().at(backend)->tensor_builder;
- tensor_builder->postFunctionPrepare();
+ auto ctx = pair.second.get();
+ const auto &data = ctx->data();
+
+ // Create partial and trainable graphs
+ auto tgraph = std::make_unique<ir::train::TrainableGraph>(*data.graph);
+ data.graph->operations().iterate(
+ [&](const onert::ir::OperationIndex &op_index, const onert::ir::IOperation &) {
+ const auto &orig_tgraph = lowered_graph->trainable_graph();
+ const auto &trainable_op = orig_tgraph.operation(op_index);
+ auto gen_index = tgraph->replaceOperation(op_index, trainable_op.clone());
+ UNUSED_RELEASE(gen_index);
+ assert(gen_index == op_index);
+ });
+ data.graph->operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) {
+ const auto &orig_tgraph = lowered_graph->trainable_graph();
+ if (orig_tgraph.backward_operands().exist(index))
+ {
+ const auto &bwd_operand = orig_tgraph.backward_operands().at(index);
+ auto new_bwd_operand = std::make_unique<ir::Operand>(bwd_operand);
+ auto gen_index = tgraph->addBackwardOperand(index, std::move(new_bwd_operand));
+ UNUSED_RELEASE(gen_index);
+ assert(gen_index == index);
+ }
});
- }
- backend::TensorManagerSet tensor_mgrs = createTensorManagerSet(tensor_builders);
- auto exec = new exec::LinearExecutor{
- std::move(lowered_graph), input_tensors, output_tensors, tensor_regs,
- std::move(tensor_mgrs), std::move(code_map), order};
+ // Remove outputs of whole graph from external_operands
+ auto external_operands = data.external_operands;
+ for (const auto &index : lowered_graph->trainable_graph().getOutputs())
+ {
+ if (external_operands.contains(index))
+ external_operands.remove(index);
+ }
- if (!options.trace_filepath.empty())
- {
- std::unique_ptr<exec::IExecutionObserver> ctp =
- std::make_unique<exec::ChromeTracingObserver>(options.trace_filepath, exec->graph());
- exec->addObserver(std::move(ctp));
+ // Set trainable context data
+ backend::train::TrainableContextData tdata;
+ tdata.tgraph = std::move(tgraph);
+ tdata.op_order = std::move(data.op_order);
+ tdata.external_operands = std::move(external_operands);
+ tdata.operand_layouts = std::move(data.operand_layouts);
+ tdata.custom_kernel_builder = std::move(data.custom_kernel_builder);
+ tdata.is_linear_executor = data.is_linear_executor;
+ tdata.optim_info = training_info.optimizerInfo();
+
+ // TODO Remove dynamic_cast
+ const auto backend = pair.first;
+ const auto tbackend = dynamic_cast<const backend::train::ITrainableBackend *>(backend);
+ if (!tbackend)
+ {
+ throw std::runtime_error("ExecutorFactory: Invalid backend - TrainableExecutor does not "
+ "support non-trainble backends");
+ }
+ tbackend_contexts.emplace(backend, tbackend->newContext(std::move(tdata)));
}
+ base_backend_contexts.clear();
- return exec;
-}
+ train::TensorRegistries tensor_regs{tbackend_contexts, true};
-exec::IExecutor *ExecutorFactory::createDataflowExecutor(
- std::unique_ptr<compiler::LoweredGraph> lowered_graph, const compiler::CompilerOptions &options,
- const std::shared_ptr<exec::ExecutorMap> &executor_map, bool parallel)
-{
- const auto &backend_contexts = lowered_graph->backend_contexts();
-
- initializeBackendContext(lowered_graph.get());
+ initializeSubgraphIOTensors(
+ *lowered_graph, tbackend_contexts,
+ (lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs()) |
+ ir::Remove::DUPLICATED | ir::Remove::UNDEFINED);
+ // linearize for forwarding
auto order = Linear::linearize(*lowered_graph);
- runTensorRegistration(lowered_graph.get(), order);
+ VERBOSE(ExecutorFactory) << "Linearize for forwarding order" << std::endl;
+ Linear::dump(*lowered_graph, order);
+
+ // linearize for backwarding
+ auto backward_order = lowered_graph->trainable_graph().essentialBackwardOrder();
+ VERBOSE(ExecutorFactory) << "Linearize for backwarding order" << std::endl;
+ Linear::dump(*lowered_graph, backward_order);
- std::vector<std::shared_ptr<backend::ITensor>> input_tensors;
- std::vector<std::shared_ptr<backend::ITensor>> output_tensors;
- if (options.is_primary_subgraph)
+ for (auto &&pair : tbackend_contexts)
{
- input_tensors = initializeModelIOTensors(*lowered_graph, lowered_graph->graph().getInputs());
- output_tensors = initializeModelIOTensors(*lowered_graph, lowered_graph->graph().getOutputs());
+ pair.second->genTensors();
}
- TensorBuilders tensor_builders{lowered_graph->backend_contexts(), true};
- TensorRegistries tensor_regs{lowered_graph->backend_contexts(), true};
-
- // To make tensors never be deallocated, this is a workaround to use static memory planner
- for (auto &tensor_builder : tensor_builders)
+ for (auto &&pair : tbackend_contexts)
{
- lowered_graph->graph().operands().iterate(
- [&](const ir::OperandIndex &ind, const ir::Operand &) {
- if (tensor_builder->isRegistered(ind))
- {
- tensor_builder->notifyFirstUse(ind);
- }
- });
+ auto tctx = pair.second.get();
+ tctx->genTrainingTensors();
}
- for (auto &tensor_builder : tensor_builders)
+ prepareMigrantTensors(*lowered_graph, tbackend_contexts);
+
+ // Give some runtime objects to builtin KernelGenerator
+ for (auto &&pair : tbackend_contexts)
{
- tensor_builder->prepare();
+ auto builtin_context =
+ dynamic_cast<backend::builtin::train::BackendContext *>(pair.second.get());
+ if (builtin_context != nullptr)
+ {
+ auto builtin_kernel_gen = builtin_context->kernel_gen;
+ builtin_kernel_gen->setTensorRegistries(tensor_regs);
+ builtin_kernel_gen->setWholeGraphOutputs(lowered_graph->trainable_graph().getOutputs());
+ }
}
- prepareExternalTensors(*lowered_graph);
+ // Adjust the order of backends for the upcoming iteration
+ auto ordered_contexts =
+ onert::orderBackendContext<backend::train::TrainableBackendContext>(tbackend_contexts);
- ExecutionBuilder builder;
+ // TODO Remove this simulation
+ // Simulate the execution for deallocation of tensors
+ std::unordered_map<ir::OperationIndex, DeallocList> dealloc_list_map;
+ {
+ ir::OperandIndexMap<uint32_t> uses_map;
+ ir::OperandIndexSequence constants;
- // Generate kernels
- lowered_graph->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &op_seq_index,
- const ir::OpSequence &op_seq) {
- auto lower_info = lowered_graph->getLowerInfo(op_seq_index);
- auto kernel_gen = lowered_graph->backend_contexts().at(lower_info->backend())->kernel_gen;
- // Set TensorBuilderSet and ExecutorMap to kernel_gen of control flow
- auto cf_kernel_gen = dynamic_cast<backend::controlflow::KernelGenerator *>(kernel_gen.get());
- if (cf_kernel_gen != nullptr)
+ auto model_io =
+ (graph.getInputs() + graph.getOutputs()) | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
+
+ // Prepare scanning
+ graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) {
+ uses_map[ind] = obj.getUses().size();
+
+ if (obj.isConstant())
+ constants.append(ind);
+ });
+
+ // A trick to consider constants as an execption
+ for (const auto &ind : constants)
{
- assert(cf_kernel_gen != nullptr);
- cf_kernel_gen->setTensorRegistries(tensor_regs);
- cf_kernel_gen->setExecutorMap(executor_map);
+ uses_map[ind]++;
}
- auto fn_seq = kernel_gen->generate(op_seq);
- if (options.he_profiling_mode)
+
+ for (const auto &op_ind : order)
{
- fn_seq->wrap<SyncFunction>(lower_info->backend()->config());
+ const auto &op = graph.operations().at(op_ind);
+ auto op_inputs = op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED;
+ auto op_outputs = op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED;
+
+ for (const auto &ind : op_inputs)
+ {
+ const auto &operand = graph.operands().at(ind);
+ assert(uses_map.find(ind) != uses_map.end());
+ assert(uses_map[ind] > 0);
+ uses_map[ind]--;
+ if (uses_map[ind] == 0 && !operand.info().isVariable() && !model_io.contains(ind))
+ {
+ dealloc_list_map[op_ind].emplace_back(tensor_regs.getITensor(ind));
+ }
+ }
}
- builder.append(op_seq_index, {&op_seq, lower_info, std::move(fn_seq)});
- });
- for (const auto &tensor_builder : tensor_builders)
- {
- tensor_builder->allocate();
- }
+ // Dispose and validate
+ for (const auto &ind : constants)
+ {
+ --uses_map[ind];
+ }
- for (auto &pair : backend_contexts)
- {
- pair.second->initConsts();
+ assert(
+ std::all_of(uses_map.begin(), uses_map.end(),
+ [](std::pair<const ir::OperandIndex, uint32_t> it) { return it.second == 0; }));
}
- lowered_graph->graph().operands().iterate(
- [](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); });
-
- auto code_map = builder.releaseCodeMap();
-
- for (auto &it : code_map)
+ // Check back propagation tensors
{
- auto op_seq_index = it.first;
- auto &fn_seq = it.second.fn_seq;
-
- fn_seq->iterate([&](exec::IFunction &ifunc) {
- ifunc.prepare();
- auto backend = lowered_graph->getLowerInfo(op_seq_index)->backend();
- auto tensor_builder = lowered_graph->backend_contexts().at(backend)->tensor_builder;
- tensor_builder->postFunctionPrepare();
- });
+ // TODO Support multiple subgraphs
+ // Check if the back propagation tensors corresponding to inputs of model are nullptr
+ // NOTE The back propagation tensors corresponding to inputs of model are for inputs of
+ // PermuteLayers
+ // and they are nullptr and because they are meaningless.
+ assert(std::all_of(
+ lowered_graph->trainable_graph().getInputs().begin(),
+ lowered_graph->trainable_graph().getInputs().end(),
+ [&](const auto &input_idx) { return tensor_regs.getBackPropITensor(input_idx) == nullptr; }));
+
+ // Check if the back propagation tensors corresponding to outputs of model exist
+ assert(std::all_of(lowered_graph->trainable_graph().getOutputs().begin(),
+ lowered_graph->trainable_graph().getOutputs().end(),
+ [&](const auto &output_idx) {
+ return tensor_regs.getBackPropITensor(output_idx) == nullptr;
+ }));
}
- backend::TensorManagerSet tensor_mgrs = createTensorManagerSet(tensor_builders);
-
- exec::ExecutorBase *exec = nullptr;
- if (parallel)
- {
- exec = new exec::ParallelExecutor{std::move(lowered_graph), input_tensors,
- output_tensors, tensor_regs,
- std::move(tensor_mgrs), std::move(code_map)};
- }
- else
+ train::TrainableCodeMap code_map;
+ // Generate kernels
+ for (auto &&pair : ordered_contexts)
{
- auto dataflow_exec = new exec::DataflowExecutor{std::move(lowered_graph), input_tensors,
- output_tensors, tensor_regs,
- std::move(tensor_mgrs), std::move(code_map)};
- if (options.he_profiling_mode)
+ auto codes = pair.second->genKernels();
+ for (auto &&pair : codes)
{
- std::vector<const backend::Backend *> backends;
- for (const auto &pair : backend_contexts)
- {
- backends.push_back(pair.first);
- }
- auto et = std::make_shared<exec::ExecTime>(backends);
- std::unique_ptr<exec::IExecutionObserver> obs =
- std::make_unique<exec::ProfileObserver>(et, dataflow_exec->graph());
- dataflow_exec->addObserver(std::move(obs));
+ auto &op_ind = pair.first;
+ auto &tn_seq = pair.second;
+ auto &op = lowered_graph->trainable_graph().operation(op_ind);
+ auto lower_info = lowered_graph->lower_info().operation.getRawPtr(op_ind);
+
+ assert(code_map.find(op_ind) == code_map.end());
+ code_map.insert(
+ {op_ind, train::TrainableCodeAndInfo{op_ind, &op, lower_info, std::move(tn_seq)}});
}
- exec = dataflow_exec;
}
- if (!options.trace_filepath.empty())
+ if (order.size() != code_map.size())
+ {
+ throw std::runtime_error("ExecutorFactory: Some kernels are not generated");
+ }
+
+ auto exec = new exec::train::TrainableExecutor{std::move(lowered_graph),
+ std::move(tbackend_contexts),
+ tensor_regs,
+ std::move(code_map),
+ order,
+ backward_order,
+ tracing_ctx,
+ training_info.lossInfo()};
+
+ if (!options->workspace_dir.empty())
{
- std::unique_ptr<exec::IExecutionObserver> ctp =
- std::make_unique<exec::ChromeTracingObserver>(options.trace_filepath, exec->graph());
- exec->addObserver(std::move(ctp));
+ exec->addObserver(
+ std::make_unique<exec::TracingObserver>(options->workspace_dir, exec->graph(), tracing_ctx));
}
+ // TODO Support MINMAX_H5DUMPER
return exec;
}
diff --git a/runtime/onert/core/src/compiler/ExecutorFactory.h b/runtime/onert/core/src/compiler/ExecutorFactory.h
index b8893c03b..1b9bd4ab6 100644
--- a/runtime/onert/core/src/compiler/ExecutorFactory.h
+++ b/runtime/onert/core/src/compiler/ExecutorFactory.h
@@ -17,18 +17,32 @@
#ifndef __ONERT_COMPILER_EXECUTOR_FACTORY_H__
#define __ONERT_COMPILER_EXECUTOR_FACTORY_H__
-#include <unordered_map>
+#include "TensorRegistries.h"
#include "backend/ITensor.h"
-#include "exec/IExecutor.h"
+#include "backend/train/TrainableBackendContext.h"
#include "compiler/LoweredGraph.h"
-#include "TensorRegistries.h"
+#include "compiler/train/LoweredTrainableGraph.h"
+#include "exec/IExecutors.h"
+#include "ir/train/TrainingInfo.h"
+
+#include <deque>
+#include <unordered_map>
namespace onert
{
namespace compiler
{
+// TODO Change to a better name
+struct ExecutorFactoryArgs
+{
+ const util::TracingCtx *tracing_ctx;
+ const compiler::CompilerOptions *options;
+ ir::ModelIndex model_index;
+ std::shared_ptr<backend::custom::IKernelBuilder> custom_kernel_builder;
+};
+
class ExecutorFactory
{
public:
@@ -36,35 +50,52 @@ public:
public:
exec::IExecutor *create(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const compiler::CompilerOptions &options,
- const std::shared_ptr<exec::ExecutorMap> &executor_map);
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args);
+
+ // TODO Unify create()
+ exec::IExecutor *create(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args,
+ const ir::train::TrainingInfo &training_info);
private:
ExecutorFactory();
private:
- static void initializeBackendContext(compiler::LoweredGraph *lowered_graph);
- static void runTensorRegistration(compiler::LoweredGraph *lowered_graph,
- const std::vector<ir::OpSequenceIndex> &order);
- static std::vector<std::shared_ptr<backend::ITensor>>
- initializeModelIOTensors(compiler::LoweredGraph &lowered_graph,
- const ir::OperandIndexSequence &indices);
- static void prepareExternalTensors(compiler::LoweredGraph &lowered_graph);
+ static void prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph,
+ const backend::BackendContexts &backend_contexts);
+ static void prepareBuiltinBackend(const TensorRegistries &tensor_regs,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const backend::BackendContexts &backend_contexts,
+ const ir::ModelIndex &index);
+ static std::deque<std::pair<const backend::Backend *, backend::BackendContext *>>
+ orderBackendContext(const backend::BackendContexts &backend_contexts);
+
static exec::IExecutor *
createLinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const compiler::CompilerOptions &options,
- const std::shared_ptr<exec::ExecutorMap> &executor_map);
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args);
static exec::IExecutor *
createDataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const compiler::CompilerOptions &options,
- const std::shared_ptr<exec::ExecutorMap> &executor_map, bool parallel);
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args, bool parallel);
+ // TODO Unify prepareMigrantTensors
+ static void
+ prepareMigrantTensors(compiler::ILoweredGraph &lowered_graph,
+ const backend::train::TrainableBackendContexts &backend_contexts);
+ static exec::IExecutor *
+ createTrainableExecutor(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args,
+ const ir::train::TrainingInfo &training_info);
private:
- std::unordered_map<std::string, std::function<exec::IExecutor *(
- std::unique_ptr<compiler::LoweredGraph>,
- const compiler::CompilerOptions &options,
- const std::shared_ptr<exec::ExecutorMap> &executor_map)>>
- _map;
+ std::unordered_map<
+ std::string, std::function<exec::IExecutor *(std::unique_ptr<compiler::LoweredGraph>,
+ const std::shared_ptr<exec::IExecutors> &executors,
+ const ExecutorFactoryArgs &args)>>
+ _map;
};
} // namespace compiler
diff --git a/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc b/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc
index 23a6a253d..ce9b09c2d 100644
--- a/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc
+++ b/runtime/onert/core/src/compiler/Fp32ToFp16Converter.cc
@@ -14,6 +14,8 @@
* limitations under the License.
*/
+#if 0 // This file is temporarily unused
+
#include "Fp32ToFp16Converter.h"
#include "ir/operation/ConvertFp32ToFp16.h"
#include "ir/operation/ConvertFp16ToFp32.h"
@@ -45,7 +47,7 @@ namespace compiler
{
Fp32ToFp16Converter::Fp32ToFp16Converter(compiler::LoweredGraph &lowered_graph)
- : _lowered_graph{lowered_graph}
+ : _lowered_graph{lowered_graph}
{
VERBOSE(Fp32ToFp16Converter) << "Fp16 Enable on" << std::endl;
}
@@ -177,26 +179,26 @@ void Fp32ToFp16Converter::run()
void Fp32ToFp16Converter::appendOpSequences()
{
_lowered_graph.op_seqs().iterate(
- [&](const ir::OpSequenceIndex &op_seq_ind, ir::OpSequence &op_seq) {
- const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
- assert(lower_info != nullptr);
-
- // For now, the only acl_cl supports fully fp16 type
- // TODO Support fp16 on acl_neon. Current acl_neon supports the only reshape and concat
- // operations.
- // To do this, we could check the support by `operation by operation`. After that, we
- // would partition an op_seq if it contains unsupported operations.
- if (lower_info->backend()->config()->id() != kAclClBackendConfigId)
- return;
-
- // OpSeq's input set should be included in the first operation's input set or
- // OpSeq's output set should be included in the last operation's output set
- assert(checkOperandsOfOpSequence(op_seq));
-
- // Append converting OpSequence for fp16 but all operands' types are not fp16 still.
- appendNewOpSeqForConvertFp32ToFp16(op_seq_ind, op_seq);
- appendNewOpSeqForConvertFp16ToFp32(op_seq_ind, op_seq);
- });
+ [&](const ir::OpSequenceIndex &op_seq_ind, ir::OpSequence &op_seq) {
+ const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
+ assert(lower_info != nullptr);
+
+ // For now, the only acl_cl supports fully fp16 type
+ // TODO Support fp16 on acl_neon. Current acl_neon supports the only reshape and concat
+ // operations.
+ // To do this, we could check the support by `operation by operation`. After that, we
+ // would partition an op_seq if it contains unsupported operations.
+ if (lower_info->backend()->config()->id() != kAclClBackendConfigId)
+ return;
+
+ // OpSeq's input set should be included in the first operation's input set or
+ // OpSeq's output set should be included in the last operation's output set
+ assert(checkOperandsOfOpSequence(op_seq));
+
+ // Append converting OpSequence for fp16 but all operands' types are not fp16 still.
+ appendNewOpSeqForConvertFp32ToFp16(op_seq_ind, op_seq);
+ appendNewOpSeqForConvertFp16ToFp32(op_seq_ind, op_seq);
+ });
}
//
@@ -253,7 +255,7 @@ void Fp32ToFp16Converter::appendNewOpSeqForConvertFp32ToFp16(const ir::OpSequenc
const auto new_op_seq_ind = newOpSequence(op_seq_ind, new_node_ind);
// set new lower_info for op_seq
- setNewOpSequenceLowerInfo(op_seq_ind, new_op_seq_ind);
+ setNewOperationLowerInfo(op_seq_ind, new_op_seq_ind);
_list_fp32_to_fp16.insert(new_op_seq_ind);
@@ -326,7 +328,7 @@ void Fp32ToFp16Converter::appendNewOpSeqForConvertFp16ToFp32(const ir::OpSequenc
auto new_op_seq_ind = newOpSequence(op_seq_ind, new_node_ind);
// set new lower_info for op_seq
- setNewOpSequenceLowerInfo(op_seq_ind, new_op_seq_ind);
+ setNewOperationLowerInfo(op_seq_ind, new_op_seq_ind);
_list_fp16_to_fp32.insert(new_op_seq_ind);
@@ -372,16 +374,16 @@ void Fp32ToFp16Converter::optimize()
void Fp32ToFp16Converter::convertOperands()
{
_lowered_graph.op_seqs().iterate(
- [&](const ir::OpSequenceIndex &op_seq_ind, ir::OpSequence &op_seq) {
- const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
- assert(lower_info != nullptr);
- // For now, the only acl_cl supports fully fp16
- if (lower_info->backend()->config()->id() != kAclClBackendConfigId)
- return;
-
- // Convert input,output operands' type to fp16
- convertOperandsOfOpSequence(op_seq);
- });
+ [&](const ir::OpSequenceIndex &op_seq_ind, ir::OpSequence &op_seq) {
+ const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
+ assert(lower_info != nullptr);
+ // For now, the only acl_cl supports fully fp16
+ if (lower_info->backend()->config()->id() != kAclClBackendConfigId)
+ return;
+
+ // Convert input,output operands' type to fp16
+ convertOperandsOfOpSequence(op_seq);
+ });
}
void Fp32ToFp16Converter::convertOperandsOfOpSequence(ir::OpSequence &op_seq)
@@ -391,10 +393,10 @@ void Fp32ToFp16Converter::convertOperandsOfOpSequence(ir::OpSequence &op_seq)
const auto &op_seq_inputs = _lowered_graph.graph().getInputs();
const auto &op_seq_outputs = _lowered_graph.graph().getOutputs();
- for (auto &op_idx : op_seq)
+ for (const auto &op_idx : op_seq)
{
const auto &node = operations.at(op_idx);
- for (auto &ind : node.getInputs() | ir::Remove::UNDEFINED)
+ for (const auto &ind : node.getInputs() | ir::Remove::UNDEFINED)
{
if (node.opcode() == ir::OpCode::ConvertFp32ToFp16 || op_seq_inputs.contains(ind))
continue;
@@ -405,10 +407,10 @@ void Fp32ToFp16Converter::convertOperandsOfOpSequence(ir::OpSequence &op_seq)
obj.type(ir::DataType::FLOAT16);
- VERBOSE(Fp32ToFp16Converter) << "Input Operand #" << ind.value() << ": fp16" << std::endl;
+ VERBOSE(Fp32ToFp16Converter) << "Input Operand " << ind << ": fp16" << std::endl;
}
- for (auto &ind : node.getOutputs())
+ for (const auto &ind : node.getOutputs())
{
if (node.opcode() == ir::OpCode::ConvertFp16ToFp32 || op_seq_outputs.contains(ind))
continue;
@@ -419,7 +421,7 @@ void Fp32ToFp16Converter::convertOperandsOfOpSequence(ir::OpSequence &op_seq)
obj.type(ir::DataType::FLOAT16);
- VERBOSE(Fp32ToFp16Converter) << "Output Operand #" << ind.value() << ": fp16" << std::endl;
+ VERBOSE(Fp32ToFp16Converter) << "Output Operand " << ind << ": fp16" << std::endl;
}
}
}
@@ -444,7 +446,7 @@ void Fp32ToFp16Converter::convertDatas()
obj.data(std::move(new_data));
obj.type(ir::DataType::FLOAT16);
- VERBOSE(Fp32ToFp16Converter) << "Constant Operand #" << ind.value() << ": fp16" << std::endl;
+ VERBOSE(Fp32ToFp16Converter) << "Constant Operand " << ind << ": fp16" << std::endl;
}
});
}
@@ -513,23 +515,23 @@ ir::OperandIndex Fp32ToFp16Converter::newCopiedOperand(const ir::OperandIndex &o
void Fp32ToFp16Converter::setNewOperandLowerInfo(const ir::OpSequenceIndex &op_seq_ind,
const ir::OperandIndex &new_op_ind)
{
- const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
+ const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
assert(lower_info != nullptr);
- auto new_lower_info = std::make_unique<ir::operand::LowerInfo>();
- auto permute_factor = ir::operand::PermuteFactor(lower_info->backend(), lower_info->layout());
+ auto new_lower_info = std::make_unique<compiler::OperandLowerInfo>();
+ auto permute_factor = compiler::PermuteFactor(lower_info->backend(), lower_info->layout());
new_lower_info->addDefPermuteFactor(permute_factor);
new_lower_info->addUsePermuteFactor(permute_factor);
_lowered_graph.setLowerInfo(new_op_ind, std::move(new_lower_info));
}
-void Fp32ToFp16Converter::setNewOpSequenceLowerInfo(const ir::OpSequenceIndex &op_seq_ind,
- const ir::OpSequenceIndex &new_op_seq_ind)
+void Fp32ToFp16Converter::setNewOperationLowerInfo(const ir::OpSequenceIndex &op_seq_ind,
+ const ir::OpSequenceIndex &new_op_seq_ind)
{
- const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
+ const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
assert(lower_info != nullptr);
auto new_lower_info =
- std::make_unique<ir::operation::LowerInfo>(lower_info->backend(), lower_info->layout());
+ std::make_unique<compiler::OperationLowerInfo>(lower_info->backend(), lower_info->layout());
_lowered_graph.setLowerInfo(new_op_seq_ind, std::move(new_lower_info));
}
@@ -600,7 +602,7 @@ Fp32ToFp16Converter::newOperationConvertFp32ToFp16(const ir::OperandIndex &op_se
auto &new_op_obj = operands.at(new_op_ind);
std::unique_ptr<ir::Operation> new_node(
- new ir::operation::ConvertFp32ToFp16({op_seq_input_ind}, {new_op_ind}));
+ new ir::operation::ConvertFp32ToFp16({op_seq_input_ind}, {new_op_ind}));
const auto new_node_ind = operations.push(std::move(new_node));
input_obj.insertUse(new_node_ind);
@@ -620,7 +622,7 @@ Fp32ToFp16Converter::newOperationConvertFp16ToFp32(const ir::OperandIndex &op_se
auto &new_op_obj = operands.at(new_op_ind);
std::unique_ptr<ir::Operation> new_node(
- new ir::operation::ConvertFp16ToFp32({new_op_ind}, {op_seq_output_ind}));
+ new ir::operation::ConvertFp16ToFp32({new_op_ind}, {op_seq_output_ind}));
const auto new_node_ind = operations.push(std::move(new_node));
new_op_obj.insertUse(new_node_ind);
@@ -633,7 +635,7 @@ ir::OpSequenceIndex Fp32ToFp16Converter::newOpSequence(const ir::OpSequenceIndex
const ir::OperationIndex &node_index)
{
auto &node = _lowered_graph.graph().operations().at(node_index);
- const auto lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
+ const auto &lower_info = _lowered_graph.getLowerInfo(op_seq_ind);
assert(lower_info != nullptr);
auto layout = lower_info->layout();
@@ -745,7 +747,7 @@ Fp32ToFp16Converter::findOpSequencesContiguous(const InputToOpSeqs &input_to_op_
// | |
// [OPERATION] [OPERATION]
//
- for (auto &op_seq_ind : found_input_in_op_seqs->second)
+ for (const auto &op_seq_ind : found_input_in_op_seqs->second)
{
auto found_in_fp32_to_fp16 = _list_fp32_to_fp16.find(op_seq_ind);
if (found_in_fp32_to_fp16 != _list_fp32_to_fp16.end())
@@ -759,9 +761,8 @@ Fp32ToFp16Converter::findOpSequencesContiguous(const InputToOpSeqs &input_to_op_
opseq_map_to_delete[op_seq_ind_fp16_to_fp32].insert(op_seq_ind);
}
- VERBOSE(Fp32ToFp16Converter)
- << "Contiguous from OpSeq#" << op_seq_ind_fp16_to_fp32.value() << "(ToFp32)"
- << " to OpSeq#" << op_seq_ind.value() << "(ToFp16)" << std::endl;
+ VERBOSE(Fp32ToFp16Converter) << "Contiguous from " << op_seq_ind_fp16_to_fp32 << "(ToFp32)"
+ << " to " << op_seq_ind << "(ToFp16)" << std::endl;
}
}
}
@@ -775,7 +776,7 @@ Fp32ToFp16Converter::InputToOpSeqs Fp32ToFp16Converter::prepareInputToOpSeqs() c
InputToOpSeqs input_to_op_seqs;
op_seqs.iterate([&](const ir::OpSequenceIndex &op_seq_idx, const ir::OpSequence &op_seq) {
- for (auto input : op_seq.getInputs() | ir::Remove::UNDEFINED)
+ for (auto &&input : op_seq.getInputs() | ir::Remove::UNDEFINED)
{
auto it = input_to_op_seqs.find(input);
if (it == input_to_op_seqs.end())
@@ -798,13 +799,13 @@ Fp32ToFp16Converter::getListOpSequences(const OpSeqIndexToOpSeqIndexList &opseq_
OpSeqIndexList list;
for (const auto &it : opseq_map_to_delete)
{
- auto &opseq_ind_fp16_to_fp32 = it.first;
+ const auto &opseq_ind_fp16_to_fp32 = it.first;
if (list.find(opseq_ind_fp16_to_fp32) == list.end())
{
list.emplace(opseq_ind_fp16_to_fp32);
}
- for (auto &opseq_ind_fp32_to_fp16 : it.second)
+ for (const auto &opseq_ind_fp32_to_fp16 : it.second)
{
if (list.find(opseq_ind_fp32_to_fp16) == list.end())
{
@@ -842,7 +843,7 @@ Fp32ToFp16Converter::findOperationsToDelete(const OpSeqIndexList &list_to_delete
}
void Fp32ToFp16Converter::manipulateContiguousOpSequences(
- const InputToOpSeqs &input_to_op_seqs, const OpSeqIndexToOpSeqIndexList &opseq_map_to_delete)
+ const InputToOpSeqs &input_to_op_seqs, const OpSeqIndexToOpSeqIndexList &opseq_map_to_delete)
{
auto &op_seqs = _lowered_graph.op_seqs();
@@ -861,14 +862,14 @@ void Fp32ToFp16Converter::manipulateContiguousOpSequences(
// |
// [OPERATION] // op_seq_ind_next_to_fp16
//
- for (auto it : opseq_map_to_delete)
+ for (auto &&it : opseq_map_to_delete)
{
// fp16_to_fp32's input/output num is always 1
auto &op_seq_ind_fp16_to_fp32 = it.first;
auto &op_seq_fp16_to_fp32 = op_seqs.at(op_seq_ind_fp16_to_fp32);
auto &input_ind_fp16_to_fp32 = op_seq_fp16_to_fp32.getInputs().at(0);
- for (auto &op_seq_ind_fp32_to_fp16 : it.second)
+ for (const auto &op_seq_ind_fp32_to_fp16 : it.second)
{
auto &op_seq_fp32_to_fp16 = op_seqs.at(op_seq_ind_fp32_to_fp16);
assert(op_seq_fp32_to_fp16.size() == 1);
@@ -878,7 +879,7 @@ void Fp32ToFp16Converter::manipulateContiguousOpSequences(
auto found_next_to_fp16 = input_to_op_seqs.find(output_ind_fp32_to_fp16);
assert(found_next_to_fp16 != input_to_op_seqs.end());
- for (auto &op_seq_ind_next_to_fp16 : found_next_to_fp16->second)
+ for (const auto &op_seq_ind_next_to_fp16 : found_next_to_fp16->second)
{
manipulateInput(op_seq_ind_next_to_fp16, output_ind_fp32_to_fp16, input_ind_fp16_to_fp32);
}
@@ -894,61 +895,62 @@ void Fp32ToFp16Converter::manipulateContiguousOpSequences(
}
void Fp32ToFp16Converter::deleteContiguousOpSequences(
- const OpSeqIndexList &list_to_delete_op_seqs,
- const ir::OperandIndexSequence &list_to_delete_ops)
+ const OpSeqIndexList &list_to_delete_op_seqs, const ir::OperandIndexSequence &list_to_delete_ops)
{
auto &operands = _lowered_graph.graph().operands();
auto &operations = _lowered_graph.graph().operations();
auto &op_seqs = _lowered_graph.op_seqs();
- for (auto &op_seq_ind : list_to_delete_op_seqs)
+ for (const auto &op_seq_ind : list_to_delete_op_seqs)
{
auto &op_seq = op_seqs.at(op_seq_ind);
assert(op_seq.size() == 1);
- VERBOSE(Fp32ToFp16Converter) << "Delete OpSeq #" << op_seq_ind.value() << std::endl;
+ VERBOSE(Fp32ToFp16Converter) << "Delete OpSeq " << op_seq_ind << std::endl;
auto &first_node_ind = op_seq.operations().at(0);
auto &first_node = operations.at(first_node_ind);
assert(first_node.opcode() == ir::OpCode::ConvertFp32ToFp16 ||
first_node.opcode() == ir::OpCode::ConvertFp16ToFp32);
- VERBOSE(Fp32ToFp16Converter) << "Delete Node #" << first_node_ind.value() << std::endl;
+ VERBOSE(Fp32ToFp16Converter) << "Delete Node " << first_node_ind << std::endl;
// Uses
- for (auto &ind : first_node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
+ for (const auto &ind : first_node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
{
auto &obj = operands.at(ind);
obj.removeUse(first_node_ind);
- VERBOSE(Fp32ToFp16Converter) << "Operand #" << ind.value() << "'s Use(Node#"
- << first_node_ind.value() << ") is removed" << std::endl;
+ VERBOSE(Fp32ToFp16Converter)
+ << "Operand " << ind << "'s Use(Node" << first_node_ind << ") is removed" << std::endl;
}
// Def
- for (auto &ind : first_node.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
+ for (const auto &ind : first_node.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
{
auto &obj = operands.at(ind);
assert(obj.getDef() == first_node_ind);
obj.unsetDef();
- VERBOSE(Fp32ToFp16Converter) << "Operand #" << ind.value() << "'s Def(Node#"
- << first_node_ind.value() << ") is removed" << std::endl;
+ VERBOSE(Fp32ToFp16Converter)
+ << "Operand " << ind << "'s Def(Node" << first_node_ind << ") is removed" << std::endl;
}
// Operation
operations.remove(first_node_ind);
- VERBOSE(Fp32ToFp16Converter) << "Node#" << first_node_ind.value() << " is removed" << std::endl;
+ VERBOSE(Fp32ToFp16Converter) << "Node" << first_node_ind << " is removed" << std::endl;
// OpSequence
op_seqs.remove(op_seq_ind);
- VERBOSE(Fp32ToFp16Converter) << "OpSeq#" << op_seq_ind.value() << " is removed" << std::endl;
+ VERBOSE(Fp32ToFp16Converter) << "OpSeq" << op_seq_ind << " is removed" << std::endl;
}
// Operand
- for (auto &ind : list_to_delete_ops)
+ for (const auto &ind : list_to_delete_ops)
{
operands.remove(ind);
- VERBOSE(Fp32ToFp16Converter) << "Operand #" << ind.value() << " is removed" << std::endl;
+ VERBOSE(Fp32ToFp16Converter) << "Operand " << ind << " is removed" << std::endl;
}
}
} // namespace compiler
} // namespace onert
+
+#endif
diff --git a/runtime/onert/core/src/compiler/Fp32ToFp16Converter.h b/runtime/onert/core/src/compiler/Fp32ToFp16Converter.h
index eeecb9846..87751ceb4 100644
--- a/runtime/onert/core/src/compiler/Fp32ToFp16Converter.h
+++ b/runtime/onert/core/src/compiler/Fp32ToFp16Converter.h
@@ -14,6 +14,8 @@
* limitations under the License.
*/
+#if 0 // This file is temporarily unused
+
#ifndef __ONERT_COMPILER_FP32_TO_FP16_CONVERTER_H__
#define __ONERT_COMPILER_FP32_TO_FP16_CONVERTER_H__
@@ -64,8 +66,8 @@ private:
void setNewOperandLowerInfo(const ir::OpSequenceIndex &op_seq_ind,
const ir::OperandIndex &new_op_ind);
- void setNewOpSequenceLowerInfo(const ir::OpSequenceIndex &op_seq_ind,
- const ir::OpSequenceIndex &new_op_seq_ind);
+ void setNewOperationLowerInfo(const ir::OpSequenceIndex &op_seq_ind,
+ const ir::OpSequenceIndex &new_op_seq_ind);
void manipulateInput(const ir::OpSequenceIndex &op_seq_ind,
const ir::OperandIndex &op_seq_input_ind,
@@ -99,3 +101,5 @@ private:
} // namespace onert
#endif // __ONERT_COMPILER_FP32_TO_FP16_CONVERTER_H__
+
+#endif
diff --git a/runtime/onert/core/src/compiler/HEScheduler.cc b/runtime/onert/core/src/compiler/HEScheduler.cc
index 5653b090e..2d04d42ce 100644
--- a/runtime/onert/core/src/compiler/HEScheduler.cc
+++ b/runtime/onert/core/src/compiler/HEScheduler.cc
@@ -14,34 +14,32 @@
* limitations under the License.
*/
-#include "ir/Operand.h"
-#include "compiler/HEScheduler.h"
-#include "ir/Graph.h"
-#include "util/ConfigSource.h"
+#include "HEScheduler.h"
+
#include "compiler/BackendResolver.h"
+#include "ir/Graph.h"
#include "util/logging.h"
-#include "util/Utils.h"
-#include "exec/FunctionSequence.h"
+
#include <cassert>
#include <cmath>
-#include <chrono>
-namespace onert
+namespace
{
-namespace compiler
-{
-static uint32_t getOperationsFlattenedIOSize(const ir::Graph &graph, const ir::Operation &node)
+using namespace onert;
+
+uint32_t getOperationsFlattenedIOSize(const ir::Graph &graph, const ir::IOperation &node)
{
uint32_t size = 0;
- for (const auto &ind : (node.getInputs() | ir::Remove::UNDEFINED) + node.getOutputs())
+ for (const auto &ind :
+ (node.getInputs() | ir::Remove::UNDEFINED) + (node.getOutputs() | ir::Remove::UNDEFINED))
{
size += graph.operands().at(ind).info().total_size();
}
return size;
}
-static bool isQuant(const ir::Graph &graph, const ir::Operation &node)
+bool isQuant(const ir::Graph &graph, const ir::IOperation &node)
{
for (const auto &input : node.getInputs() | ir::Remove::UNDEFINED)
{
@@ -54,18 +52,11 @@ static bool isQuant(const ir::Graph &graph, const ir::Operation &node)
return false;
}
-static bool isWorkaroundSkip(const ir::Graph &, const backend::Backend *, const ir::Operation &,
- bool)
-{
- // Now, there is no workaround
- return false;
-}
-
// if a node can be merged into op_seq
-static bool isMergeable(const ir::Graph &graph, const ir::Operation &node)
+bool isMergeable(const ir::Graph &graph, const ir::IOperation &node)
{
size_t prev_op_cnt = 0;
- for (const auto &input : node.getInputs())
+ for (const auto &input : node.getInputs() | ir::Remove::UNDEFINED)
{
// only valid_inputs
const auto &operand = graph.operands().at(input);
@@ -85,15 +76,23 @@ static bool isMergeable(const ir::Graph &graph, const ir::Operation &node)
return true;
}
+} // namespace
+
+namespace onert
+{
+
+namespace compiler
+{
+
void HEScheduler::scheduleShufflingBackends()
{
VERBOSE(HEScheduler::schedule)
- << "Started task scheduling: uses all backends to get more metrics for data transfer"
- << std::endl;
+ << "Started task scheduling: uses all backends to get more metrics for data transfer"
+ << std::endl;
size_t backend_ind = 0;
for (const auto &rank : _rank_to_op)
{
- VERBOSE(HEScheduler::schedule) << "scheduling (" << rank.second.value() << ")" << std::endl;
+ VERBOSE(HEScheduler::schedule) << "scheduling (" << rank.second << ")" << std::endl;
const auto &node = _graph->operations().at(rank.second);
const bool quant = isQuant(*_graph, node);
const auto size = getOperationsFlattenedIOSize(*_graph, node);
@@ -109,13 +108,8 @@ void HEScheduler::scheduleShufflingBackends()
{
backend_ind = 0;
}
- if (isWorkaroundSkip(*_graph, _all_backends[backend_ind], node, quant))
- {
- ++backend_ind;
- continue;
- }
const auto exec_time =
- _exec_time->getOperationExecTime(_all_backends[backend_ind], node.name(), quant, size);
+ _exec_time->getOperationExecTime(_all_backends[backend_ind], node.name(), quant, size);
// Scheduling to measure data transfer must be done after measuring all backends separately
assert(exec_time != _exec_time->NOT_FOUND);
if (exec_time == _exec_time->getMax())
@@ -132,7 +126,7 @@ void HEScheduler::scheduleShufflingBackends()
}
}
-bool HEScheduler::isNodeProfiled(const ir::Operation &node)
+bool HEScheduler::isNodeProfiled(const ir::IOperation &node)
{
const bool quant = isQuant(*_graph, node);
const auto size = getOperationsFlattenedIOSize(*_graph, node);
@@ -202,7 +196,7 @@ std::unique_ptr<compiler::BackendResolver> HEScheduler::schedule(const ir::Graph
{
// Check if profiling info about all backend/node pairs already exists
bool all_nodes_are_profiled = true;
- _graph->operations().iterate([&](const ir::OperationIndex &, const ir::Operation &op) {
+ _graph->operations().iterate([&](const ir::OperationIndex &, const ir::IOperation &op) {
if (all_nodes_are_profiled)
all_nodes_are_profiled = isNodeProfiled(op);
});
@@ -219,7 +213,7 @@ std::unique_ptr<compiler::BackendResolver> HEScheduler::schedule(const ir::Graph
ir::OperationIndexMap<bool> visited;
graph.operations().iterate(
- [&](const ir::OperationIndex &index, const ir::Operation &) { visited[index] = false; });
+ [&](const ir::OperationIndex &index, const ir::IOperation &) { visited[index] = false; });
// for each task select the backend with the smallest earliest finishing time(eft)
for (const auto &rank : _rank_to_op)
{
@@ -248,19 +242,20 @@ int64_t HEScheduler::getPermuteTime(const backend::Backend *src_backend,
if (time != _exec_time->NOT_FOUND)
return time;
+ // FIXME permute time is not recorded so the control reaches here always
// Makes the scheduler prefer keeping computations on one backend
- return size / 200;
+ return size / 400;
}
-int64_t HEScheduler::tryBackend(const ir::Operation &node, const backend::Backend *backend)
+int64_t HEScheduler::tryBackend(const ir::IOperation &node, const backend::Backend *backend)
{
// if there is no profiling info don't use this backend during scheduling
if (!_is_profiling_mode)
{
VERBOSE(HEScheduler::tryBackend)
- << "Trying to HE schedule while there is no profiling info for " << node.name()
- << " on backend " << backend->config()->id() << ". So this backend won't be used. "
- << std::endl;
+ << "Trying to HE schedule while there is no profiling info for " << node.name()
+ << " on backend " << backend->config()->id() << ". So this backend won't be used. "
+ << std::endl;
_is_supported[backend][node.name()] = false;
return _exec_time->getMax();
}
@@ -291,10 +286,10 @@ void HEScheduler::makeRank()
VERBOSE(HEScheduler::makeRank) << "task prioritizing" << std::endl;
_graph->operations().iterate(
- [&](const ir::OperationIndex &index, const ir::Operation &) { DFSMaxRank(index); });
+ [&](const ir::OperationIndex &index, const ir::IOperation &) { DFSMaxRank(index); });
// Check that ranks are calculated for all operations(nodes)
- _graph->operations().iterate([&](const ir::OperationIndex &index, const ir::Operation &) {
+ _graph->operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &) {
UNUSED_RELEASE(index);
assert(_op_to_rank->find(index) != _op_to_rank->end());
});
@@ -360,8 +355,8 @@ int64_t HEScheduler::DFSMaxRank(const ir::OperationIndex &index)
assert(rank >= 0);
_rank_to_op.emplace(rank, index);
_op_to_rank->emplace(index, rank);
- VERBOSE(HEScheduler::DFSMaxRank) << "rank of operation (" << index.value() << ")" << node.name()
- << " is " << rank << std::endl;
+ VERBOSE(HEScheduler::DFSMaxRank)
+ << "rank of operation (" << index << ")" << node.name() << " is " << rank << std::endl;
return rank;
}
@@ -370,7 +365,7 @@ int64_t HEScheduler::DFSChildrenMaxRank(const ir::OperationIndex &index)
{
const auto &node = _graph->operations().at(index);
int64_t max_child_rank = 0;
- for (const auto &output : node.getOutputs())
+ for (const auto &output : node.getOutputs() | ir::Remove::UNDEFINED)
{
const auto &operand = _graph->operands().at(output);
const bool quant = operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM;
@@ -384,9 +379,9 @@ int64_t HEScheduler::DFSChildrenMaxRank(const ir::OperationIndex &index)
{
continue;
}
- // TODO Change it to controlflow backend
+ // TODO Change it to builtin backend
auto transfer_cost =
- getPermuteTime(backend, other_backend, quant, operand.info().total_size());
+ getPermuteTime(backend, other_backend, quant, operand.info().total_size());
avg_transfer_cost += transfer_cost;
}
}
@@ -403,7 +398,7 @@ int64_t HEScheduler::DFSChildrenMaxRank(const ir::OperationIndex &index)
int64_t HEScheduler::backendAvailableTime(const backend::Backend *backend,
const int64_t &starting_time, const int64_t &time_amount)
{
- const auto backend_times = _backends_avail_time.at(backend);
+ const auto &backend_times = _backends_avail_time.at(backend);
// finishing and starting times of an op, that will come after current op
auto next_op_fst = backend_times.upper_bound(starting_time);
// finishing time of an op, that will come before current op
@@ -419,7 +414,7 @@ int64_t HEScheduler::backendAvailableTime(const backend::Backend *backend,
bool HEScheduler::schedule(const ir::OperationIndex &index, const backend::Backend *parent_backend)
{
- VERBOSE(HEScheduler::schedule) << "scheduling (" << index.value() << ")" << std::endl;
+ VERBOSE(HEScheduler::schedule) << "scheduling (" << index << ")" << std::endl;
int64_t eft = std::numeric_limits<int64_t>::max(), selected_exec_time = 0;
const auto &node = _graph->operations().at(index);
@@ -487,10 +482,6 @@ HEScheduler::ESTAndExecTime(const backend::Backend *backend, const ir::Operation
{
permute_fine *= 2;
}
- if (isWorkaroundSkip(*_graph, backend, node, quant))
- {
- return {_exec_time->getMax(), _exec_time->getMax()};
- }
// get average exec time of the op on this backend
auto exec_time = getOpTime(backend, node.name(), quant, size);
if (backend->config()->id() == "cpu" && _is_parallel_exec)
@@ -506,7 +497,7 @@ HEScheduler::ESTAndExecTime(const backend::Backend *backend, const ir::Operation
// Find free time for data transferring and insert it into backend taskset. This is needed:
// 1. Time for multiple permutations for this node's input is found correctly
// 2. If backend==cpu, then free time for this node must come after permutations
- for (auto &it : transfer_st_exec_time)
+ for (auto &&it : transfer_st_exec_time)
{
if (_is_parallel_exec)
{
@@ -542,27 +533,27 @@ HEScheduler::ESTAndExecTime(const backend::Backend *backend, const ir::Operation
if (!_is_parallel_exec)
{
VERBOSE(HEScheduler::ESTAndExecTime)
- << "exec_time of (" << index.value() << ") " << node.name() << " quant==" << quant << " on "
- << backend->config()->id() << " is " << exec_time
- << " microseconds. Data transfer cost: " << total_transfer_cost << std::endl;
+ << "exec_time of (" << index << ") " << node.name() << " quant==" << quant << " on "
+ << backend->config()->id() << " is " << exec_time
+ << " microseconds. Data transfer cost: " << total_transfer_cost << std::endl;
return {total_transfer_cost, exec_time};
}
VERBOSE(HEScheduler::ESTAndExecTime)
- << "exec_time of (" << index.value() << ") " << node.name() << " quant==" << quant << " on "
- << backend->config()->id() << ": " << exec_time
- << " microseconds. Backend available time: " << prev_op_ft
- << " Parent's max eft: " << max_pred_eft - total_transfer_cost
- << " data transfer cost: " << total_transfer_cost << std::endl;
+ << "exec_time of (" << index << ") " << node.name() << " quant==" << quant << " on "
+ << backend->config()->id() << ": " << exec_time
+ << " microseconds. Backend available time: " << prev_op_ft
+ << " Parent's max eft: " << max_pred_eft - total_transfer_cost
+ << " data transfer cost: " << total_transfer_cost << std::endl;
return {prev_op_ft, exec_time};
}
-int64_t HEScheduler::predMaxEFT(const backend::Backend *backend, const ir::Operation &node,
+int64_t HEScheduler::predMaxEFT(const backend::Backend *backend, const ir::IOperation &node,
std::multimap<int64_t, int64_t> &transfer_st_exec_time)
{
int64_t max_pred_eft = 0;
- for (const auto &input_operand_idx : node.getInputs())
+ for (const auto &input_operand_idx : node.getInputs() | ir::Remove::UNDEFINED)
{
const auto &input_operand = _graph->operands().at(input_operand_idx);
const bool quant = input_operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM;
@@ -578,7 +569,7 @@ int64_t HEScheduler::predMaxEFT(const backend::Backend *backend, const ir::Opera
{
// Multiply operand size by 2 because size must describe input+output size
int64_t transfer_cost =
- getPermuteTime(parent_backend, backend, quant, input_operand.info().total_size() * 2);
+ getPermuteTime(parent_backend, backend, quant, input_operand.info().total_size() * 2);
transfer_st_exec_time.emplace(_ops_eft.at(input_node_idx), transfer_cost);
}
}
diff --git a/runtime/onert/core/src/compiler/HEScheduler.h b/runtime/onert/core/src/compiler/HEScheduler.h
index b9cee5881..df6c07926 100644
--- a/runtime/onert/core/src/compiler/HEScheduler.h
+++ b/runtime/onert/core/src/compiler/HEScheduler.h
@@ -23,14 +23,16 @@
#ifndef __ONERT_COMPILER_H_E_SCHEDULER_H_
#define __ONERT_COMPILER_H_E_SCHEDULER_H_
-#include "compiler/IScheduler.h"
-#include "compiler/BackendManager.h"
-#include "compiler/Compiler.h"
-#include "ir/Graph.h"
-#include "exec/ExecTime.h"
-#include "backend/Backend.h"
-#include <memory>
-#include "ir/OperationIndexMap.h"
+#include "IScheduler.h"
+#include "../backend/builtin/Config.h"
+#include "../exec/ExecTime.h"
+
+#include <backend/Backend.h>
+#include <compiler/BackendManager.h>
+#include <compiler/Compiler.h>
+#include <ir/Graph.h>
+#include <ir/OperationIndexMap.h>
+
#include <map>
#include <memory>
@@ -50,26 +52,26 @@ public:
* @param[in] model Graph model
* @param[in] backend_resolver backend resolver
*/
- HEScheduler(const backend::BackendContexts &backend_contexts, const CompilerOptions &options)
- : _is_supported{}, _backends_avail_time{}, _ops_eft{},
- _op_to_rank{std::make_shared<ir::OperationIndexMap<int64_t>>()},
- _is_profiling_mode{options.he_profiling_mode},
- _is_linear_exec{options.executor == "Linear"},
- _is_parallel_exec{options.executor == "Parallel"}
+ HEScheduler(const std::vector<const backend::Backend *> &backends, const CompilerOptions &options)
+ : _is_supported{}, _backends_avail_time{}, _ops_eft{},
+ _op_to_rank{std::make_shared<ir::OperationIndexMap<int64_t>>()},
+ _is_profiling_mode{options.he_profiling_mode}, _is_linear_exec{options.executor == "Linear"},
+ _is_parallel_exec{options.executor == "Parallel"}
{
- for (auto &entry : backend_contexts)
+ for (auto &&entry : backends)
{
- if (entry.first->config()->id() == backend::controlflow::Config::ID)
+ if (entry->config()->id() == backend::builtin::Config::ID)
continue;
- _all_backends.push_back(entry.first);
+ _all_backends.push_back(entry);
}
_backend_resolver = std::make_unique<compiler::BackendResolver>();
_exec_time = std::make_unique<exec::ExecTime>(_all_backends);
// Find cpu backend
- auto cpu_backend_it = std::find_if(
- _all_backends.begin(), _all_backends.end(),
- [](const backend::Backend *backend) { return backend->config()->id() == "cpu"; });
+ auto cpu_backend_it =
+ std::find_if(_all_backends.begin(), _all_backends.end(), [](const backend::Backend *backend) {
+ return backend->config()->id() == "cpu";
+ });
if (cpu_backend_it == _all_backends.end())
throw std::runtime_error("HEScheduler could be used only if 'cpu' backend is available");
_cpu_backend = *cpu_backend_it;
@@ -86,7 +88,7 @@ public:
std::shared_ptr<ir::OperationIndexMap<int64_t>> getIndexedRanks() { return _op_to_rank; }
private:
- bool isNodeProfiled(const ir::Operation &);
+ bool isNodeProfiled(const ir::IOperation &);
bool schedule(const ir::OperationIndex &, const backend::Backend *parent_backend);
/**
@@ -113,7 +115,7 @@ private:
*
* @return earliest finishing time of parent nodes
*/
- int64_t predMaxEFT(const backend::Backend *backend, const ir::Operation &node,
+ int64_t predMaxEFT(const backend::Backend *backend, const ir::IOperation &node,
std::multimap<int64_t, int64_t> &transfer_st_exec_time);
void makeRank();
@@ -144,7 +146,7 @@ private:
void scheduleShufflingBackends();
- int64_t tryBackend(const ir::Operation &node, const backend::Backend *backend);
+ int64_t tryBackend(const ir::IOperation &node, const backend::Backend *backend);
/**
* @brief Schedule a node and its successor until:
@@ -173,7 +175,7 @@ private:
std::unique_ptr<exec::ExecTime> _exec_time;
const ir::Graph *_graph{nullptr};
std::vector<const backend::Backend *> _all_backends;
- const backend::Backend *_cpu_backend{nullptr}; // TODO Change this to controlflow_backend
+ const backend::Backend *_cpu_backend{nullptr}; // TODO Change this to _builtin_backend
bool _is_profiling_mode;
bool _is_linear_exec;
bool _is_parallel_exec;
diff --git a/runtime/onert/core/src/compiler/HEScheduler.test.cc b/runtime/onert/core/src/compiler/HEScheduler.test.cc
new file mode 100644
index 000000000..505fbbb48
--- /dev/null
+++ b/runtime/onert/core/src/compiler/HEScheduler.test.cc
@@ -0,0 +1,572 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "HEScheduler.h"
+#include "../exec/ExecTime.h"
+
+#include <ir/DataType.h>
+#include <ir/InternalType.h>
+#include <ir/Shape.h>
+#include <ir/TypeInfo.h>
+#include <ir/operation/BinaryArithmetic.h>
+#include <ir/operation/FullyConnected.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+using namespace onert;
+using namespace ir;
+using namespace backend;
+using namespace operation;
+using namespace exec;
+
+//
+// Mock backends classes
+//
+
+struct MockConfigCPU : public IConfig
+{
+ std::string id() override { return "cpu"; }
+ bool initialize() override { return true; };
+ bool supportPermutation() override { return false; }
+ Layout supportLayout(const IOperation &, Layout) override { return Layout::UNKNOWN; }
+ bool supportDynamicTensor() override { return false; }
+ bool supportFP16() override { return false; }
+};
+
+class MockBackendContext : public BackendContext
+{
+public:
+ using BackendContext::BackendContext;
+ ITensorRegistry *genTensors() override { return nullptr; }
+ FunctionMap genKernels() override { return {}; }
+};
+
+struct MockBackendCPU : public Backend
+{
+ std::shared_ptr<IConfig> config() const override { return std::make_shared<MockConfigCPU>(); }
+ std::unique_ptr<BackendContext> newContext(ContextData &&data) const override
+ {
+ return std::make_unique<MockBackendContext>(this, std::move(data), nullptr);
+ }
+};
+
+struct MockConfigGPU : public IConfig
+{
+ std::string id() override { return "gpu"; }
+ bool initialize() override { return true; };
+ bool supportPermutation() override { return false; }
+ ir::Layout supportLayout(const ir::IOperation &, ir::Layout) override
+ {
+ return ir::Layout::UNKNOWN;
+ }
+ bool supportDynamicTensor() override { return false; }
+ bool supportFP16() override { return false; }
+};
+
+struct MockBackendGPU : public Backend
+{
+ std::shared_ptr<IConfig> config() const override { return std::make_shared<MockConfigGPU>(); }
+ std::unique_ptr<BackendContext> newContext(ContextData &&data) const override
+ {
+ return std::make_unique<MockBackendContext>(this, std::move(data), nullptr);
+ }
+};
+
+struct MockConfigNPU : public IConfig
+{
+ std::string id() override { return "npu"; }
+ bool initialize() override { return true; };
+ bool supportPermutation() override { return false; }
+ ir::Layout supportLayout(const ir::IOperation &, ir::Layout) override
+ {
+ return ir::Layout::UNKNOWN;
+ }
+ bool supportDynamicTensor() override { return false; }
+ bool supportFP16() override { return false; }
+};
+
+struct MockBackendNPU : public Backend
+{
+ std::shared_ptr<IConfig> config() const override { return std::make_shared<MockConfigNPU>(); }
+ std::unique_ptr<BackendContext> newContext(ContextData &&data) const override
+ {
+ return std::make_unique<MockBackendContext>(this, std::move(data), nullptr);
+ }
+};
+
+//
+// Constants
+//
+
+const int OPERAND_ELEMS = 268203;
+const int OPERAND_SIZE = OPERAND_ELEMS * 4;
+const int OPERATION_SIZE = OPERAND_SIZE * 3;
+
+const std::string LINEAR("Linear");
+const std::string DATAFLOW("Dataflow");
+const std::string PARALLEL("Parallel");
+
+//
+// Helper functions
+//
+
+// Set executor through environment variable
+void setExecutor(const std::string &executor) { setenv("EXECUTOR", executor.c_str(), true); }
+
+// Set profiling mode through environment variable
+void setProfilingMode(const bool value) { setenv("PROFILING_MODE", value ? "1" : "0", true); }
+
+// Calculate operation size by addition sizes of all input and output operands
+uint32_t calcOpSize(const std::shared_ptr<Graph> &graph, const OperationIndex &op_idx)
+{
+ uint32_t size = 0;
+ const auto &op = graph->operations().at(op_idx);
+ for (const auto &ind : op.getInputs() + op.getOutputs())
+ size += graph->operands().at(ind).info().total_size();
+ return size;
+}
+
+// Set execution operation time. This method is needed since ExecutionTime has only
+// 'updateOperationExecTime' method.
+void setOperationExecTime(ExecTime &et, const Backend *backend, const std::string &operation,
+ bool quant, uint32_t op_size, int64_t time)
+{
+ // You shouldn't set negative time with this method since nnfw JSON deserializer can't read it
+ assert(time > 0);
+ int64_t prev_time = et.getOperationExecTime(backend, operation, quant, op_size);
+ int64_t time_to_set = prev_time == ExecTime::NOT_FOUND ? time : 2 * time - prev_time;
+ et.updateOperationExecTime(backend, operation, quant, op_size, time_to_set);
+ assert(et.getOperationExecTime(backend, operation, quant, op_size) == time);
+}
+
+// Set same execution time for all given backends/operations
+void setOperationsExecutionTime(const std::vector<const Backend *> &backends,
+ const std::vector<std::string> &op_names,
+ const std::vector<uint32_t> &op_sizes, int64_t exec_time)
+{
+ assert(op_names.size() == op_sizes.size());
+ ExecTime et(backends);
+ for (int i = 0; i < op_names.size(); ++i)
+ {
+ for (const auto backend : backends)
+ setOperationExecTime(et, backend, op_names[i], false, op_sizes[i], exec_time);
+ }
+ et.storeOperationsExecTime();
+}
+
+// Set permute time from one backend to another. This method is needed since ExecutionTime has only
+// 'updatePermuteTime' method.
+void setPermutationTime(ExecTime &et, const Backend *from_backend, const Backend *to_backend,
+ bool quant, uint32_t op_size, int64_t time)
+{
+ // You shouldn't set negative time with this method since nnfw JSON deserializer can't read it
+ assert(time > 0);
+ int64_t prev_time = et.getPermuteTime(from_backend, to_backend, quant, op_size);
+ int64_t time_to_set = prev_time == ExecTime::NOT_FOUND ? time : 2 * time - prev_time;
+ et.updatePermuteTime(from_backend, to_backend, quant, op_size, time_to_set);
+ assert(et.getPermuteTime(from_backend, to_backend, quant, op_size) == time);
+}
+
+// Set same permutation time between all given backends
+void setPermutationsExecutionTime(const std::vector<const Backend *> &backends,
+ const int operand_size, const int64_t exec_time)
+{
+ ExecTime et(backends);
+ for (const auto &backend : backends)
+ {
+ for (const auto other_backend : backends)
+ {
+ if (backend == other_backend)
+ continue;
+ setPermutationTime(et, backend, other_backend, false, operand_size, exec_time);
+ }
+ }
+ et.storeOperationsExecTime();
+}
+
+//
+// Functions for creating graphs
+//
+
+using OIS = OperandIndexSequence;
+
+template <typename NodeT, typename... Types>
+OperationIndex create(std::shared_ptr<Graph> graph, Types &&...args)
+{
+ auto op = std::make_unique<NodeT>(std::forward<Types>(args)...);
+ auto op_idx = graph->addOperation(std::move(op));
+ // For now in scheduler test all operations in tested graphs has same size (for simplicity)
+ assert(calcOpSize(graph, op_idx) == OPERATION_SIZE);
+ return op_idx;
+}
+
+// Create straight graph: Add->Sub->Mul
+std::shared_ptr<Graph> createStraightGraph()
+{
+ auto graph = std::make_shared<Graph>();
+ const TypeInfo float_op(DataType::FLOAT32);
+
+ // Create add node
+ auto add_lhs_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto add_rhs_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto add_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ BinaryArithmetic::Param add_op_params{BinaryArithmetic::ArithmeticType::ADD, Activation::NONE};
+ create<BinaryArithmetic>(graph, OIS{add_lhs_idx, add_rhs_idx}, OIS{add_out_idx}, add_op_params);
+
+ // Create sub node
+ auto sub_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto sub_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ BinaryArithmetic::Param sub_op_params{BinaryArithmetic::ArithmeticType::SUB, Activation::NONE};
+ create<BinaryArithmetic>(graph, OIS{add_out_idx, sub_const_idx}, OIS{sub_out_idx}, sub_op_params);
+
+ // Create mul node
+ auto mul_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto mul_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ BinaryArithmetic::Param mul_op_params{BinaryArithmetic::ArithmeticType::MUL, Activation::NONE};
+ create<BinaryArithmetic>(graph, OIS{sub_out_idx, mul_const_idx}, OIS{mul_out_idx}, mul_op_params);
+
+ graph->verify();
+ return graph;
+}
+
+/* Create branched graph:
+ * [Add]
+ * // \\
+ * [Mul1] [FC2]
+ * || ||
+ * [Mul2] [FC2]
+ * \\ //
+ * [Sub]
+ */
+std::shared_ptr<Graph> createBranchedGraph()
+{
+ auto graph = std::make_shared<Graph>();
+ const TypeInfo float_op(DataType::FLOAT32);
+
+ // Create add node
+ auto add_lhs_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto add_rhs_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto add_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ BinaryArithmetic::Param add_op_params{BinaryArithmetic::ArithmeticType::ADD, Activation::NONE};
+ create<BinaryArithmetic>(graph, OIS{add_lhs_idx, add_rhs_idx}, OIS{add_out_idx}, add_op_params);
+
+ // Create mul1 node
+ auto mul1_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto mul1_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ BinaryArithmetic::Param mul1_op_params{BinaryArithmetic::ArithmeticType::MUL, Activation::NONE};
+ create<BinaryArithmetic>(graph, OIS{add_out_idx, mul1_const_idx}, OIS{mul1_out_idx},
+ mul1_op_params);
+
+ // Create mul2 node
+ auto mul2_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto mul2_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ BinaryArithmetic::Param mul2_op_params{BinaryArithmetic::ArithmeticType::MUL, Activation::NONE};
+ create<BinaryArithmetic>(graph, OIS{mul1_out_idx, mul2_const_idx}, OIS{mul2_out_idx},
+ mul2_op_params);
+
+ // Create fc1 node
+ auto fc1_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto fc1_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ FullyConnected::Param fc1_op_params{Activation::NONE};
+ create<FullyConnected>(graph, OIS{add_out_idx, fc1_const_idx}, OIS{fc1_out_idx}, fc1_op_params);
+
+ // Create fc2 node
+ auto fc2_const_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ auto fc2_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ FullyConnected::Param fc2_op_params{Activation::NONE};
+ create<FullyConnected>(graph, OIS{fc1_out_idx, fc2_const_idx}, OIS{fc2_out_idx}, fc2_op_params);
+
+ // Create sub node
+ auto sub_out_idx = graph->addOperand(ir::Shape{OPERAND_ELEMS}, float_op);
+ BinaryArithmetic::Param sub_op_params{BinaryArithmetic::ArithmeticType::SUB, Activation::NONE};
+ create<BinaryArithmetic>(graph, OIS{mul2_out_idx, fc2_out_idx}, OIS{sub_out_idx}, sub_op_params);
+
+ graph->verify();
+ return graph;
+}
+
+//
+// Tests setup/teardown
+//
+
+// SetUp/TearDown methods runs before/after each test and performs actions common for each test
+class HESchedulerTest : public ::testing::Test
+{
+protected:
+ void SetUp() override
+ {
+ // Initialize mock backends
+ _cpu_backend = new MockBackendCPU();
+ _gpu_backend = new MockBackendGPU();
+ _npu_backend = new MockBackendNPU();
+ _mock_backends = {_cpu_backend, _gpu_backend, _npu_backend};
+
+ // Remove previous profile data if it exists
+ if (!remove("exec_time.json"))
+ {
+ // DO NOTHING (no profile data)
+ }
+
+ // Remember original value of 'EXECUTOR' environment variable
+ char *executor = std::getenv("EXECUTOR");
+ _original_executor = executor == nullptr ? "" : executor;
+
+ // Remember original value of 'PROFILING_MODE' environment variable
+ char *profiling_mode = std::getenv("PROFILING_MODE");
+ _original_profiling_mode = profiling_mode == nullptr ? "" : profiling_mode;
+ }
+
+ void TearDown() override
+ {
+ delete _cpu_backend;
+ delete _gpu_backend;
+ delete _npu_backend;
+ EXPECT_EQ(remove("exec_time.json"), 0);
+ setenv("EXECUTOR", _original_executor.c_str(), true);
+ setenv("PROFILING_MODE", _original_profiling_mode.c_str(), true);
+ }
+
+ const MockBackendCPU *_cpu_backend{nullptr};
+ const MockBackendGPU *_gpu_backend{nullptr};
+ const MockBackendNPU *_npu_backend{nullptr};
+ std::vector<const Backend *> _mock_backends;
+
+ std::string _original_executor;
+ std::string _original_profiling_mode;
+};
+
+//
+// HEScheduler tests
+//
+
+class HESchedulerTestWithExecutorParam : public HESchedulerTest,
+ public testing::WithParamInterface<std::string>
+{
+};
+
+// SchedulerTestWithExecutorParam tests are parameterized with executor name and runs three times -
+// one time for each executor
+INSTANTIATE_TEST_SUITE_P(AllExecutors, HESchedulerTestWithExecutorParam,
+ testing::Values(LINEAR, DATAFLOW, PARALLEL));
+
+// Test scheduler behavior for straight graph with known execution time of all nodes and permutes.
+TEST_P(HESchedulerTestWithExecutorParam, straight_graph_known_exec_time)
+{
+ setExecutor(GetParam());
+
+ // Prepare graph
+ ir::Model model;
+ auto graph(createStraightGraph());
+ model.push(ir::SubgraphIndex{0}, graph);
+ OperationIndex add_op_idx(0), sub_op_idx(1), mul_op_idx(2);
+
+ // Set default execution and transfer time
+ setPermutationsExecutionTime(_mock_backends, OPERAND_SIZE, 1);
+ setOperationsExecutionTime(_mock_backends, {"Add", "Sub", "Mul"},
+ {OPERATION_SIZE, OPERATION_SIZE, OPERATION_SIZE}, 1e4);
+
+ // Test 1
+ // Expected behaviour: scheduler assigns different backend to each node
+ {
+ // For each backend reduce execution time of one node
+ ExecTime et(_mock_backends);
+ setOperationExecTime(et, _cpu_backend, "Add", false, OPERATION_SIZE, 1);
+ setOperationExecTime(et, _gpu_backend, "Sub", false, OPERATION_SIZE, 1);
+ setOperationExecTime(et, _npu_backend, "Mul", false, OPERATION_SIZE, 1);
+ et.storeOperationsExecTime();
+
+ // Test scheduler
+ auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig();
+ auto scheduler = compiler::HEScheduler(_mock_backends, coptions);
+ const auto br = scheduler.schedule(*graph);
+ ASSERT_EQ(br->getBackend(add_op_idx)->config()->id(), "cpu");
+ ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "gpu");
+ ASSERT_EQ(br->getBackend(mul_op_idx)->config()->id(), "npu");
+ }
+
+ // Test 2
+ // Expected behaviour: scheduler assigns single backend to all nodes because of big transfer time
+ {
+ // Increase transfer time
+ setPermutationsExecutionTime(_mock_backends, OPERAND_SIZE, 1e5);
+
+ // Test scheduler
+ auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig();
+ auto scheduler = compiler::HEScheduler(_mock_backends, coptions);
+ const auto br = scheduler.schedule(*graph);
+ ASSERT_EQ(br->getBackend(add_op_idx)->config()->id(), "cpu");
+ ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "cpu");
+ ASSERT_EQ(br->getBackend(mul_op_idx)->config()->id(), "cpu");
+ }
+}
+
+// Test scheduler behavior for branched graph with known execution time of all nodes and permutes
+TEST_P(HESchedulerTestWithExecutorParam, branched_graph_known_exec_time)
+{
+ const int64_t NPU_ET = 5000;
+ setExecutor(GetParam());
+
+ // Prepare graph
+ ir::Model model;
+ auto graph(createBranchedGraph());
+ model.push(ir::SubgraphIndex{0}, graph);
+ OperationIndex add_op_idx(0), mul1_op_idx(1), mul2_op_idx(2), fc1_op_idx(3), fc2_op_idx(4),
+ sub_op_idx(5);
+
+ // Set default execution and transfer time
+ setPermutationsExecutionTime(_mock_backends, OPERAND_SIZE, 1000);
+ setOperationsExecutionTime(_mock_backends, {"Add", "Sub", "Mul", "FullyConnected"},
+ {OPERATION_SIZE, OPERATION_SIZE, OPERATION_SIZE, OPERATION_SIZE}, 1e4);
+
+ // Test 1
+ // Expected behaviour: for dataflow and linear executors scheduler assigns fastest backend to all
+ // nodes, in case of parallel executor scheduler assigns different backends to branches.
+ {
+ // Reduce execution time
+ ExecTime et(_mock_backends);
+ setOperationExecTime(et, _npu_backend, "Add", false, OPERATION_SIZE, NPU_ET);
+ setOperationExecTime(et, _npu_backend, "Mul", false, OPERATION_SIZE, NPU_ET);
+ setOperationExecTime(et, _npu_backend, "Sub", false, OPERATION_SIZE, NPU_ET);
+ setOperationExecTime(et, _npu_backend, "FullyConnected", false, OPERATION_SIZE, NPU_ET);
+ setOperationExecTime(et, _gpu_backend, "Mul", false, OPERATION_SIZE, NPU_ET + 1000);
+ setOperationExecTime(et, _gpu_backend, "FullyConnected", false, OPERATION_SIZE, NPU_ET + 1000);
+ et.storeOperationsExecTime();
+
+ // Test scheduler
+ auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig();
+ auto scheduler = compiler::HEScheduler(_mock_backends, coptions);
+ const auto br = scheduler.schedule(*graph);
+
+ std::string branch1_expected_backend("npu"), branch2_expected_backend("npu");
+ if (GetParam() == PARALLEL)
+ {
+ branch1_expected_backend =
+ br->getBackend(mul1_op_idx)->config()->id() == "npu" ? "npu" : "gpu";
+ branch2_expected_backend = branch1_expected_backend == "npu" ? "gpu" : "npu";
+ }
+
+ ASSERT_EQ(br->getBackend(add_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(mul1_op_idx)->config()->id(), branch1_expected_backend);
+ ASSERT_EQ(br->getBackend(mul2_op_idx)->config()->id(), branch1_expected_backend);
+ ASSERT_EQ(br->getBackend(fc1_op_idx)->config()->id(), branch2_expected_backend);
+ ASSERT_EQ(br->getBackend(fc2_op_idx)->config()->id(), branch2_expected_backend);
+ ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "npu");
+ }
+
+ // Test 2
+ // Expected behaviour: scheduler assigns single backend to all nodes
+ {
+ // Increase execution time for GPU backend
+ ExecTime et(_mock_backends);
+ /* for parallel executor: set a time, that is larger than sum_of_other_branches_nodes_cnt *
+ * npu_exec_time so that npu is prefered: the ith branch will wait for npu until it finishes the
+ * [0;i-1] branches nodes in DFS order. In each branch it goes deep intul doesn't encounter
+ * branching or scheduler assigns another backend to a node*/
+ setOperationExecTime(et, _gpu_backend, "Mul", false, OPERATION_SIZE, NPU_ET * 3 + 1);
+ setOperationExecTime(et, _gpu_backend, "FullyConnected", false, OPERATION_SIZE, NPU_ET * 3 + 1);
+ et.storeOperationsExecTime();
+
+ // Test scheduler
+ auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig();
+ auto scheduler = compiler::HEScheduler(_mock_backends, coptions);
+ const auto br = scheduler.schedule(*graph);
+ ASSERT_EQ(br->getBackend(add_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(mul1_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(mul2_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(fc1_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(fc2_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "npu");
+ }
+}
+
+// Test scheduler behavior for branched graph and enabled profiling mode
+TEST_F(HESchedulerTest, branched_graph_profiling_mode)
+{
+ const int ET = 1e5;
+
+ // Turn on profiling mode
+ setProfilingMode(true);
+ setExecutor(DATAFLOW);
+
+ // Prepare graph
+ ir::Model model;
+ auto graph(createBranchedGraph());
+ model.push(ir::SubgraphIndex{0}, graph);
+ OperationIndex add_op_idx(0), mul1_op_idx(1), mul2_op_idx(2), fc1_op_idx(3), fc2_op_idx(4),
+ sub_op_idx(5);
+
+ // Test 1
+ // Expected behaviour: scheduler assigns backends to nodes with unknown execution time
+ {
+ // Set execution time for all backends/nodes except for cpu/Sub, npu/Mul, gpu/FC
+ ExecTime et(_mock_backends);
+ setOperationExecTime(et, _cpu_backend, "Add", false, OPERATION_SIZE, ET);
+ setOperationExecTime(et, _cpu_backend, "Mul", false, OPERATION_SIZE, ET + 1);
+ setOperationExecTime(et, _cpu_backend, "FullyConnected", false, OPERATION_SIZE, ET);
+ setOperationExecTime(et, _npu_backend, "Add", false, OPERATION_SIZE, ET);
+ setOperationExecTime(et, _npu_backend, "FullyConnected", false, OPERATION_SIZE, ET);
+ setOperationExecTime(et, _npu_backend, "Sub", false, OPERATION_SIZE, ET);
+ setOperationExecTime(et, _gpu_backend, "Add", false, OPERATION_SIZE, ET);
+ setOperationExecTime(et, _gpu_backend, "Mul", false, OPERATION_SIZE, ET + 1);
+ setOperationExecTime(et, _gpu_backend, "Sub", false, OPERATION_SIZE, ET);
+ et.storeOperationsExecTime();
+
+ // Test scheduler
+ auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig();
+ auto scheduler = compiler::HEScheduler(_mock_backends, coptions);
+ const auto br = scheduler.schedule(*graph);
+ ASSERT_EQ(br->getBackend(mul1_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(mul2_op_idx)->config()->id(), "npu");
+ ASSERT_EQ(br->getBackend(fc1_op_idx)->config()->id(), "gpu");
+ ASSERT_EQ(br->getBackend(fc2_op_idx)->config()->id(), "gpu");
+ ASSERT_EQ(br->getBackend(sub_op_idx)->config()->id(), "cpu");
+ }
+
+ // Test 2
+ // Expected behaviour: scheduler shuffling backends, so different backends are assigned to
+ // neighbor nodes
+ {
+ // Set execution time for rest backends/nodes (cpu/Sub, npu/Mul, gpu/FC)
+ ExecTime et(_mock_backends);
+ setOperationExecTime(et, _cpu_backend, "Sub", false, OPERATION_SIZE, ET);
+ setOperationExecTime(et, _npu_backend, "Mul", false, OPERATION_SIZE, ET + 1);
+ setOperationExecTime(et, _gpu_backend, "FullyConnected", false, OPERATION_SIZE, ET);
+ et.storeOperationsExecTime();
+
+ // Test scheduler
+ auto coptions = *onert::compiler::CompilerOptions::fromGlobalConfig();
+ auto scheduler = compiler::HEScheduler(_mock_backends, coptions);
+ const auto br = scheduler.schedule(*graph);
+ ASSERT_NE(br->getBackend(add_op_idx)->config()->id(),
+ br->getBackend(mul1_op_idx)->config()->id());
+ ASSERT_NE(br->getBackend(add_op_idx)->config()->id(),
+ br->getBackend(fc1_op_idx)->config()->id());
+ ASSERT_NE(br->getBackend(mul1_op_idx)->config()->id(),
+ br->getBackend(mul2_op_idx)->config()->id());
+ ASSERT_NE(br->getBackend(fc1_op_idx)->config()->id(),
+ br->getBackend(fc2_op_idx)->config()->id());
+ ASSERT_NE(br->getBackend(mul2_op_idx)->config()->id(),
+ br->getBackend(sub_op_idx)->config()->id());
+ ASSERT_NE(br->getBackend(fc2_op_idx)->config()->id(),
+ br->getBackend(sub_op_idx)->config()->id());
+ }
+}
+
+// TODO: Add tests with unknown execution and permutation time
+
+} // unnamed namespace
diff --git a/runtime/onert/core/src/compiler/Linear.cc b/runtime/onert/core/src/compiler/Linear.cc
index 49a989500..663cf5450 100644
--- a/runtime/onert/core/src/compiler/Linear.cc
+++ b/runtime/onert/core/src/compiler/Linear.cc
@@ -14,207 +14,38 @@
* limitations under the License.
*/
-#include <algorithm>
-
#include "Linear.h"
-#include "backend/IConfig.h"
-#include "backend/IConstantInitializer.h"
-#include "backend/ITensorRegister.h"
-#include "backend/Backend.h"
+#include "../dumper/text/GraphDumper.h"
+
#include "util/logging.h"
+#include <sstream>
+
namespace onert
{
namespace compiler
{
-std::vector<ir::OpSequenceIndex> Linear::linearize(const compiler::LoweredGraph &lowered_graph)
+// TODO(easy) Change the LoweredGraph param to Graph
+std::vector<ir::OperationIndex> Linear::linearize(const compiler::ILoweredGraph &lowered_graph)
{
- std::vector<ir::OpSequenceIndex> order;
- lowered_graph.iterateTopolOpSeqs(
- [&](const ir::OpSequenceIndex &index, const ir::OpSequence &) -> void {
- order.emplace_back(index);
- });
- return order;
+ return lowered_graph.graph().topolSortOperations();
}
-void Linear::dump(const compiler::LoweredGraph &lowered_graph,
- const std::vector<ir::OpSequenceIndex> &order)
+// TODO(easy) Change the LoweredGraph param to Graph
+void Linear::dump(const compiler::ILoweredGraph &lowered_graph,
+ const std::vector<ir::OperationIndex> &order)
{
+ for (const auto &ind : order)
{
- const auto &toString = [](const onert::backend::Backend *backend) {
- assert(backend);
- std::string str;
- str += backend->config()->id();
- return "{" + str + "}";
- };
-
- VERBOSE(Linear) << "Final OpSequence" << std::endl;
- for (const auto index : order)
- {
- const auto &op_seq = lowered_graph.op_seqs().at(index);
- const auto lower_info = lowered_graph.getLowerInfo(index);
- const auto &operations = lowered_graph.graph().operations();
- VERBOSE(Linear) << "* OP_SEQ " << toString(lower_info->backend()) << " "
- << ir::getStrFromOpSeq(op_seq, operations) << std::endl;
- }
+ // TODO Could logging system can handle this? (Inserting prefix for each line)
+ std::istringstream iss{dumper::text::formatOperation(lowered_graph.graph(), ind)};
+ std::string line;
+ while (std::getline(iss, line))
+ VERBOSE(Linearize) << line << std::endl;
}
}
-void Linear::planTensors(const compiler::LoweredGraph &lowered_graph,
- const std::vector<ir::OpSequenceIndex> &order)
-{
- const auto &graph = lowered_graph.graph();
- ir::OperandIndexMap<std::shared_ptr<backend::ITensorBuilder>> tensor_builder_map;
-
- ir::OperandIndexMap<uint32_t> uses_map;
- ir::OperandIndexMap<uint32_t> def_map;
- ir::OperandIndexSequence constants;
-
- // Prepare scanning
- graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) {
- const auto lower_info = lowered_graph.getLowerInfo(ind);
- // TODO Remove if onert doesn't support anymore such as
- // GeneratedTests.reshape_quant8_weights_as_inputs
- if (lower_info->def_factors().size() == 0 && lower_info->use_factors().size() == 0 &&
- !graph.getInputs().contains(ind))
- {
- VERBOSE(LINEAR) << "Operand #" << ind.value() << " will not be used. no more process."
- << std::endl;
- return;
- }
-
- // Unused input of subgraph
- // TODO Register unused input as nullptr in tensor_builder
- if (lower_info->def_factors().size() == 0 && lower_info->use_factors().size() == 0 &&
- graph.getInputs().contains(ind))
- {
- VERBOSE(LINEAR) << "Operand #" << ind.value() << " will not be used. no more process."
- << std::endl;
- return;
- }
-
- uses_map[ind] = obj.getUses().size();
- def_map[ind] = obj.getDef().valid() ? 1 : 0;
-
- bool is_const = obj.isConstant();
- if (is_const)
- {
- constants.append(ind);
- }
-
- auto factor = lower_info->def_factors().getOnlyElement();
- auto backend = factor.backend();
- auto tensor_builder = lowered_graph.backend_contexts().at(backend)->tensor_builder;
- if (!tensor_builder->isRegistered(ind))
- {
- // These tensors do not exist in any op_seq (No use and def)
- const auto info = obj.info();
- const auto backend_layout = factor.layout();
- // TODO Change tensor info to have permuted shape
- tensor_builder->registerTensorInfo(ind, info, backend_layout);
- }
-
- tensor_builder_map[ind] = tensor_builder;
- });
-
- // If a tensor is model output, increase the use of the tensor.
- // This aim is same to above one.
- for (const auto &ind : graph.getOutputs() | ir::Remove::DUPLICATED)
- {
- uses_map[ind]++;
- }
-
- // Start scanning to do notify{First|Last}Use for each tensor
-
- // If a tensor is a constant, increase the use of the tensor.
- // It makes the tensor not be dealloced. It means these will be deallocated last.
- // And allocate constant operands first
- VERBOSE(LINEAR) << "TENSORS as CONSTANT" << std::endl;
- for (const auto &ind : constants)
- {
- uses_map[ind]++;
- tensor_builder_map[ind]->notifyFirstUse(ind);
- }
-
- // Allocate Model's inputs
- VERBOSE(LINEAR) << "TENSORS as MODEL INPUT" << std::endl;
- for (const auto &ind : graph.getInputs() | ir::Remove::DUPLICATED)
- {
- auto tensor_builder = tensor_builder_map[ind];
- if (!tensor_builder) // for GeneratedTests.xxx_weights_as_inputs
- continue;
- tensor_builder->notifyFirstUse(ind);
- }
-
- // At each operation,
- // 1. Scan DEF of outputs. If the DEF, allocate it
- // 2. Scan USE of inputs. Decrease the USE and deallocate if the USE is 0
- VERBOSE(LINEAR) << "TENSORS" << std::endl;
- for (const auto op_seq_ind : order)
- {
- const auto &op_seq = lowered_graph.op_seqs().at(op_seq_ind);
- for (const auto &op_idx : op_seq.operations())
- {
- for (const auto &ind : graph.operations().at(op_idx).getOutputs() | ir::Remove::DUPLICATED |
- ir::Remove::UNDEFINED)
- {
- assert(def_map.find(ind) != def_map.end());
- if (def_map[ind])
- {
- def_map[ind] = 0;
- tensor_builder_map[ind]->notifyFirstUse(ind);
- }
- }
-
- for (const auto &ind : graph.operations().at(op_idx).getInputs() | ir::Remove::DUPLICATED |
- ir::Remove::UNDEFINED)
- {
- assert(uses_map.find(ind) != uses_map.end());
- assert(uses_map[ind] > 0);
- uses_map[ind]--;
- if (uses_map[ind] == 0)
- {
- // plan for deallocation of static tensornode
- tensor_builder_map[ind]->notifyLastUse(ind);
-
- // plan for deallocation of dynamic tensor
- auto dyn_tensor_manager = tensor_builder_map[ind]->dynamicTensorManager();
- if (dyn_tensor_manager)
- dyn_tensor_manager->planDealloc(op_idx, ind);
- }
- }
- }
- }
-
- // Dispose and validate
- for (const auto &ind : graph.getOutputs() | ir::Remove::DUPLICATED)
- {
- --uses_map[ind];
- if (uses_map[ind] == 0) // To prevent notifyLastUse from being called twice
- {
- tensor_builder_map[ind]->notifyLastUse(ind);
- }
- }
-
- for (const auto &ind : constants)
- {
- --uses_map[ind];
- if (uses_map[ind] == 0) // To prevent notifyLastUse from being called twice
- {
- tensor_builder_map[ind]->notifyLastUse(ind);
- }
- }
-
- assert(
- std::all_of(uses_map.begin(), uses_map.end(),
- [](std::pair<const ir::OperandIndex, uint32_t> it) { return it.second == 0; }));
-
- assert(
- std::all_of(def_map.begin(), def_map.end(),
- [](std::pair<const ir::OperandIndex, uint32_t> it) { return it.second == 0; }));
-}
-
} // namespace compiler
} // namespace onert
diff --git a/runtime/onert/core/src/compiler/Linear.h b/runtime/onert/core/src/compiler/Linear.h
index 1e24cf92b..4f92dc88d 100644
--- a/runtime/onert/core/src/compiler/Linear.h
+++ b/runtime/onert/core/src/compiler/Linear.h
@@ -20,18 +20,8 @@
#include <vector>
#include <memory>
-#include "ir/OpSequences.h"
#include "ir/Index.h"
-#include "backend/ITensorBuilder.h"
-#include "compiler/LoweredGraph.h"
-
-namespace onert
-{
-namespace ir
-{
-struct OperationVisitor;
-} // namespace ir
-} // namespace onert
+#include "compiler/ILoweredGraph.h"
namespace onert
{
@@ -41,11 +31,9 @@ namespace compiler
class Linear
{
public:
- static std::vector<ir::OpSequenceIndex> linearize(const compiler::LoweredGraph &lowered_graph);
- static void dump(const compiler::LoweredGraph &lowered_graph,
- const std::vector<ir::OpSequenceIndex> &order);
- static void planTensors(const compiler::LoweredGraph &lowered_graph,
- const std::vector<ir::OpSequenceIndex> &order);
+ static std::vector<ir::OperationIndex> linearize(const compiler::ILoweredGraph &lowered_graph);
+ static void dump(const compiler::ILoweredGraph &lowered_graph,
+ const std::vector<ir::OperationIndex> &order);
};
} // namespace compiler
diff --git a/runtime/onert/core/src/compiler/LoweredGraph.cc b/runtime/onert/core/src/compiler/LoweredGraph.cc
index 1489a1884..46a45e44a 100644
--- a/runtime/onert/core/src/compiler/LoweredGraph.cc
+++ b/runtime/onert/core/src/compiler/LoweredGraph.cc
@@ -16,21 +16,23 @@
#include "compiler/LoweredGraph.h"
-#include <assert.h>
-#include <sstream>
-#include "util/logging.h"
-#include "compiler/pass/ConstantInsertionPass.h"
-#include "compiler/pass/ConstantLoweringPass.h"
-#include "compiler/pass/PermutationOperationPass.h"
-#include "compiler/pass/PermutationInsertionPass.h"
-#include "compiler/pass/PermutationEliminationPass.h"
-#include "ir/GraphIterator.h"
-#include "ir/verifier/Verifier.h"
+#include "HEScheduler.h"
+#include "ManualScheduler.h"
+#include "pass/ConstantInsertionPass.h"
+#include "pass/ConstantLoweringPass.h"
+#include "pass/PassRunner.h"
+#include "pass/PermutationEliminationPass.h"
+#include "pass/PermutationInsertionPass.h"
+#include "pass/PermutationOperationPass.h"
+#include "../dumper/text/GraphDumper.h"
+#include "../ir/verifier/Verifier.h"
+
#include "backend/Backend.h"
-#include "backend/IConfig.h"
#include "compiler/BackendResolver.h"
-#include "compiler/ManualScheduler.h"
-#include "compiler/HEScheduler.h"
+#include "util/logging.h"
+
+#include <cassert>
+#include <sstream>
namespace onert
{
@@ -39,18 +41,15 @@ namespace compiler
LoweredGraph::LoweredGraph(const ir::Graph &graph, const CompilerOptions &options) : _graph{graph}
{
- bool linear_executor = (options.executor == "Linear");
+ lowerGraph(options);
+}
+void LoweredGraph::lowerGraph(const CompilerOptions &options)
+{
// Build backend contexts
auto &backend_manager = BackendManager::get();
-
- // Always create Controlflow backend context
- auto cf_backend = backend_manager.getControlflow();
- _backend_contexts.emplace(
- cf_backend, cf_backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor));
-
// Create contexts for other backends
- for (auto backend_str : options.backend_list)
+ for (auto &&backend_str : options.backend_list)
{
backend_manager.loadBackend(backend_str);
auto backend = backend_manager.get(backend_str);
@@ -60,12 +59,9 @@ LoweredGraph::LoweredGraph(const ir::Graph &graph, const CompilerOptions &option
// we should change it back(throw if backend is not loaded) later.
if (!backend)
{
- VERBOSE(LoweredGraph) << "Cannot load backend - " << backend_str;
+ VERBOSE(LoweredGraph) << "Cannot load backend - " << backend_str << std::endl;
continue;
}
-
- _backend_contexts.emplace(
- backend, backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor));
}
if (backend_manager.num_backends() == 0)
throw std::runtime_error{"No available backends loaded."};
@@ -73,317 +69,115 @@ LoweredGraph::LoweredGraph(const ir::Graph &graph, const CompilerOptions &option
// TODO Move "schedule" phase out of here
// Schedule
std::unique_ptr<BackendResolver> backend_resolver;
+ auto all_backends = backend_manager.getAll();
if (options.he_scheduler)
{
- auto scheduler = HEScheduler(_backend_contexts, options);
+ auto scheduler = HEScheduler(all_backends, options);
backend_resolver = scheduler.schedule(_graph);
_indexed_ranks = scheduler.getIndexedRanks();
}
else
{
- auto scheduler = ManualScheduler(_backend_contexts, options);
+ auto scheduler = ManualScheduler(all_backends, options);
backend_resolver = scheduler.schedule(_graph);
}
- {
- // operand::LowerInfo holder
- ir::OperandIndexMap<std::unique_ptr<ir::operand::LowerInfo>> operands_lower_info;
-
- _graph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) {
- operands_lower_info[index] = std::make_unique<ir::operand::LowerInfo>();
- });
-
- // Make op_seqs while checking whether a node can be merged into a op_seq.
- makeOpSequences(operands_lower_info, options, *backend_resolver);
+ makeLowerInfo(*backend_resolver);
+ VERBOSE(LoweredGraph) << "dump before mandatory passes" << std::endl;
+ dumper::text::dumpLoweredGraph(*this);
- _op_seqs.iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
- assert(op_seq.operations().size() > 0);
- std::reverse(std::begin(op_seq.operations()), std::end(op_seq.operations()));
- });
+ // Mandatory passes - kind of legalization(?)
+ pass::PassRunner{}
+ .append(std::make_unique<pass::ConstantInsertionPass>(*this))
+ .append(std::make_unique<pass::ConstantLoweringPass>(*this))
+ .append(std::make_unique<pass::PermutationOperationPass>(*this))
+ .append(std::make_unique<pass::PermutationInsertionPass>(*this))
+ .run();
- VERBOSE(OpSequences) << "dump without permutation" << std::endl;
- dumpOpSequences(_op_seqs, _graph.operations());
+ dumpLowerInfo();
- pass::ConstantInsertionPass ci_pass(*this);
- ci_pass.run();
+ // Optimization passes (optional)
+ pass::PassRunner{}.append(std::make_unique<pass::PermutationEliminationPass>(*this)).run();
- pass::ConstantLoweringPass cl_pass(*this);
- cl_pass.run();
-
- // Set LowerInfo for each operand from the operand::LowerInfo holder
- manipulateLowerInfo(operands_lower_info, options.is_primary_subgraph);
-
- dumpLowerInfo();
- }
-
- // Run Permutation Passes
- {
- pass::PermutationOperationPass po_pass(*this);
- po_pass.run();
-
- pass::PermutationInsertionPass pi_pass(*this);
- pi_pass.run();
-
- pass::PermutationEliminationPass pe_pass(*this);
- pe_pass.run();
-
- VERBOSE(OpSequences) << "dump with permutation" << std::endl;
- dumpOpSequences(_op_seqs, _graph.operations());
- }
+ VERBOSE(LoweredGraph) << "Dump after all the passes" << std::endl;
+ for (auto &&operand : _graph.getInputs())
+ VERBOSE(LoweredGraph) << "Graph Input : " << operand << std::endl;
+ for (auto &&operand : _graph.getOutputs())
+ VERBOSE(LoweredGraph) << "Graph Output : " << operand << std::endl;
+ dumper::text::dumpLoweredGraph(*this);
// Graph verifications
{
+ assert(ir::verifier::InputOutputChecker().verify(_graph));
assert(ir::verifier::DAGChecker().verify(_graph));
- assert(ir::verifier::EdgeConsistencyChecker().verify(_graph));
+ assert(ir::verifier::EdgeChecker().verify(_graph));
}
}
-const ir::operation::LowerInfo *
-LoweredGraph::getLowerInfo(const ir::OpSequenceIndex &op_seq_index) const
+void LoweredGraph::makeLowerInfo(const compiler::BackendResolver &backend_resolver)
{
- auto itr = _lower_info_map.op_seq.find(op_seq_index);
- if (itr == _lower_info_map.op_seq.end())
- return nullptr;
- return itr->second.get();
-}
-
-void LoweredGraph::setLowerInfo(const ir::OpSequenceIndex &op_seq_index,
- std::unique_ptr<ir::operation::LowerInfo> &&lower_info)
-{
- _lower_info_map.op_seq.insert(std::make_pair(op_seq_index, std::move(lower_info)));
-}
+ _graph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) {
+ lower_info().operand.set(index, std::make_unique<OperandLowerInfo>());
+ });
-void LoweredGraph::removeLowerInfo(const ir::OpSequenceIndex &op_seq_index)
-{
- auto &op_seq_lower_info = _lower_info_map.op_seq;
- assert(op_seq_lower_info.find(op_seq_index) != op_seq_lower_info.end());
- for (auto it = op_seq_lower_info.begin(); it != op_seq_lower_info.end(); ++it)
- {
- if (it->first == op_seq_index)
+ // Set operand lower info using assigned backends to operations
+ _graph.operations().iterate([&](const ir::OperationIndex &op_ind, const ir::IOperation &) {
+ const ir::IOperation &op = _graph.operations().at(op_ind);
+ auto backend = backend_resolver.getBackend(op_ind);
+ if (!backend)
{
- op_seq_lower_info.erase(it);
- break;
+ throw std::runtime_error{"Fail to find backend for " + op.name() + " operation"};
}
- }
-}
-
-const ir::operand::LowerInfo *LoweredGraph::getLowerInfo(const ir::OperandIndex &index) const
-{
- auto itr = _lower_info_map.operand.find(index);
- if (itr == _lower_info_map.operand.end())
- return nullptr;
- return itr->second.get();
-}
-
-ir::operand::LowerInfo *LoweredGraph::getLowerInfo(const ir::OperandIndex &index)
-{
- auto itr = _lower_info_map.operand.find(index);
- if (itr == _lower_info_map.operand.end())
- return nullptr;
- return itr->second.get();
-}
-
-void LoweredGraph::setLowerInfo(const ir::OperandIndex &index,
- std::unique_ptr<ir::operand::LowerInfo> &&lower_info)
-{
- _lower_info_map.operand.insert(std::make_pair(index, std::move(lower_info)));
-}
-
-void LoweredGraph::removeLowerInfo(const ir::OperandIndex &index)
-{
- _lower_info_map.operand.erase(index);
-}
-
-void LoweredGraph::iterateTopolOpSeqs(
- const std::function<void(const ir::OpSequenceIndex &, const ir::OpSequence &)> &fn) const
-{
- // Topological Sorting for ir::OpSequences
- std::vector<ir::OpSequenceIndex> topol_sorted;
- ir::PostDfsIterator<true>{}.iterateOpSeqs(
- *this, [&](const ir::OpSequenceIndex &index, const ir::OpSequence &) {
- topol_sorted.emplace_back(index);
- });
- std::reverse(topol_sorted.begin(), topol_sorted.end());
- for (const auto op_seq_idx : topol_sorted)
- {
- const auto &op_seq = _op_seqs.at(op_seq_idx);
- fn(op_seq_idx, op_seq);
- }
-}
-
-void LoweredGraph::iterateTopolOpSeqs(
- const std::function<void(const ir::OpSequenceIndex &, ir::OpSequence &)> &fn)
-{
- // Topological Sorting for ir::OpSequences
- std::vector<ir::OpSequenceIndex> topol_sorted;
- ir::PostDfsIterator<false>{}.iterateOpSeqs(
- *this, [&](const ir::OpSequenceIndex &index, ir::OpSequence &) {
- topol_sorted.emplace_back(index);
- });
- std::reverse(topol_sorted.begin(), topol_sorted.end());
- for (const auto op_seq_idx : topol_sorted)
- {
- auto &op_seq = _op_seqs.at(op_seq_idx);
- fn(op_seq_idx, op_seq);
- }
-}
-
-ir::OpSequenceIndex LoweredGraph::appendFreshSingleOpSequence(const ir::OperationIndex &node_index,
- const ir::Operation &node)
-{
- // Create a fresh op_seq with one operation, and append it to op_seqs
- // Create a fresh op_seq
- auto op_seq = std::make_unique<ir::OpSequence>(_graph.layout());
-
- // Add an operation
- op_seq->appendOperation(node_index);
-
- // Update input/output
- op_seq->setOutputs(node.getOutputs());
- op_seq->setInputs(node.getInputs());
-
- return _op_seqs.emplace(std::move(op_seq));
-}
-
-void LoweredGraph::makeOpSequences(
- ir::OperandIndexMap<std::unique_ptr<ir::operand::LowerInfo>> &operands_lower_info,
- const CompilerOptions &options, const BackendResolver &backend_resolver)
-{
- // if SUBG_MAX_NODE == 0, no limit on nodes of a op_seq
- const int op_seq_max_node = options.op_seq_max_node;
- assert(op_seq_max_node >= 0);
-
- bool is_profiling = options.he_profiling_mode;
- ir::OpSequence *op_seq = nullptr;
- ir::OpSequenceIndex op_seq_index;
-
- // NOTE: The below method appends nodes while making one op_seq if needed. If something better
- // ways, happy to update this code.
- ir::PostDfsConstIterator{}.iterate(
- _graph, [&](const ir::OperationIndex &node_index, const ir::Operation &node) {
- // LowerInfo for in/output operands
- auto backend = backend_resolver.getBackend(node_index);
-
- // Get frontend's layout
- auto frontend_layout = _graph.layout();
-
- // The layout of each backend should be set at another place
- // TODO Change setting layout of each backend at another place
- auto backend_layout = backend->config()->supportLayout(node, frontend_layout);
-
- for (auto operand : node.getInputs() | ir::Remove::UNDEFINED)
- {
- auto &&lower_info = operands_lower_info.at(operand);
- lower_info->addUsePermuteFactor(ir::operand::PermuteFactor{backend, backend_layout});
- }
- for (auto operand : node.getOutputs())
- {
- auto &&lower_info = operands_lower_info.at(operand);
- lower_info->addDefPermuteFactor(ir::operand::PermuteFactor{backend, backend_layout});
- }
-
- bool new_op_seq = (op_seq == nullptr ||
- (op_seq_max_node != 0 &&
- op_seq->operations().size() >= static_cast<size_t>(op_seq_max_node)));
-
- // for profiling each op_seq must contain just one node,
- // so that we can measure a node separately
- if (new_op_seq || is_profiling ||
- !mergeable(op_seq_index, node_index, backend_layout, backend_resolver))
- {
- auto new_op_seq_index = appendFreshSingleOpSequence(node_index, node);
-
- // ir::OpSequence LowerInfo
- setLowerInfo(new_op_seq_index,
- std::make_unique<ir::operation::LowerInfo>(backend, backend_layout));
-
- op_seq_index = new_op_seq_index;
- op_seq = &(_op_seqs.at(new_op_seq_index));
-
- VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is created for "
- << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl;
- }
- else
- {
- op_seq->appendOperation(node_index);
- // Set inputs
- auto new_inputs = node.getInputs();
- // Add inputs except outputs of the previous node
- for (auto ind : op_seq->getInputs())
- {
- if (!node.getOutputs().contains(ind))
- new_inputs.append(ind);
- }
- op_seq->setInputs(new_inputs);
- VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " merges "
- << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl;
- }
- });
-}
+ auto frontend_layout = _graph.layout();
-void LoweredGraph::manipulateLowerInfo(
- ir::OperandIndexMap<std::unique_ptr<ir::operand::LowerInfo>> &operands_lower_info,
- bool is_primary)
-{
- const auto controlflow_backend = BackendManager::get().getControlflow();
+ // The layout of each backend should be set at another place
+ // TODO Change setting layout of each backend at another place
+ auto backend_layout = backend->config()->supportLayout(op, frontend_layout);
- // TODO Rather than handling primary graph specially,
- // let the permute inserted and remove it later
- if (is_primary)
- {
- // TODO Rather than using NHWC Get frontend layout of this node from IR
- auto factor = ir::operand::PermuteFactor{controlflow_backend, ir::Layout::NHWC};
- for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED)
+ for (auto &&ind : op.getInputs() | ir::Remove::UNDEFINED)
{
- auto &&lower_info = operands_lower_info.at(index);
- assert(lower_info->def_factors().empty());
- lower_info->addDefPermuteFactor(factor);
+ auto &operand_li = lower_info().operand.at(ind);
+ operand_li.addUsePermuteFactor(PermuteFactor{backend, backend_layout});
}
- for (auto index : _graph.getOutputs())
+ for (auto &&ind : op.getOutputs() | ir::Remove::UNDEFINED)
{
- auto &&lower_info = operands_lower_info.at(index);
- lower_info->addUsePermuteFactor(factor);
+ auto &operand_li = lower_info().operand.at(ind);
+ operand_li.addDefPermuteFactor(PermuteFactor{backend, backend_layout});
}
- }
- else
+ lower_info().operation.set(
+ op_ind, std::make_unique<compiler::OperationLowerInfo>(backend, backend_layout));
+ });
+
+ // Handle graph inputs and outputs
+ const auto builtin_backend = BackendManager::get().getBuiltin();
+ auto factor = PermuteFactor{builtin_backend, _graph.layout()};
+ for (auto &&index : _graph.getInputs() | ir::Remove::UNDEFINED)
{
- for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED)
- {
- auto &&lower_info = operands_lower_info.at(index);
- if (!(lower_info->def_factors().size() == 0 && lower_info->use_factors().size() == 0))
- {
- // In case of not that Graph's input is not used in any operation and not the graph's
- // output.
- // In other words, it is not unused input in Graph.
- lower_info->addDefPermuteFactor(*lower_info->use_factors().begin());
- }
- else
- {
- // In case of that an operand is Graph's input and not input or output of any operation
- lower_info->addDefPermuteFactor(ir::operand::PermuteFactor{
- controlflow_backend,
- ir::Layout::NHWC // TODO Get frontend layout of this node from IR
- });
- }
- }
+ auto &operand_li = lower_info().operand.at(index);
+ assert(operand_li.def_factors().empty());
+ operand_li.addDefPermuteFactor(factor);
}
- for (auto index : _graph.getOutputs())
+ for (auto &&index : _graph.getOutputs() | ir::Remove::UNDEFINED)
{
- auto &&lower_info = operands_lower_info.at(index);
- if (lower_info->def_factors().size() == 0)
- {
- // In case of that an operand is Graph's output and not input or output of any operation
- lower_info->addDefPermuteFactor(ir::operand::PermuteFactor{
- controlflow_backend,
- ir::Layout::NHWC // TODO Get frontend layout of this node from IR
- });
- }
+ auto &operand_li = lower_info().operand.at(index);
+ operand_li.addUsePermuteFactor(factor);
}
- // Set LowerInfo for each operand from the operand::LowerInfo holder
- _graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &) {
- setLowerInfo(index, std::move(operands_lower_info[index]));
+ // Handle variable tensors
+ _graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &operand) {
+ // Some inputs of an operation could be non-constant, but not existed in graph inputs/outputs
+ // and not undefined operand - these are variable tensors. For example,
+ // UnidirectionalSequenceLSTM has such inputs.
+ if (operand.info().isVariable())
+ {
+ // The variable operand with buffer is not supported yet
+ assert(operand.data() == nullptr);
+ assert(operand.getUses().size() == 1 && !operand.getDef().valid());
+ auto operand_li = lower_info().operand.at(index);
+ assert(operand_li.def_factors().empty());
+ operand_li.addDefPermuteFactor(operand_li.use_factors().getOnlyElement());
+ }
});
}
@@ -395,12 +189,22 @@ void LoweredGraph::dumpLowerInfo()
std::map<uint32_t, std::string> dumps;
_graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &object) {
- std::stringstream sstream;
- if (!getLowerInfo(index)->def_factors().empty() || !getLowerInfo(index)->use_factors().empty())
+ const auto operand_lower_info = lower_info().operand.getRawPtr(index);
+ assert(operand_lower_info);
+ if (!operand_lower_info->def_factors().empty() || !operand_lower_info->use_factors().empty())
{
- auto factors_to_string = [](const ir::operand::PermuteFactorSet &factors) {
+ auto shape_to_string = [](const ir::Shape &shape) {
+ std::stringstream sstream;
+ sstream << "{ ";
+ for (auto i = 0; i < shape.rank(); ++i)
+ sstream << (shape.dim(i)) << " ";
+ sstream << "}";
+ return sstream.str();
+ };
+
+ auto factors_to_string = [](const PermuteFactorSet &factors) {
std::string str;
- for (auto factor : factors)
+ for (auto &&factor : factors)
{
str += factor.backend()->config()->id();
str += "(" + to_string(factor.layout()) + ")";
@@ -409,159 +213,45 @@ void LoweredGraph::dumpLowerInfo()
return "{ " + str + "}";
};
- auto operation_index_to_string = [](const ir::OperationIndexSet &operations) {
- std::string str;
- for (auto op : operations)
- {
- str += std::to_string(op.value());
- str += " ";
- }
- return "{ " + str + "}";
+ auto operation_index_set_to_string = [](const ir::OperationIndexSet &operations) {
+ std::stringstream sstream;
+ sstream << "{ ";
+ for (auto &&op : operations)
+ sstream << op << " ";
+ sstream << "}";
+ return sstream.str();
+ };
+
+ auto data_to_str = [](const ir::Data *data) {
+ return (data ? (std::to_string(data->size()) + " bytes") : "N/A");
};
- const auto lower_info = getLowerInfo(index);
- const auto &shape = object.shape();
- std::string def_ops =
- object.getDef().valid() ? std::to_string(object.getDef().value()) : "N/A";
- std::string use_ops = operation_index_to_string(object.getUses());
- std::string def_layouts = factors_to_string(lower_info->def_factors());
- std::string use_layouts = factors_to_string(lower_info->use_factors());
- sstream << "Operand #" << index.value() << " LowerInfo" << std::endl;
- sstream << " - Shape : { ";
- for (auto i = 0; i < shape.rank(); ++i)
- {
- sstream << (shape.dim(i)) << " ";
- }
- sstream << "}" << std::endl;
- sstream << " - Def ir::Operations : " << def_ops << std::endl;
- sstream << " - Use ir::Operations : " << use_ops << std::endl;
- sstream << " - Lower Info" << std::endl;
- sstream << " - Def Backends : " << def_layouts << std::endl;
- sstream << " - Use Backends : " << use_layouts << std::endl;
+ std::string shape_str = shape_to_string(object.shape());
+ std::string def_op = operation_index_set_to_string({object.getDef()});
+ std::string use_ops = operation_index_set_to_string(object.getUses());
+ std::string def_factors = factors_to_string(operand_lower_info->def_factors());
+ std::string use_factors = factors_to_string(operand_lower_info->use_factors());
+ std::stringstream sstream;
+ sstream << "Operand " << index << " Info" << std::endl;
+ sstream << " - Shape : " << shape_str << std::endl;
+ sstream << " - Def/Uses : Def " << def_op << " Uses " << use_ops << std::endl;
+ sstream << " - Data : " << data_to_str(object.data()) << std::endl;
+ sstream << " - LowerInfo : Def " << def_factors << " Uses " << use_factors << std::endl;
+ dumps.emplace(index.value(), sstream.str());
}
- dumps.emplace(index.value(), sstream.str());
});
for (const auto &e : dumps)
{
if (!e.second.empty())
{
- VERBOSE(Lower) << e.second;
+ std::istringstream iss(e.second);
+ std::string line;
+ while (std::getline(iss, line))
+ VERBOSE(Lower) << line << std::endl;
}
}
}
-bool LoweredGraph::mergeable(const ir::OpSequenceIndex &op_seq_index,
- const ir::OperationIndex &node_index, ir::Layout layout,
- const BackendResolver &backend_resolver)
-{
- // Are they mergeable?
- // 1. the same backend id and layout?
- // 2. Is op_seq or node branched?
- // 3. if 1 is true, the op_seq and a node are connected?
- const auto &op_seq = _op_seqs.at(op_seq_index);
- const auto &node = _graph.operations().at(node_index);
-
- // The same backend id and layout?
- {
- const auto op_seq_backend_layout = getLowerInfo(op_seq_index)->layout();
- const auto &op_seq_backend_id = getLowerInfo(op_seq_index)->backend()->config()->id();
- const auto &node_backend_id = backend_resolver.getBackend(node_index)->config()->id();
- VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " { " << op_seq_backend_id << "("
- << to_string(op_seq_backend_layout) << ") } "
- << " NODE#" << node_index.value() << " (" << node.name() << ") { "
- << node_backend_id << "(" << to_string(layout) << ") } " << std::endl;
- if (op_seq_backend_id != node_backend_id || op_seq_backend_layout != layout)
- return false;
- }
-
- // Branched?
- {
- std::unordered_set<ir::OperationIndex> branched_set;
-
- // Check for branching up
- for (const auto &input : op_seq.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
- {
- const auto &input_obj = _graph.operands().at(input);
- auto def = input_obj.getDef();
- if (def.valid())
- {
- branched_set.insert(def);
- if (branched_set.size() > 1)
- {
- return false;
- }
- }
- }
- branched_set.clear();
-
- // Check for branching down
- for (const auto &output : node.getOutputs() | ir::Remove::DUPLICATED)
- {
- // TODO Fix this workaround for the case of model outputs that are used by another operation
- // This is needed since the branching is decided by operation, but for model outputs,
- // there is controlflow backen(use backend) but no actual use operation exists
- if (_graph.getOutputs().contains(output))
- return false;
-
- const auto &output_obj = _graph.operands().at(output);
- for (const auto &use : output_obj.getUses())
- {
- branched_set.insert(use);
- if (branched_set.size() > 1)
- {
- return false;
- }
- }
- }
- }
-
- // Connected?
- // an input of one node is an output of the other node? or vice-versa?
- {
- const auto &node_inputs = node.getInputs();
- const auto &node_outputs = node.getOutputs();
-
- // op_seq's operations are in order so that we just check the first and the last
- std::vector<ir::OperationIndex> op_seq_ops{op_seq.operations()[0]};
- if (op_seq.operations().size() > 1)
- op_seq_ops.emplace_back(op_seq.operations()[op_seq.operations().size() - 1]);
-
- for (const auto &n_index : op_seq_ops)
- {
- const auto &n = _graph.operations().at(n_index);
-
- // node's output == op_seq's input?
- for (const auto input : n.getInputs() | ir::Remove::UNDEFINED)
- {
- if (node_outputs.contains(input))
- {
- VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value()
- << "(" << n.name() << ") is connected to NODE#" << node_index.value()
- << "(" << node.name() << ")" << std::endl;
- return true;
- }
- }
-
- // node's input == op_seq's output?
- for (const auto output : n.getOutputs())
- {
- if (node_inputs.contains(output))
- {
- VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value()
- << " (" << n.name() << ") is connected to NODE#" << node_index.value()
- << std::endl;
- return true;
- }
- }
- }
-
- VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is not connected to NODE#"
- << node_index.value() << "(" << node.name() << ")" << std::endl;
- }
-
- return false;
-}
-
} // namespace compiler
} // namespace onert
diff --git a/runtime/onert/core/src/compiler/ManualScheduler.cc b/runtime/onert/core/src/compiler/ManualScheduler.cc
index ed49ee56f..ccd08893f 100644
--- a/runtime/onert/core/src/compiler/ManualScheduler.cc
+++ b/runtime/onert/core/src/compiler/ManualScheduler.cc
@@ -29,9 +29,9 @@ namespace onert
namespace compiler
{
-ManualScheduler::ManualScheduler(const backend::BackendContexts &backend_contexts,
+ManualScheduler::ManualScheduler(const std::vector<const backend::Backend *> &backends,
const compiler::CompilerOptions &options)
- : _backend_contexts{backend_contexts}, _options{options}
+ : _backends{backends}, _options{options}
{
}
@@ -42,7 +42,7 @@ std::unique_ptr<BackendResolver> ManualScheduler::schedule(const ir::Graph &grap
// This fallback will be used in case that `backend_for_all` is unavailable
auto fallback = [&]() -> const backend::Backend * {
- for (auto backend_id : _options.backend_list)
+ for (auto &&backend_id : _options.backend_list)
{
auto backend = resolveBackend(backend_id);
if (backend)
@@ -58,20 +58,20 @@ std::unique_ptr<BackendResolver> ManualScheduler::schedule(const ir::Graph &grap
VERBOSE(ManualScheduler) << "Default backend for all ops: " << backend_all->config()->id()
<< std::endl;
- graph.operations().iterate([&](const ir::OperationIndex &index, const ir::Operation &) {
+ graph.operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &) {
backend_resolver->setBackend(index, backend_all);
});
// 2. Backend per operation type
std::unordered_map<ir::OpCode, backend::Backend *> op_type_map;
- for (auto &pair : manual_options.opcode_to_backend)
+ for (const auto &pair : manual_options.opcode_to_backend)
{
op_type_map.emplace(pair.first, BackendManager::get().get(pair.second));
}
// By default, Custom uses cpu backend
op_type_map[ir::OpCode::Custom] = BackendManager::get().get("cpu");
- graph.operations().iterate([&](const ir::OperationIndex &index, const ir::Operation &operation) {
+ graph.operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &operation) {
auto itr = op_type_map.find(operation.opcode());
if (itr != op_type_map.end())
{
@@ -80,7 +80,7 @@ std::unique_ptr<BackendResolver> ManualScheduler::schedule(const ir::Graph &grap
});
// 3. Backend per operation
- for (auto &pair : manual_options.index_to_backend)
+ for (const auto &pair : manual_options.index_to_backend)
{
const auto &key = pair.first;
const auto &val = pair.second;
@@ -88,22 +88,21 @@ std::unique_ptr<BackendResolver> ManualScheduler::schedule(const ir::Graph &grap
try
{
graph.operations().at(key); // Check if exist, or this will throw
- backend_resolver->setBackend(
- key, BackendManager::get().get(
- val)); // TODO Ensure this backend is available in backend contexts
+ backend_resolver->setBackend(key, BackendManager::get().get(val));
}
catch (...)
{
- VERBOSE(ManualScheduler) << "Invalid value while OperationIndex to Backend mapping : @"
- << key.value() << " -> \"" << val << "\"" << std::endl;
+ VERBOSE(ManualScheduler) << "Invalid value while OperationIndex to Backend mapping : @" << key
+ << " -> \"" << val << "\"" << std::endl;
}
}
// Dump final assignment
- backend_resolver->iterate([&](const ir::OperationIndex &index, const backend::Backend &backend) {
- VERBOSE(ManualScheduler) << "backend for operation #" << index.value() << ": "
- << backend.config()->id() << std::endl;
- });
+ WHEN_LOG_ENABLED(backend_resolver->iterate(
+ [&](const ir::OperationIndex &index, const backend::Backend &backend) {
+ VERBOSE(ManualScheduler) << "backend for " << index << ": " << backend.config()->id()
+ << std::endl;
+ }));
return backend_resolver;
}
@@ -113,7 +112,7 @@ const backend::Backend *ManualScheduler::resolveBackend(const std::string &id,
{
// Ensure if the backend is available in the current backend context
const backend::Backend *backend = BackendManager::get().get(id);
- if (!backend || _backend_contexts.find(backend) == _backend_contexts.end())
+ if (!backend || std::find(_backends.begin(), _backends.end(), backend) == _backends.end())
{
backend = fallback;
}
diff --git a/runtime/onert/core/src/compiler/ManualScheduler.h b/runtime/onert/core/src/compiler/ManualScheduler.h
index 41503f7ff..18732d744 100644
--- a/runtime/onert/core/src/compiler/ManualScheduler.h
+++ b/runtime/onert/core/src/compiler/ManualScheduler.h
@@ -28,7 +28,7 @@ namespace compiler
class ManualScheduler : public IScheduler
{
public:
- ManualScheduler(const backend::BackendContexts &backend_contexts,
+ ManualScheduler(const std::vector<const backend::Backend *> &backends,
const compiler::CompilerOptions &options);
std::unique_ptr<BackendResolver> schedule(const ir::Graph &graph) override;
@@ -37,7 +37,7 @@ private:
const backend::Backend *fallback = nullptr);
private:
- const backend::BackendContexts &_backend_contexts;
+ std::vector<const backend::Backend *> _backends;
compiler::CompilerOptions _options;
};
diff --git a/runtime/onert/core/src/compiler/MultiModelCompiler.cc b/runtime/onert/core/src/compiler/MultiModelCompiler.cc
new file mode 100644
index 000000000..7fdf700c7
--- /dev/null
+++ b/runtime/onert/core/src/compiler/MultiModelCompiler.cc
@@ -0,0 +1,230 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MultiModelCompiler.h"
+
+#include "CompilerHelpers.h"
+#include "ExecutorFactory.h"
+#include "ShapeValidator.h"
+#include "pass/ConstantOutputPass.h"
+#include "pass/OddOutputPass.h"
+#include "pass/PassRunner.h"
+#include "pass/UnusedOperandEliminationPass.h"
+#include "../dumper/dot/DotDumper.h"
+#include "../exec/MultiModelExecutors.h"
+#include "../ir/OperationDumper.h"
+#include "../ir/verifier/Verifier.h"
+
+#include "compiler/StaticShapeInferer.h"
+
+#include <misc/string_helpers.h>
+#include <misc/polymorphic_downcast.h>
+
+namespace onert
+{
+namespace compiler
+{
+
+MultiModelCompiler::MultiModelCompiler(const std::shared_ptr<ir::NNPkg> &nnpkg,
+ CompilerOptions *copts)
+ : _nnpkg{nnpkg}, _options{copts}
+{
+ // DO NOTHING
+}
+
+std::shared_ptr<CompilerArtifact> MultiModelCompiler::compile(void)
+{
+ /***************************************************
+ * Prepare compilation phase
+ ***************************************************/
+ {
+ if (!_options)
+ throw std::runtime_error{"Empty compile option"};
+
+ // Mode check
+ // TODO handle option for each model
+ if (_options->he_profiling_mode)
+ throw std::runtime_error("NYI: Profiling mode for multiple model is not supported yet");
+
+ _options->forceInternalOptions();
+ _options->verboseOptions();
+ }
+
+ // NYI: allow one model compilation
+ auto const model_count = _nnpkg->model_count();
+ for (uint16_t i = 0; i < model_count; i++)
+ {
+ if (!_nnpkg->model(ir::ModelIndex{i})->hasOnly<ir::Graph>())
+ throw std::runtime_error("MultiModelCompiler can only compile models for inference.");
+ }
+
+ for (uint16_t i = 0; i < model_count; i++)
+ {
+ _nnpkg->model(ir::ModelIndex{i})->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) {
+ auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph);
+
+ // Mandatory passes
+ pass::PassRunner{}
+ .append(std::make_unique<pass::ConstantOutputPass>(subg))
+ .append(std::make_unique<pass::OddOutputPass>(subg))
+ .run();
+
+ // Optimizations
+ pass::PassRunner{}.append(std::make_unique<pass::UnusedOperandEliminationPass>(subg)).run();
+ });
+ }
+
+ /***************************************************
+ * Backend independent analysis & optimization phase
+ ***************************************************/
+ // TODO Handle dump level for each model
+ auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options->graph_dump_level);
+ onert::dumper::dot::DotDumper dot_dumper(dump_level);
+
+ // Tracing context
+ // TODO Support tracing_ctx for multiple model
+ std::unique_ptr<util::TracingCtx> tracing_ctx = nullptr;
+
+ // Model edge context: copy model edge context
+ auto model_edges = std::make_unique<ir::ModelEdges>(_nnpkg->model_edges());
+
+ // Custom kernels
+ std::unordered_map<ir::ModelIndex, std::shared_ptr<backend::custom::IKernelBuilder>>
+ custom_kernel_builders;
+ for (uint16_t i = 0; i < model_count; i++)
+ {
+ auto const model_index = ir::ModelIndex{i};
+ custom_kernel_builders[model_index] = _nnpkg->model(model_index)->getKernelBuilder();
+ }
+
+ // Lower: Assign backend
+ std::unordered_map<ir::ModelIndex,
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::LoweredGraph>>>
+ lowered_subgs;
+
+ for (uint16_t i = 0; i < model_count; i++)
+ {
+ auto const model_index = ir::ModelIndex{i};
+ auto model = _nnpkg->model(model_index);
+
+ model->iterate([&](const ir::SubgraphIndex &subg_index, ir::IGraph &graph) {
+ auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph);
+
+ dot_dumper.dump(subg,
+ nnfw::misc::str("before_lower_model-", i, "-subg-", subg_index.value()));
+ // Lower: Assign backend
+ lowered_subgs[model_index][subg_index] =
+ std::make_unique<compiler::LoweredGraph>(subg, *_options);
+ // Set tracing_ctx for copied graph
+ if (tracing_ctx != nullptr)
+ tracing_ctx->setSubgraphIndex(&(lowered_subgs[model_index][subg_index]->graph()),
+ subg_index.value());
+ });
+ }
+
+ _nnpkg.reset();
+
+ for (const auto &pair : lowered_subgs)
+ {
+ const auto &model_index = pair.first;
+ const auto &model_lsubg = pair.second;
+
+ for (const auto &pair_inner : model_lsubg)
+ {
+ const auto &subg_index = pair_inner.first;
+ const auto &lowered_subg = pair_inner.second;
+ dot_dumper.dump(*lowered_subg, nnfw::misc::str("after_lower_model-", model_index.value(),
+ "-subg-", subg_index.value()));
+ }
+ }
+
+ // Shape inference.
+ for (auto &&pair : lowered_subgs)
+ {
+ auto &model_lsubgs = pair.second;
+ // Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called
+ // recursively
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers =
+ createStaticShapeInferers(model_lsubgs);
+
+ const auto primary_subg_idx = ir::SubgraphIndex{0};
+ inferers.at(primary_subg_idx)->infer();
+
+ for (const auto &pair_inferer : inferers)
+ {
+ const auto inferer = pair_inferer.second.get();
+ inferer->dump();
+ }
+ }
+
+ // Shape validation
+ // TODO Move shape independent feature check from ShapeValidator to OperationValidator
+ // TODO Move ShapeValidator into shape inference
+ // - Check input tensor shape validation
+ // - Check parameter value validation which valid value is depend on input tensor shape
+ // - Output tensor shape validation check is needless because
+ // static/dynamic shape inferer will make valid output shape
+ for (const auto &pair : lowered_subgs)
+ {
+ const auto &model_lsubgs = pair.second;
+
+ for (const auto &pair_inner : model_lsubgs)
+ {
+ const auto &lowered_subg = pair_inner.second;
+ compiler::ShapeValidator{lowered_subg->graph()}();
+ }
+ }
+
+ /*************************************************************
+ * Backend independent analysis & optimization phase finished
+ *************************************************************/
+ auto executors = std::make_shared<exec::MultiModelExecutors>(std::move(model_edges));
+ for (auto &&pair : lowered_subgs)
+ {
+ auto const &model_index = pair.first;
+ auto &model_lsubgs = pair.second;
+
+ for (auto &&pair_inner : model_lsubgs)
+ {
+ auto const subg_index = pair_inner.first;
+ auto &lowered_subg = pair_inner.second;
+ auto const indexed_ranks = lowered_subg->indexed_ranks();
+
+ ir::OperationDumper dumper("Executor generation of Subgraph " +
+ std::to_string(subg_index.value()));
+ lowered_subg->graph().operations().iterate(
+ [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); });
+
+ ExecutorFactoryArgs args;
+ args.tracing_ctx = tracing_ctx.get();
+ args.options = _options;
+ args.model_index = model_index;
+ args.custom_kernel_builder = custom_kernel_builders[model_index];
+ auto executor = std::unique_ptr<exec::IExecutor>{
+ ExecutorFactory::get().create(std::move(lowered_subg), executors, args)};
+ executor->setIndexedRanks(indexed_ranks);
+ executors->emplace(model_index, subg_index, std::move(executor));
+ }
+ }
+
+ /********************************
+ * Code generation phase finished
+ ********************************/
+ return std::make_shared<CompilerArtifact>(executors, std::move(tracing_ctx));
+}
+
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/MultiModelCompiler.h b/runtime/onert/core/src/compiler/MultiModelCompiler.h
new file mode 100644
index 000000000..7e202a71f
--- /dev/null
+++ b/runtime/onert/core/src/compiler/MultiModelCompiler.h
@@ -0,0 +1,68 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * @file MultiModelCompiler.h
+ * @brief This file contains MultiModelCompiler class to define and run compilation phase
+ */
+
+#ifndef __ONERT_COMPILER_MULTI_MODEL_COMPILER_H__
+#define __ONERT_COMPILER_MULTI_MODEL_COMPILER_H__
+
+#include "compiler/CompilerOptions.h"
+#include "compiler/ICompiler.h"
+#include "ir/NNPkg.h"
+
+namespace onert
+{
+namespace compiler
+{
+
+/**
+ * @brief Class to compile NN package
+ */
+class MultiModelCompiler final : public ICompiler
+{
+public:
+ /**
+ * @brief Construct a new Compiler object for NN package
+ * @param[in] nnpkg NN package to compile
+ * @param[in] copts Compiler option for package
+ */
+ MultiModelCompiler(const std::shared_ptr<ir::NNPkg> &nnpkg, CompilerOptions *copts);
+
+ /**
+ * @brief Destroy the MultiModelCompiler object
+ */
+ ~MultiModelCompiler() = default;
+
+public:
+ /**
+ * @brief Do compilation with the options
+ *
+ * @return std::shared_ptr<CompilerArtifact> MultiModelExecutors as a result of compilation
+ */
+ std::shared_ptr<CompilerArtifact> compile(void);
+
+private:
+ std::shared_ptr<ir::NNPkg> _nnpkg;
+ CompilerOptions *_options;
+};
+
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_MULTI_MODEL_COMPILER_H__
diff --git a/runtime/onert/core/src/ir/operation/LowerInfo.cc b/runtime/onert/core/src/compiler/OperationLowerInfo.cc
index 249918bd6..e8a438130 100644
--- a/runtime/onert/core/src/ir/operation/LowerInfo.cc
+++ b/runtime/onert/core/src/compiler/OperationLowerInfo.cc
@@ -14,21 +14,18 @@
* limitations under the License.
*/
-#include "ir/operation/LowerInfo.h"
+#include "compiler/OperationLowerInfo.h"
namespace onert
{
-namespace ir
-{
-namespace operation
+namespace compiler
{
-LowerInfo::LowerInfo(const backend::Backend *backend, Layout layout)
- : _permute_factor{backend, layout}
+OperationLowerInfo::OperationLowerInfo(const backend::Backend *backend, ir::Layout layout)
+ : _permute_factor{backend, layout}
{
// DO NOTHING
}
-} // namespace operation
-} // namespace ir
+} // namespace compiler
} // namespace onert
diff --git a/runtime/onert/core/src/compiler/OperationValidator.cc b/runtime/onert/core/src/compiler/OperationValidator.cc
deleted file mode 100644
index f7f659e3e..000000000
--- a/runtime/onert/core/src/compiler/OperationValidator.cc
+++ /dev/null
@@ -1,1053 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "OperationValidator.h"
-
-#include <typeinfo>
-
-#include "ir/Graph.h"
-#include "ir/operation/LowerInfo.h"
-
-#include "util/logging.h"
-#include "util/Utils.h"
-
-#define OP_REQUIRES(EXP) \
- do \
- { \
- if (!(EXP)) \
- throw std::runtime_error("OperationValidator failed at line " + std::to_string(__LINE__)); \
- } while (0)
-
-namespace onert
-{
-namespace compiler
-{
-
-OperationValidator::OperationValidator(const ir::Graph &graph)
- : _graph{graph}, _ctx{graph.operands()}, _current_op_seq_layout{ir::Layout::UNKNOWN}
-{
-}
-
-void OperationValidator::checkUnaryOp(const ir::Operation &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto input_index{node.getInputs().at(0)};
-
- // Check if I/O types match
- OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
-
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- // Check if I/O shapes match
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
-
-void OperationValidator::operator()()
-{
- // There is no reason for each subgraph to have subgraphs since compiler has subgraphs when
- // creating Compiler
- assert(_graph.subgraphs() == nullptr);
-
- _current_op_seq_layout = _graph.layout();
-
- _graph.operations().iterate(
- [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); });
-}
-
-void OperationValidator::visit(const ir::operation::BatchMatMul &node)
-{
- const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS));
- const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS));
- const auto out_index{node.getOutputs().at(0)};
-
- // Constant lhs and rhs is not implemented yet
- OP_REQUIRES(!_ctx.at(lhs_index).isConstant() && !_ctx.at(rhs_index).isConstant());
-
- if (_ctx.at(out_index).info().isDynamic())
- return;
-
- OP_REQUIRES(_ctx.at(lhs_index).shape().rank() <= 4);
- OP_REQUIRES(_ctx.at(rhs_index).shape().rank() <= 4);
- OP_REQUIRES(_ctx.at(lhs_index).shape().rank() >= 2);
- OP_REQUIRES(_ctx.at(rhs_index).shape().rank() >= 2);
-}
-
-void OperationValidator::visit(const ir::operation::BatchToSpaceND &node)
-{
- const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
- return;
-
- const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)};
- const auto block_size_index{
- node.getInputs().at(ir::operation::BatchToSpaceND::Input::BLOCK_SIZE)};
-
- const auto frontend_layout = _current_op_seq_layout;
- const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
- const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
-
- // All requirement as per NNAPI specification.
- OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1);
-
- OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2);
-
- OP_REQUIRES(_ctx.at(block_size_index).isConstant());
-
- OP_REQUIRES(input_shape.C == output_shape.C);
-}
-
-void OperationValidator::visit(const ir::operation::Comparison &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- // This validator does not check shape. So checking isDynamic() is skipped.
-
- const auto lhs_index{node.getInputs().at(ir::operation::Comparison::Input::INPUT0)};
- const auto rhs_index{node.getInputs().at(ir::operation::Comparison::Input::INPUT1)};
-
- OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type());
- OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::BOOL8);
-}
-
-void OperationValidator::visit(const ir::operation::Softmax &node)
-{
- VERBOSE(Softmax) << "Configure SOFTMAX operation" << std::endl;
-
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto input_index{node.getInputs().at(0)};
-
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
-}
-
-void OperationValidator::visit(const ir::operation::InstanceNorm &node)
-{
- const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
- return;
-
- const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)};
- const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)};
- const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)};
-
- OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(ifm_index).shape() == _ctx.at(ofm_index).shape());
- OP_REQUIRES(_ctx.at(gamma_index).shape().rank() == 1);
- OP_REQUIRES(_ctx.at(beta_index).shape().rank() == 1);
-}
-
-void OperationValidator::visit(const ir::operation::Pool2D &node)
-{
- const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
- return;
-
- const auto ifm_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)};
-
- OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
-}
-
-void OperationValidator::visit(const ir::operation::Permute &node)
-{
- VERBOSE(Permute) << "Configure Permute operation" << std::endl;
-
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto input_index{node.getInputs().at(0)};
-
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
-}
-
-void OperationValidator::visit(const ir::operation::Reduce &node)
-{
- VERBOSE(Permute) << "Configure " + node.name() + " operation" << std::endl;
-
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)};
- const auto input_shape = _ctx.at(input_index).shape();
- const auto output_shape = _ctx.at(output_index).shape();
-
- OP_REQUIRES(input_shape.rank() <= 4);
- OP_REQUIRES(output_shape.rank() <= input_shape.rank());
-
- // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only
- // supports cases reducing height and width or reducing depth.
- // TODO We have to support all cases of dimensions up to 4.
- // For correct permuting, we have to set output's shape to be equal in dimension position of the
- // input. But the positions of the same dimensions in the input and output may be set differently.
- // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original
- // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to
- // extend it in 4 dimensions, it should be {1,1,3,5}.
- // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of
- // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the
- // next operation is not desired.
- if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank())
- {
- if (output_shape.rank() == 2)
- {
- // Reducing HW
- OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) &&
- input_shape.dim(3) == output_shape.dim(1));
- }
- else if (output_shape.rank() == 3)
- {
- // Reducing C or
- // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1)
- OP_REQUIRES((input_shape.dim(0) == output_shape.dim(0) &&
- input_shape.dim(1) == output_shape.dim(1) &&
- input_shape.dim(2) == output_shape.dim(2)) ||
- (input_shape.dim(0) == output_shape.dim(0) &&
- (input_shape.dim(1) == output_shape.dim(1) ||
- input_shape.dim(2) == output_shape.dim(1)) &&
- input_shape.dim(3) == 1 && output_shape.dim(2) == 1));
- }
- }
-}
-
-void OperationValidator::visit(const ir::operation::Transpose &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)};
- const auto &perm{node.param().perm};
-
- const auto &output_shape = _ctx.at(output_index).shape();
- const auto &input_shape = _ctx.at(input_index).shape();
-
- OP_REQUIRES(input_shape.rank() == static_cast<int>(perm.size()));
- OP_REQUIRES(input_shape.rank() == output_shape.rank());
-}
-
-void OperationValidator::visit(const ir::operation::RNN &node)
-{
- // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
- // TODO Support dynamic rnn
- const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto hidden_state_out_index{
- node.getOutputs().at(ir::operation::RNN::Output::HIDDEN_STATE_OUT)};
-
- const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)};
- const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)};
- const auto recurrent_weights_index{
- node.getInputs().at(ir::operation::RNN::Input::RECURRENT_WEIGHTS)};
- const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)};
- const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)};
-
- const auto batch_size = _ctx.at(output_index).shape().dim(0);
- const auto num_units = _ctx.at(output_index).shape().dim(1);
-
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == 2 &&
- _ctx.at(hidden_state_out_index).shape().rank() == 2 &&
- _ctx.at(input_index).shape().rank() == 2 &&
- _ctx.at(weights_index).shape().rank() == 2 &&
- _ctx.at(recurrent_weights_index).shape().rank() == 2 &&
- _ctx.at(hidden_state_in_index).shape().rank() == 2);
- OP_REQUIRES(_ctx.at(bias_index).shape().rank() == 1);
-
- OP_REQUIRES(batch_size == _ctx.at(input_index).shape().dim(0) &&
- batch_size == _ctx.at(hidden_state_in_index).shape().dim(0) &&
- batch_size == _ctx.at(hidden_state_out_index).shape().dim(0));
- OP_REQUIRES(_ctx.at(input_index).shape().dim(1) == _ctx.at(weights_index).shape().dim(1));
-
- OP_REQUIRES(num_units == _ctx.at(weights_index).shape().dim(0) &&
- num_units == _ctx.at(recurrent_weights_index).shape().dim(0) &&
- num_units == _ctx.at(bias_index).shape().dim(0));
- OP_REQUIRES(num_units == _ctx.at(output_index).shape().dim(1) &&
- num_units == _ctx.at(recurrent_weights_index).shape().dim(1) &&
- num_units == _ctx.at(hidden_state_in_index).shape().dim(1) &&
- num_units == _ctx.at(hidden_state_out_index).shape().dim(1));
-}
-
-void OperationValidator::visit(const ir::operation::SpaceToBatchND &node)
-{
- const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
- return;
-
- const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
- const auto block_size_index{
- node.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
- const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
-
- const auto frontend_layout = _current_op_seq_layout;
- const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
- const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
-
- // All requirement as per NNAPI specification.
- OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(block_size_index).shape().rank() == 1);
- OP_REQUIRES(_ctx.at(paddings_index).shape().rank() == 2);
-
- OP_REQUIRES(_ctx.at(block_size_index).shape().dim(0) == 2);
- OP_REQUIRES(_ctx.at(paddings_index).shape().dim(0) == 2);
- OP_REQUIRES(_ctx.at(paddings_index).shape().dim(1) == 2);
-
- OP_REQUIRES(_ctx.at(block_size_index).isConstant());
- OP_REQUIRES(_ctx.at(paddings_index).isConstant());
-
- OP_REQUIRES(input_shape.C == output_shape.C);
-}
-
-void OperationValidator::visit(const ir::operation::SpaceToDepth &node)
-{
- const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
- return;
-
- const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)};
-
- const auto frontend_layout = _current_op_seq_layout;
- const auto input_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
- const auto output_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
- const auto block_size = node.param().block_size;
-
- // All assertions as per NNAPI specification.
- OP_REQUIRES(_ctx.at(ifm_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
- OP_REQUIRES((block_size >= 1) && (input_shape.H % block_size == 0) &&
- (input_shape.W % block_size == 0));
- OP_REQUIRES(input_shape.N == output_shape.N);
- OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C);
-}
-
-void OperationValidator::visit(const ir::operation::ElementwiseActivation &node)
-{
- checkUnaryOp(node);
-}
-
-void OperationValidator::visit(const ir::operation::ElementwiseBinary &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto lhs_index{node.getInputs().at(ir::operation::ElementwiseBinary::Input::LHS)};
- const auto rhs_index{node.getInputs().at(ir::operation::ElementwiseBinary::Input::RHS)};
-
- OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type());
- OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(output_index).typeInfo().type());
-}
-
-void OperationValidator::visit(const ir::operation::ElementwiseUnary &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto input_index{node.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)};
-
- OP_REQUIRES(node.getInputs().size() == 1);
- OP_REQUIRES(node.getOutputs().size() == 1);
-
- // Check if I/O types match
- if (node.param().op_type == ir::operation::ElementwiseUnary::Type::DEQUANTIZE)
- {
- OP_REQUIRES(_ctx.at(input_index).typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM);
- OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::FLOAT32);
- }
- else if (node.param().op_type == ir::operation::ElementwiseUnary::Type::QUANTIZE)
- {
- OP_REQUIRES(_ctx.at(input_index).typeInfo().type() == ir::DataType::FLOAT32);
- OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM);
- }
- else if (node.param().op_type != ir::operation::ElementwiseUnary::Type::CAST)
- {
- OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
- }
-
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
-
-void OperationValidator::visit(const ir::operation::EmbeddingLookup &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)};
- const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)};
-
- const auto &output_obj = _ctx.at(output_index);
- const auto &lookups_obj = _ctx.at(lookups_index);
- const auto &values_obj = _ctx.at(values_index);
-
- // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying
- // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729)
- {
- OP_REQUIRES(lookups_obj.typeInfo().type() == ir::DataType::INT32);
-
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto &output_shape = output_obj.shape();
- const auto &lookups_shape = lookups_obj.shape();
- const auto &values_shape = values_obj.shape();
-
- OP_REQUIRES(lookups_shape.rank() == 1);
- OP_REQUIRES(values_shape.rank() >= 2);
-
- // output should be a n-D tensor with the same rank and shape as the values tensor, except for
- // the first dimension which has the same size as lookups' only dimension.
- OP_REQUIRES(output_shape.rank() == values_shape.rank());
- OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0));
- for (int n = 1; n < output_shape.rank(); ++n)
- {
- OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n));
- }
- }
-}
-
-void OperationValidator::visit(const ir::operation::ExpandDims &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto input_index{node.getInputs().at(ir::operation::ExpandDims::Input::INPUT)};
- const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
-
- OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
- OP_REQUIRES(_ctx.at(axis_index).typeInfo().type() == ir::DataType::INT32);
-
- if (_ctx.at(axis_index).info().isDynamic())
- return;
- OP_REQUIRES(_ctx.at(axis_index).shape().rank() <= 1);
-}
-
-void OperationValidator::visit(const ir::operation::HashtableLookup &node)
-{
- const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)};
- const auto hits_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::HITS)};
-
- const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)};
- const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)};
- const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)};
-
- const auto &output_obj = _ctx.at(output_index);
- const auto &hits_obj = _ctx.at(hits_index);
-
- const auto &lookups_obj = _ctx.at(lookups_index);
- const auto &keys_obj = _ctx.at(keys_index);
- const auto &values_obj = _ctx.at(values_index);
-
- OP_REQUIRES(lookups_obj.typeInfo().type() == ir::DataType::INT32);
- OP_REQUIRES(keys_obj.typeInfo().type() == ir::DataType::INT32);
- OP_REQUIRES(hits_obj.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM);
-
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto &output_shape = output_obj.shape();
- const auto &lookups_shape = lookups_obj.shape();
- const auto &keys_shape = keys_obj.shape();
- const auto &values_shape = values_obj.shape();
-
- OP_REQUIRES(values_shape.rank() == output_shape.rank());
- OP_REQUIRES(lookups_shape.rank() == 1);
- OP_REQUIRES(keys_shape.rank() == 1);
- OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0));
- OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0));
-}
-
-void OperationValidator::visit(const ir::operation::TransposeConv &node)
-{
- // param check
- OP_REQUIRES((node.param().padding.type == ir::PaddingType::SAME) ||
- (node.param().padding.type == ir::PaddingType::VALID));
-
- // shape check
- const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
- return;
-
- const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)};
- const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)};
-
- // Only 4D tensors are supported
- OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ifm_index).shape().rank());
- OP_REQUIRES(_ctx.at(ofm_index).shape().rank() == _ctx.at(ker_index).shape().rank());
-
- const auto frontend_layout = _current_op_seq_layout;
- const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(frontend_layout);
- const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature(frontend_layout);
- // The kernel has only IHWO layout on frontend
- // So ker_shape is treated here below
- // I -> N
- // H -> H
- // W -> W
- // O -> C
- const auto ker_shape = _ctx.at(ker_index).shape().asFeature(ir::Layout::NHWC);
-
- OP_REQUIRES(ifm_shape.N == ofm_shape.N);
- OP_REQUIRES(ifm_shape.C == ker_shape.C);
- OP_REQUIRES(ker_shape.N == ofm_shape.C);
-}
-
-void OperationValidator::visit(const ir::operation::Gather &node)
-{
- const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
- return;
-
- const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
- const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
-
- const auto ifm_shape = _ctx.at(ifm_index).shape();
- const auto indices_shape = _ctx.at(indices_index).shape();
- const auto ofm_shape = _ctx.at(ofm_index).shape();
-
- OP_REQUIRES(ifm_shape.rank() <= 4);
- OP_REQUIRES(indices_shape.rank() <= 3);
- OP_REQUIRES(ofm_shape.rank() <= 4);
-}
-
-void OperationValidator::visit(const ir::operation::DepthToSpace &node)
-{
- // param check
- int32_t block_size = node.param().block_size;
-
- OP_REQUIRES(block_size > 0);
-
- // shape check
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)};
-
- const auto frontend_layout = _current_op_seq_layout;
- const auto output_shape = _ctx.at(output_index).shape().asFeature(frontend_layout);
- const auto input_shape = _ctx.at(input_index).shape().asFeature(frontend_layout);
-
- OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
-
- {
- OP_REQUIRES(output_shape.N == input_shape.N);
- OP_REQUIRES(output_shape.H == input_shape.H * block_size);
- OP_REQUIRES(output_shape.W == input_shape.W * block_size);
- OP_REQUIRES(input_shape.C % (block_size * block_size) == 0);
- OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size));
- }
-}
-
-void OperationValidator::visit(const ir::operation::Pack &node)
-{
- // param check
- const auto num{node.param().num};
- const auto axis{node.param().axis};
- OP_REQUIRES(num == static_cast<int32_t>(node.getInputs().size()));
-
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- // shape check
- const auto &output_shape = _ctx.at(output_index).shape();
- const auto output_rank = static_cast<int32_t>(output_shape.rank());
-
- const auto input1_index{node.getInputs().at(0)};
- const auto input_shape = _ctx.at(input1_index).shape();
-
- OP_REQUIRES(axis >= -output_rank && axis < output_rank);
- for (const auto &index : node.getInputs())
- {
- OP_REQUIRES(input_shape == _ctx.at(index).shape());
- }
-}
-
-void OperationValidator::visit(const ir::operation::LSTM &node)
-{
- // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
- // TODO Support dynamic rnn
- const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto scratch_buffer_index{
- node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)};
- const auto output_state_out_index{
- node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)};
- const auto cell_state_out_index{
- node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)};
-
- const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
- const auto input_to_input_weights_index{
- node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)};
- const auto input_to_forget_weights_index{
- node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_FORGET_WEIGHTS)};
- const auto input_to_cell_weights_index{
- node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_CELL_WEIGHTS)};
- const auto input_to_output_weights_index{
- node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
- const auto recurrent_to_input_weights_index{
- node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)};
- const auto recurrent_to_forget_weights_index{
- node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)};
- const auto recurrent_to_cell_weights_index{
- node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)};
- const auto recurrent_to_output_weights_index{
- node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
- const auto cell_to_input_weights_index{
- node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)};
- const auto cell_to_forget_weights_index{
- node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)};
- const auto cell_to_output_weights_index{
- node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)};
- const auto input_gate_bias_index{
- node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)};
- const auto forget_gate_bias_index{
- node.getInputs().at(ir::operation::LSTM::Input::FORGET_GATE_BIAS)};
- const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
- const auto output_gate_bias_index{
- node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_GATE_BIAS)};
- const auto projection_weights_index{
- node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)};
- const auto projection_bias_index{
- node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)};
- const auto output_state_in_index{
- node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)};
- const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
-
- OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().rank() == 2 &&
- _ctx.at(output_state_out_index).shape().rank() == 2 &&
- _ctx.at(cell_state_out_index).shape().rank() == 2 &&
- _ctx.at(output_index).shape().rank() == 2 &&
- _ctx.at(input_index).shape().rank() == 2 &&
- _ctx.at(input_to_input_weights_index).shape().rank() == 2 &&
- _ctx.at(input_to_forget_weights_index).shape().rank() == 2 &&
- _ctx.at(input_to_cell_weights_index).shape().rank() == 2 &&
- _ctx.at(input_to_output_weights_index).shape().rank() == 2 &&
- _ctx.at(recurrent_to_input_weights_index).shape().rank() == 2 &&
- _ctx.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
- _ctx.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
- _ctx.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
- _ctx.at(projection_weights_index).shape().rank() == 2 &&
- _ctx.at(output_state_in_index).shape().rank() == 2 &&
- _ctx.at(cell_state_in_index).shape().rank() == 2);
-
- OP_REQUIRES(_ctx.at(cell_to_input_weights_index).shape().rank() == 1 &&
- _ctx.at(cell_to_forget_weights_index).shape().rank() == 1 &&
- _ctx.at(cell_to_output_weights_index).shape().rank() == 1 &&
- _ctx.at(input_gate_bias_index).shape().rank() == 1 &&
- _ctx.at(forget_gate_bias_index).shape().rank() == 1 &&
- _ctx.at(cell_bias_index).shape().rank() == 1 &&
- _ctx.at(output_gate_bias_index).shape().rank() == 1 &&
- _ctx.at(projection_bias_index).shape().rank() == 1);
-
- // CIFG assertion
- OP_REQUIRES((_ctx.at(input_to_input_weights_index).shape().dim(0) == 0 &&
- _ctx.at(input_to_input_weights_index).shape().dim(1) == 0 &&
- _ctx.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
- _ctx.at(recurrent_to_input_weights_index).shape().dim(1) == 0 &&
- _ctx.at(input_gate_bias_index).shape().dim(0) == 0 &&
- _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0) ||
- (_ctx.at(input_to_input_weights_index).shape().dim(0) != 0 &&
- _ctx.at(input_to_input_weights_index).shape().dim(1) != 0 &&
- _ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
- _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0 &&
- _ctx.at(input_gate_bias_index).shape().dim(0) != 0));
-
- // Peephole assertion
- OP_REQUIRES((_ctx.at(cell_to_forget_weights_index).shape().dim(0) == 0 &&
- _ctx.at(cell_to_output_weights_index).shape().dim(0) == 0) ||
- (_ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0 &&
- _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0));
-
- bool has_input_to_input_weights = _ctx.at(input_to_input_weights_index).shape().dim(0) != 0 &&
- _ctx.at(input_to_input_weights_index).shape().dim(1) != 0;
- bool has_recurrent_to_input_weights =
- _ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
- _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0;
- bool has_input_gate_bias = _ctx.at(input_gate_bias_index).shape().dim(0) != 0;
- bool has_cell_to_input_weights = _ctx.at(cell_to_input_weights_index).shape().dim(0) != 0;
- bool has_cell_to_forget_weights = _ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0;
- bool has_cell_to_output_weights = _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0;
- bool has_projection_weights = _ctx.at(projection_weights_index).shape().dim(0) != 0 &&
- _ctx.at(projection_weights_index).shape().dim(1) != 0;
- bool has_projection_bias = _ctx.at(projection_bias_index).shape().dim(0);
-
- // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
- // true: no CIFG
- // false: CIFG
- bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
-
- // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole.
- // true: peephole
- // false: no peephole
- bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
-
- // NOTE The projection weights may have data but the projection bias may not.
- bool has_projection_param = has_projection_weights;
-
- const auto batch_size = _ctx.at(input_index).shape().dim(0);
- OP_REQUIRES(batch_size == _ctx.at(output_state_in_index).shape().dim(0) &&
- batch_size == _ctx.at(cell_state_in_index).shape().dim(0) &&
- batch_size == _ctx.at(scratch_buffer_index).shape().dim(0) &&
- batch_size == _ctx.at(output_state_out_index).shape().dim(0) &&
- batch_size == _ctx.at(cell_state_out_index).shape().dim(0) &&
- batch_size == _ctx.at(output_index).shape().dim(0));
-
- const auto input_size = _ctx.at(input_index).shape().dim(1);
- OP_REQUIRES(input_size == _ctx.at(input_to_forget_weights_index).shape().dim(1) &&
- input_size == _ctx.at(input_to_cell_weights_index).shape().dim(1) &&
- input_size == _ctx.at(input_to_output_weights_index).shape().dim(1));
-
- const auto num_units = _ctx.at(cell_state_out_index).shape().dim(1);
- OP_REQUIRES(num_units == _ctx.at(input_to_forget_weights_index).shape().dim(0) &&
- num_units == _ctx.at(input_to_cell_weights_index).shape().dim(0) &&
- num_units == _ctx.at(input_to_output_weights_index).shape().dim(0) &&
- num_units == _ctx.at(recurrent_to_forget_weights_index).shape().dim(0) &&
- num_units == _ctx.at(recurrent_to_cell_weights_index).shape().dim(0) &&
- num_units == _ctx.at(recurrent_to_output_weights_index).shape().dim(0) &&
- num_units == _ctx.at(forget_gate_bias_index).shape().dim(0) &&
- num_units == _ctx.at(cell_bias_index).shape().dim(0) &&
- num_units == _ctx.at(output_gate_bias_index).shape().dim(0) &&
- num_units == _ctx.at(cell_state_in_index).shape().dim(1) &&
- (((num_units * 3) == _ctx.at(scratch_buffer_index).shape().dim(1)) ||
- ((num_units * 4) == _ctx.at(scratch_buffer_index).shape().dim(1))));
-
- const auto output_size = _ctx.at(output_index).shape().dim(1);
- OP_REQUIRES(output_size == _ctx.at(recurrent_to_forget_weights_index).shape().dim(1) &&
- output_size == _ctx.at(recurrent_to_cell_weights_index).shape().dim(1) &&
- output_size == _ctx.at(recurrent_to_output_weights_index).shape().dim(1) &&
- output_size == _ctx.at(output_state_in_index).shape().dim(1) &&
- output_size == _ctx.at(output_state_out_index).shape().dim(1));
-
- if (has_cifg_param)
- {
- OP_REQUIRES(input_size == _ctx.at(input_to_input_weights_index).shape().dim(1));
- OP_REQUIRES(num_units == _ctx.at(input_to_input_weights_index).shape().dim(0) &&
- num_units == _ctx.at(recurrent_to_input_weights_index).shape().dim(0) &&
- (num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0) ||
- _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0 /* non-peephole */) &&
- num_units == _ctx.at(input_gate_bias_index).shape().dim(0));
- OP_REQUIRES(output_size == _ctx.at(recurrent_to_input_weights_index).shape().dim(1));
- OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
- has_input_gate_bias);
- if (has_cell_to_input_weights)
- {
- // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole.
- OP_REQUIRES(has_peephole_param);
- }
- OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
- }
- else
- {
- OP_REQUIRES(_ctx.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
- }
-
- if (has_peephole_param)
- {
- OP_REQUIRES(num_units == _ctx.at(cell_to_forget_weights_index).shape().dim(0) &&
- num_units == _ctx.at(cell_to_output_weights_index).shape().dim(0) &&
- (num_units == _ctx.at(cell_to_input_weights_index).shape().dim(0) ||
- _ctx.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */));
- }
-
- if (has_projection_param)
- {
- OP_REQUIRES(num_units == _ctx.at(projection_weights_index).shape().dim(1));
- OP_REQUIRES(output_size == _ctx.at(projection_weights_index).shape().dim(0));
- if (has_projection_bias)
- {
- OP_REQUIRES(output_size == _ctx.at(projection_bias_index).shape().dim(0));
- }
- }
-}
-
-void OperationValidator::visit(const ir::operation::L2Normalization &node)
-{
- const auto ofm_index{node.getOutputs().at(0)};
- if (_ctx.at(ofm_index).info().isDynamic())
- return;
-
- const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
-
- auto ifm_shape = _ctx.at(ifm_index).shape();
- auto ofm_shape = _ctx.at(ofm_index).shape();
-
- OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
-
- for (auto i = 0; i < ifm_shape.rank(); i++)
- {
- OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i));
- }
-}
-
-void OperationValidator::visit(const ir::operation::Unpack &node)
-{
- const auto num{node.param().num};
- OP_REQUIRES(num == static_cast<int32_t>(node.getOutputs().size()));
- const auto axis{node.param().axis};
-
- const auto output_index{node.getInputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)};
-
- const auto &input_shape = _ctx.at(input_index).shape();
- const auto input_rank = static_cast<int32_t>(input_shape.rank());
-
- OP_REQUIRES(axis >= -input_rank && axis < input_rank);
-}
-
-void OperationValidator::visit(const ir::operation::Pad &node)
-{
- const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)};
- OP_REQUIRES(_ctx.at(pad_index).typeInfo().type() == ir::DataType::INT32);
-
- const auto output_index{node.getInputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)};
-
- const auto &pad_shape = _ctx.at(pad_index).shape();
- const auto input_rank = static_cast<int32_t>(_ctx.at(input_index).shape().rank());
-
- OP_REQUIRES(pad_shape.rank() == 2);
- OP_REQUIRES(pad_shape.dim(0) == input_rank);
- OP_REQUIRES(pad_shape.dim(1) == 2);
- OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
-}
-
-void OperationValidator::visit(const ir::operation::Select &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- // This validator does not check shape. So checking isDynamic() is skipped.
-
- const auto condition_index{node.getInputs().at(ir::operation::Select::Input::CONDITION)};
- const auto input_true_index{node.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)};
- const auto input_false_index{node.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)};
- UNUSED_RELEASE(output_index);
- UNUSED_RELEASE(input_true_index);
- UNUSED_RELEASE(input_false_index);
-
- OP_REQUIRES(_ctx.at(condition_index).typeInfo().type() == ir::DataType::BOOL8);
-}
-
-void OperationValidator::visit(const ir::operation::StridedSlice &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
- const auto starts_index{node.getInputs().at(ir::operation::StridedSlice::Input::STARTS)};
- const auto ends_index{node.getInputs().at(ir::operation::StridedSlice::Input::ENDS)};
- const auto strides_index{node.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)};
-
- UNUSED_RELEASE(starts_index);
- UNUSED_RELEASE(ends_index);
- UNUSED_RELEASE(strides_index);
-
- OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
-
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- OP_REQUIRES(_ctx.at(input_index).shape().rank() <= 4);
-}
-
-void OperationValidator::visit(const ir::operation::Split &node)
-{
- const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)};
-
- if (_ctx.at(input_index).info().isDynamic())
- return;
-
- const auto num_splits = node.param().num_splits;
- const auto input_rank = _ctx.at(input_index).shape().rank();
- const auto axis = node.param().axis < 0 ? node.param().axis + input_rank : node.param().axis;
-
- OP_REQUIRES(num_splits > 0 && num_splits <= 0xFFFF);
- OP_REQUIRES(axis >= 0 && axis < input_rank);
- OP_REQUIRES(node.getOutputs().size() == static_cast<uint32_t>(num_splits));
-
- OP_REQUIRES(_ctx.at(input_index).shape().dim(axis) % num_splits == 0);
-}
-
-void OperationValidator::visit(const ir::operation::Shape &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto input_index{node.getInputs().at(0)};
- UNUSED_RELEASE(input_index);
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == 1);
-}
-
-void OperationValidator::visit(const ir::operation::ResizeBilinear &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
-
- if (_ctx.at(output_index).info().isDynamic())
- {
- return;
- }
- OP_REQUIRES(_ctx.at(input_index).shape().rank() == 4);
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == 4);
-
- auto align_corners = node.param().align_corners;
- auto half_pixel_centers = node.param().half_pixel_centers;
-
- OP_REQUIRES(!align_corners || !half_pixel_centers);
-}
-
-void OperationValidator::visit(const ir::operation::Reverse &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)};
- const auto axis_index{node.getInputs().at(ir::operation::Reverse::Input::AXIS)};
-
- OP_REQUIRES(_ctx.at(axis_index).typeInfo().type() == ir::DataType::INT32);
- OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(input_index).typeInfo().type());
-
- if (_ctx.at(output_index).info().isDynamic())
- return;
- OP_REQUIRES(_ctx.at(output_index).shape() == _ctx.at(input_index).shape());
-}
-
-void OperationValidator::visit(const ir::operation::If &)
-{
- // TODO Add to validate with subgraphs
-}
-
-void OperationValidator::visit(const ir::operation::While &node)
-{
- // This validator does not check shape. So checking isDynamic() is skipped.
-
- OP_REQUIRES(node.getInputs().size() == node.getOutputs().size());
- // TODO Add to validate with subgraphs
-}
-
-void OperationValidator::visit(const ir::operation::SquaredDifference &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)};
- const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)};
-
- // Check for Type equivalence
- OP_REQUIRES(_ctx.at(output_index).typeInfo().type() == _ctx.at(lhs_index).typeInfo().type());
- OP_REQUIRES(_ctx.at(lhs_index).typeInfo().type() == _ctx.at(rhs_index).typeInfo().type());
-
- // Check for dimension constraints
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- auto output_shape = _ctx.at(output_index).shape();
- auto lhs_shape = _ctx.at(lhs_index).shape();
- auto rhs_shape = _ctx.at(rhs_index).shape();
- // Check for output rank
- OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank()));
- auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank());
-
- for (int idx = 1; idx <= min_rank; idx++)
- {
- int l_idx = lhs_shape.rank() - idx;
- int r_idx = rhs_shape.rank() - idx;
- int out_idx = output_shape.rank() - idx;
-
- OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0));
-
- auto l_dims = lhs_shape.dim(l_idx);
- auto r_dims = rhs_shape.dim(r_idx);
- auto out_dims = output_shape.dim(out_idx);
-
- OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) ||
- ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims)));
- }
- auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape;
- for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++)
- {
- int out_idx = output_shape.rank() - idx;
- int tmp_idx = tmp_shape.rank() - idx;
-
- OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) &&
- (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx)));
- }
-}
-void OperationValidator::visit(const ir::operation::Tile &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto input_index{node.getInputs().at(0)};
- const auto multiple_index{node.getInputs().at(1)};
-
- OP_REQUIRES(_ctx.at(multiple_index).shape().rank() == 1);
- OP_REQUIRES(_ctx.at(multiple_index).shape().dim(0) == _ctx.at(input_index).shape().rank());
- OP_REQUIRES(_ctx.at(input_index).shape().rank() == _ctx.at(output_index).shape().rank());
-}
-
-void OperationValidator::visit(const ir::operation::Range &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)};
- const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)};
- const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)};
-
- // Check for dimension constraints
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- OP_REQUIRES(_ctx.at(start_index).shape().rank() == 0);
- OP_REQUIRES(_ctx.at(limit_index).shape().rank() == 0);
- OP_REQUIRES(_ctx.at(delta_index).shape().rank() == 0);
-}
-
-void OperationValidator::visit(const ir::operation::MatrixBandPart &node)
-{
- const auto output_index{node.getOutputs().at(0)};
- const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)};
- const auto num_lower_index{
- node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_LOWER_DIAG)};
- const auto num_upper_index{
- node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_UPPER_DIAG)};
-
- // Check for dimension constraints
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- OP_REQUIRES(_ctx.at(input_index).shape().rank() >= 2); // input must be more than 2 dim matrix
- OP_REQUIRES(_ctx.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar
- OP_REQUIRES(_ctx.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar
-}
-
-void OperationValidator::visit(const ir::operation::LogSoftmax &node)
-{
- VERBOSE(LogSoftmax) << "Configure LOGSOFTMAX operation" << std::endl;
-
- const auto output_index{node.getOutputs().at(0)};
- if (_ctx.at(output_index).info().isDynamic())
- return;
-
- const auto input_index{node.getInputs().at(0)};
-
- OP_REQUIRES(_ctx.at(output_index).shape().rank() == _ctx.at(input_index).shape().rank());
-}
-
-} // namespace compiler
-} // namespace onert
diff --git a/runtime/onert/core/src/compiler/ParamChecker.h b/runtime/onert/core/src/compiler/ParamChecker.h
deleted file mode 100644
index 61429d521..000000000
--- a/runtime/onert/core/src/compiler/ParamChecker.h
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * @file ParamChecker.h
- * @brief This file contains ParamChecker to check\n
- * operations' parameters are compilable at machine independent phase\n
- * ex) Check param is constant
- */
-#ifndef __ONERT_COMPILER_PARAM_CHECKER_H__
-#define __ONERT_COMPILER_PARAM_CHECKER_H__
-
-#include "ir/OperationVisitor.h"
-
-namespace onert
-{
-namespace ir
-{
-class Graph;
-} // namespace ir
-} // namespace onert
-
-namespace onert
-{
-namespace compiler
-{
-
-class ParamChecker : public ir::OperationVisitor
-{
-public:
- /**
- * @brief Construct a new Param Checker object (deleted)
- */
- ParamChecker(void) = delete;
- /**
- * @brief Construct a new Param Checker object
- * @param[in] model Graph model to check
- */
- ParamChecker(std::shared_ptr<ir::Graph> model) : _model{model} {}
-
-public:
- /**
- * @brief Run parameter analysis
- */
- void operator()();
- /**
- * @brief Return analysis result if model have non-const parameter
- * @return @c true if there is non-const parameter, otherwise @c false
- */
- bool haveNoneConstParam(void) { return _nonConstParam; }
-
-private:
- const std::shared_ptr<ir::Graph> _model;
- bool _nonConstParam{false};
-};
-
-} // namespace compiler
-} // namespace onert
-
-#endif // __ONERT_COMPILER_OPERATION_VALIDATOR_H__
diff --git a/runtime/onert/core/src/compiler/PermuteFactor.cc b/runtime/onert/core/src/compiler/PermuteFactor.cc
new file mode 100644
index 000000000..f0081a2a4
--- /dev/null
+++ b/runtime/onert/core/src/compiler/PermuteFactor.cc
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "compiler/PermuteFactor.h"
+
+#include <assert.h>
+#include <ostream>
+
+#include "backend/Backend.h"
+
+std::ostream &operator<<(std::ostream &os, const onert::compiler::PermuteFactor &obj)
+{
+ assert(obj.backend() && obj.backend()->config());
+ return os << "(" << obj.backend()->config()->id() << "/" << to_string(obj.layout()) << ")";
+}
diff --git a/runtime/onert/core/src/compiler/ShapeValidator.cc b/runtime/onert/core/src/compiler/ShapeValidator.cc
new file mode 100644
index 000000000..0cd14c186
--- /dev/null
+++ b/runtime/onert/core/src/compiler/ShapeValidator.cc
@@ -0,0 +1,1132 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ShapeValidator.h"
+
+#include <typeinfo>
+
+#include "ir/Graph.h"
+#include "util/logging.h"
+#include "util/Utils.h"
+
+#define OP_REQUIRES(EXP) \
+ do \
+ { \
+ if (!(EXP)) \
+ throw std::runtime_error("ShapeValidator failed at line " + std::to_string(__LINE__)); \
+ } while (0)
+
+namespace onert
+{
+namespace compiler
+{
+
+ShapeValidator::ShapeValidator(const ir::Graph &graph) : _graph{graph} {}
+
+void ShapeValidator::checkUnaryOp(const ir::Operation &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(0)};
+
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ // Check if I/O shapes match
+ OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
+}
+
+void ShapeValidator::operator()()
+{
+ _graph.operations().iterate(
+ [&](const ir::OperationIndex &, const ir::IOperation &node) { node.accept(*this); });
+}
+
+void ShapeValidator::visit(const ir::operation::BatchMatMul &node)
+{
+ const auto &operands = _graph.operands();
+ const auto lhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::LHS));
+ const auto rhs_index(node.getInputs().at(ir::operation::BatchMatMul::Input::RHS));
+ const auto out_index{node.getOutputs().at(0)};
+
+ if (operands.at(out_index).info().isDynamic())
+ return;
+
+ OP_REQUIRES(operands.at(lhs_index).shape().rank() <= 4);
+ OP_REQUIRES(operands.at(rhs_index).shape().rank() <= 4);
+ OP_REQUIRES(operands.at(lhs_index).shape().rank() >= 2);
+ OP_REQUIRES(operands.at(rhs_index).shape().rank() >= 2);
+}
+
+void ShapeValidator::visit(const ir::operation::BatchToSpaceND &node)
+{
+ const auto &operands = _graph.operands();
+ const auto ofm_index{node.getOutputs().at(0)};
+ if (operands.at(ofm_index).info().isDynamic())
+ return;
+
+ const auto ifm_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::INPUT)};
+ const auto block_size_index{
+ node.getInputs().at(ir::operation::BatchToSpaceND::Input::BLOCK_SIZE)};
+
+ const auto frontend_layout = _graph.layout();
+ const auto input_shape = operands.at(ifm_index).shape().asFeature(frontend_layout);
+ const auto output_shape = operands.at(ofm_index).shape().asFeature(frontend_layout);
+
+ // All requirement as per NNAPI specification.
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1);
+
+ OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2);
+
+ if (node.getInputs().size() != 2)
+ {
+ const auto crops_index{node.getInputs().at(ir::operation::BatchToSpaceND::Input::CROPS_DATA)};
+ OP_REQUIRES(operands.at(crops_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(crops_index).shape().dim(0) ==
+ (operands.at(ifm_index).shape().rank() - 2));
+ OP_REQUIRES(operands.at(crops_index).shape().dim(1) == 2);
+ }
+
+ OP_REQUIRES(input_shape.C == output_shape.C);
+}
+
+void ShapeValidator::visit(const ir::operation::BCQFullyConnected &node)
+{
+ const auto &operands = _graph.operands();
+ const auto ofm_index{node.getOutputs().at(0)};
+ if (operands.at(ofm_index).info().isDynamic())
+ return;
+
+ const auto ifm_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
+ const auto weight_scales_index{
+ node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_SCALES)};
+ const auto weight_binary_index{
+ node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_BINARY)};
+ const auto weight_cluster_index{
+ node.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)};
+ const auto bias_index{node.getInputs().at(ir::operation::BCQFullyConnected::Input::BIAS)};
+
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(weight_scales_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(weight_binary_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(weight_cluster_index).shape().rank() == 2);
+
+ OP_REQUIRES(operands.at(ifm_index).shape().dim(1) == operands.at(ofm_index).shape().dim(1));
+
+ OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(0) > 0);
+ OP_REQUIRES(operands.at(weight_cluster_index).shape().dim(1) == 2);
+
+ // more shape validation will be done inside kernel.
+
+ OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
+}
+
+void ShapeValidator::visit(const ir::operation::BCQGather &node)
+{
+ const auto &operands = _graph.operands();
+ const auto ofm_index{node.getOutputs().at(0)};
+ if (operands.at(ofm_index).info().isDynamic())
+ return;
+
+ const auto indices_index{node.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
+ const auto input_binary_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
+ const auto input_scales_index{node.getInputs().at(ir::operation::BCQGather::Input::INPUT_SCALES)};
+ const auto input_clusters_index{
+ node.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)};
+
+ OP_REQUIRES(operands.at(indices_index).shape().rank() <=
+ 2); // TODO : support rank up to 4 or more
+ OP_REQUIRES(operands.at(input_binary_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(input_scales_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(input_clusters_index).shape().rank() == 2);
+
+ OP_REQUIRES(operands.at(input_clusters_index).shape().dim(0) > 0);
+ OP_REQUIRES(operands.at(input_clusters_index).shape().dim(1) == 2);
+
+ // more shape validation will be done inside kernel.
+}
+
+void ShapeValidator::visit(const ir::operation::Conv2D &node)
+{
+ const auto &operands = _graph.operands();
+ const auto ofm_index{node.getOutputs().at(0)};
+ if (operands.at(ofm_index).info().isDynamic())
+ return;
+
+ const auto ifm_index{node.getInputs().at(ir::operation::Conv2D::Input::INPUT)};
+ const auto ker_index{node.getInputs().at(ir::operation::Conv2D::Input::KERNEL)};
+ const auto bias_index{node.getInputs().at(ir::operation::Conv2D::Input::BIAS)};
+
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(ker_index).shape().rank() == 4);
+ OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
+}
+
+void ShapeValidator::visit(const ir::operation::Comparison &)
+{
+ // TODO Shape validation of comparison
+}
+
+void ShapeValidator::visit(const ir::operation::DepthwiseConv2D &node)
+{
+ const auto &operands = _graph.operands();
+ const auto ofm_index{node.getOutputs().at(0)};
+ if (operands.at(ofm_index).info().isDynamic())
+ return;
+
+ const auto ifm_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::INPUT)};
+ const auto ker_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::KERNEL)};
+ const auto bias_index{node.getInputs().at(ir::operation::DepthwiseConv2D::Input::BIAS)};
+
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(ker_index).shape().rank() == 4);
+ OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
+}
+
+void ShapeValidator::visit(const ir::operation::FullyConnected &node)
+{
+ const auto &operands = _graph.operands();
+ const auto ofm_index{node.getOutputs().at(0)};
+ if (operands.at(ofm_index).info().isDynamic())
+ return;
+
+ const auto ifm_index{node.getInputs().at(ir::operation::FullyConnected::Input::INPUT)};
+ const auto ker_index{node.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)};
+ const auto bias_index{node.getInputs().at(ir::operation::FullyConnected::Input::BIAS)};
+
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() >= 2);
+ OP_REQUIRES(operands.at(ker_index).shape().rank() == 2);
+ OP_REQUIRES(!operands.exist(bias_index) || operands.at(bias_index).shape().rank() == 1);
+}
+
+void ShapeValidator::visit(const ir::operation::Softmax &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto input_index{node.getInputs().at(0)};
+
+ OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
+}
+
+void ShapeValidator::visit(const ir::operation::InstanceNorm &node)
+{
+ const auto &operands = _graph.operands();
+ const auto ofm_index{node.getOutputs().at(0)};
+ if (operands.at(ofm_index).info().isDynamic())
+ return;
+
+ const auto ifm_index{node.getInputs().at(ir::operation::InstanceNorm::Input::INPUT)};
+ const auto gamma_index{node.getInputs().at(ir::operation::InstanceNorm::Input::GAMMA)};
+ const auto beta_index{node.getInputs().at(ir::operation::InstanceNorm::Input::BETA)};
+
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(ifm_index).shape() == operands.at(ofm_index).shape());
+ OP_REQUIRES(operands.at(gamma_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(beta_index).shape().rank() == 1);
+}
+
+void ShapeValidator::visit(const ir::operation::Pool2D &node)
+{
+ const auto &operands = _graph.operands();
+ const auto ofm_index{node.getOutputs().at(0)};
+ if (operands.at(ofm_index).info().isDynamic())
+ return;
+
+ const auto ifm_index{node.getInputs().at(ir::operation::Pool2D::Input::INPUT)};
+
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
+}
+
+void ShapeValidator::visit(const ir::operation::Permute &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto input_index{node.getInputs().at(0)};
+
+ OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
+}
+
+void ShapeValidator::visit(const ir::operation::Reduce &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto &input_index{node.getInputs().at(ir::operation::Reduce::Input::INPUT)};
+ const auto &input_shape = operands.at(input_index).shape();
+ const auto &output_shape = operands.at(output_index).shape();
+
+ OP_REQUIRES(input_shape.rank() <= 4);
+ OP_REQUIRES(output_shape.rank() <= input_shape.rank());
+
+ // NOTE For the 4-dimensions, if the rank of input and output are different, this runtime only
+ // supports cases reducing height and width or reducing depth.
+ // TODO We have to support all cases of dimensions up to 4.
+ // For correct permuting, we have to set output's shape to be equal in dimension position of the
+ // input. But the positions of the same dimensions in the input and output may be set differently.
+ // For example {2,3,4,5}(input's shape) can be reduced to {3,5}(output's shape). The original
+ // output shape should be {1,3,1,5}, but real output shape may be {3,5}. If you simply try to
+ // extend it in 4 dimensions, it should be {1,1,3,5}.
+ // Even if output shape is changed to {1,3,1,5}, there is another problem. It is that shape of
+ // output tensor used at next operation is changed to {1,3,1,5} after this operation even if the
+ // next operation is not desired.
+ if (input_shape.rank() == 4 && input_shape.rank() != output_shape.rank())
+ {
+ if (output_shape.rank() == 2)
+ {
+ // Reducing HW
+ OP_REQUIRES(input_shape.dim(0) == output_shape.dim(0) &&
+ input_shape.dim(3) == output_shape.dim(1));
+ }
+ else if (output_shape.rank() == 3)
+ {
+ // Reducing C or
+ // (Reducing H and C(input and output) == 1) or (Reducing W and C(input and output) == 1)
+ OP_REQUIRES(
+ (input_shape.dim(0) == output_shape.dim(0) && input_shape.dim(1) == output_shape.dim(1) &&
+ input_shape.dim(2) == output_shape.dim(2)) ||
+ (input_shape.dim(0) == output_shape.dim(0) &&
+ (input_shape.dim(1) == output_shape.dim(1) || input_shape.dim(2) == output_shape.dim(1)) &&
+ input_shape.dim(3) == 1 && output_shape.dim(2) == 1));
+ }
+ }
+}
+
+void ShapeValidator::visit(const ir::operation::Transpose &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto input_index{node.getInputs().at(ir::operation::Transpose::Input::INPUT)};
+ const auto perm_index{node.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
+
+ const auto &output_shape = operands.at(output_index).shape();
+ const auto &input_shape = operands.at(input_index).shape();
+
+ OP_REQUIRES(operands.at(perm_index).shape().num_elements() == 0 ||
+ input_shape.rank() ==
+ static_cast<int>(operands.at(perm_index).shape().num_elements()));
+ OP_REQUIRES(input_shape.rank() == output_shape.rank());
+}
+
+void ShapeValidator::visit(const ir::operation::RNN &node)
+{
+ // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
+ // TODO Support dynamic rnn
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(ir::operation::RNN::Output::OUTPUT)};
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto hidden_state_out_index{
+ node.getOutputs().at(ir::operation::RNN::Output::HIDDEN_STATE_OUT)};
+
+ const auto input_index{node.getInputs().at(ir::operation::RNN::Input::INPUT)};
+ const auto weights_index{node.getInputs().at(ir::operation::RNN::Input::WEIGHTS)};
+ const auto recurrent_weights_index{
+ node.getInputs().at(ir::operation::RNN::Input::RECURRENT_WEIGHTS)};
+ const auto bias_index{node.getInputs().at(ir::operation::RNN::Input::BIAS)};
+ const auto hidden_state_in_index{node.getInputs().at(ir::operation::RNN::Input::HIDDEN_STATE_IN)};
+
+ const auto batch_size = operands.at(output_index).shape().dim(0);
+ const auto num_units = operands.at(output_index).shape().dim(1);
+
+ OP_REQUIRES(operands.at(output_index).shape().rank() == 2 &&
+ operands.at(hidden_state_out_index).shape().rank() == 2 &&
+ operands.at(input_index).shape().rank() == 2 &&
+ operands.at(weights_index).shape().rank() == 2 &&
+ operands.at(recurrent_weights_index).shape().rank() == 2 &&
+ operands.at(hidden_state_in_index).shape().rank() == 2);
+ OP_REQUIRES(operands.at(bias_index).shape().rank() == 1);
+
+ OP_REQUIRES(batch_size == operands.at(input_index).shape().dim(0) &&
+ batch_size == operands.at(hidden_state_in_index).shape().dim(0) &&
+ batch_size == operands.at(hidden_state_out_index).shape().dim(0));
+ OP_REQUIRES(operands.at(input_index).shape().dim(1) == operands.at(weights_index).shape().dim(1));
+
+ OP_REQUIRES(num_units == operands.at(weights_index).shape().dim(0) &&
+ num_units == operands.at(recurrent_weights_index).shape().dim(0) &&
+ num_units == operands.at(bias_index).shape().dim(0));
+ OP_REQUIRES(num_units == operands.at(output_index).shape().dim(1) &&
+ num_units == operands.at(recurrent_weights_index).shape().dim(1) &&
+ num_units == operands.at(hidden_state_in_index).shape().dim(1) &&
+ num_units == operands.at(hidden_state_out_index).shape().dim(1));
+}
+
+void ShapeValidator::visit(const ir::operation::SpaceToBatchND &node)
+{
+ const auto &operands = _graph.operands();
+ const auto ofm_index{node.getOutputs().at(0)};
+ if (operands.at(ofm_index).info().isDynamic())
+ return;
+
+ const auto ifm_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
+ const auto block_size_index{
+ node.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
+ const auto paddings_index{node.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
+
+ const auto frontend_layout = _graph.layout();
+ const auto input_shape = operands.at(ifm_index).shape().asFeature(frontend_layout);
+ const auto output_shape = operands.at(ofm_index).shape().asFeature(frontend_layout);
+
+ // All requirement as per NNAPI specification.
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(block_size_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(paddings_index).shape().rank() == 2);
+
+ OP_REQUIRES(operands.at(block_size_index).shape().dim(0) == 2);
+ OP_REQUIRES(operands.at(paddings_index).shape().dim(0) == 2);
+ OP_REQUIRES(operands.at(paddings_index).shape().dim(1) == 2);
+
+ OP_REQUIRES(input_shape.C == output_shape.C);
+}
+
+void ShapeValidator::visit(const ir::operation::SpaceToDepth &node)
+{
+ const auto &operands = _graph.operands();
+ const auto ofm_index{node.getOutputs().at(0)};
+ if (operands.at(ofm_index).info().isDynamic())
+ return;
+
+ const auto ifm_index{node.getInputs().at(ir::operation::SpaceToDepth::Input::INPUT)};
+
+ const auto frontend_layout = _graph.layout();
+ const auto input_shape = operands.at(ifm_index).shape().asFeature(frontend_layout);
+ const auto output_shape = operands.at(ofm_index).shape().asFeature(frontend_layout);
+ const auto block_size = node.param().block_size;
+
+ // All assertions as per NNAPI specification.
+ OP_REQUIRES(operands.at(ifm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
+ OP_REQUIRES((input_shape.H % block_size == 0) && (input_shape.W % block_size == 0));
+ OP_REQUIRES(input_shape.N == output_shape.N);
+ OP_REQUIRES(input_shape.C * block_size * block_size == output_shape.C);
+}
+
+void ShapeValidator::visit(const ir::operation::ElementwiseActivation &node) { checkUnaryOp(node); }
+
+void ShapeValidator::visit(const ir::operation::ElementwiseBinary &)
+{
+ // TODO Shape validation of ElementwiseBinary
+}
+
+void ShapeValidator::visit(const ir::operation::ElementwiseUnary &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)};
+
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
+}
+
+void ShapeValidator::visit(const ir::operation::EmbeddingLookup &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ const auto lookups_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::LOOKUPS)};
+ const auto values_index{node.getInputs().at(ir::operation::EmbeddingLookup::Input::VALUES)};
+
+ const auto &output_obj = operands.at(output_index);
+ const auto &lookups_obj = operands.at(lookups_index);
+ const auto &values_obj = operands.at(values_index);
+
+ // Verify operand here, not at SimpleEmbeddingLookup::configure() to avoid acl's modifying
+ // TensorShape sometimes(Issue: https://github.sec.samsung.net/STAR/nnfw/issues/729)
+ {
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto &output_shape = output_obj.shape();
+ const auto &lookups_shape = lookups_obj.shape();
+ const auto &values_shape = values_obj.shape();
+
+ OP_REQUIRES(lookups_shape.rank() == 1);
+ OP_REQUIRES(values_shape.rank() >= 2);
+
+ // output should be a n-D tensor with the same rank and shape as the values tensor, except for
+ // the first dimension which has the same size as lookups' only dimension.
+ OP_REQUIRES(output_shape.rank() == values_shape.rank());
+ OP_REQUIRES(output_shape.dim(0) == lookups_shape.dim(0));
+ for (int n = 1; n < output_shape.rank(); ++n)
+ {
+ OP_REQUIRES(output_shape.dim(n) == values_shape.dim(n));
+ }
+ }
+}
+
+void ShapeValidator::visit(const ir::operation::ExpandDims &node)
+{
+ const auto &operands = _graph.operands();
+ const auto axis_index{node.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
+
+ if (operands.at(axis_index).info().isDynamic())
+ return;
+ OP_REQUIRES(operands.at(axis_index).shape().rank() <= 1);
+}
+
+void ShapeValidator::visit(const ir::operation::HashtableLookup &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(ir::operation::HashtableLookup::Output::OUTPUT)};
+ const auto lookups_index{node.getInputs().at(ir::operation::HashtableLookup::Input::LOOKUPS)};
+ const auto keys_index{node.getInputs().at(ir::operation::HashtableLookup::Input::KEYS)};
+ const auto values_index{node.getInputs().at(ir::operation::HashtableLookup::Input::VALUES)};
+
+ const auto &output_obj = operands.at(output_index);
+ const auto &lookups_obj = operands.at(lookups_index);
+ const auto &keys_obj = operands.at(keys_index);
+ const auto &values_obj = operands.at(values_index);
+
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto &output_shape = output_obj.shape();
+ const auto &lookups_shape = lookups_obj.shape();
+ const auto &keys_shape = keys_obj.shape();
+ const auto &values_shape = values_obj.shape();
+
+ OP_REQUIRES(values_shape.rank() == output_shape.rank());
+ OP_REQUIRES(lookups_shape.rank() == 1);
+ OP_REQUIRES(keys_shape.rank() == 1);
+ OP_REQUIRES(values_shape.dim(0) == keys_shape.dim(0));
+ OP_REQUIRES(lookups_shape.dim(0) == output_shape.dim(0));
+}
+
+void ShapeValidator::visit(const ir::operation::TransposeConv &node)
+{
+ // shape check
+ const auto &operands = _graph.operands();
+ const auto ofm_index{node.getOutputs().at(0)};
+
+ if (operands.at(ofm_index).info().isDynamic())
+ return;
+
+ const auto ifm_index{node.getInputs().at(ir::operation::TransposeConv::Input::INPUT)};
+ const auto ker_index{node.getInputs().at(ir::operation::TransposeConv::Input::KERNEL)};
+
+ // Only 4D tensors are supported
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ifm_index).shape().rank());
+ OP_REQUIRES(operands.at(ofm_index).shape().rank() == operands.at(ker_index).shape().rank());
+
+ const auto frontend_layout = _graph.layout();
+ const auto ofm_shape = operands.at(ofm_index).shape().asFeature(frontend_layout);
+ const auto ifm_shape = operands.at(ifm_index).shape().asFeature(frontend_layout);
+ // The kernel has only IHWO layout on frontend
+ // So ker_shape is treated here below
+ // I -> N
+ // H -> H
+ // W -> W
+ // O -> C
+ const auto ker_shape = operands.at(ker_index).shape().asFeature(ir::Layout::NHWC);
+
+ OP_REQUIRES(ifm_shape.N == ofm_shape.N);
+ OP_REQUIRES(ifm_shape.C == ker_shape.C);
+ OP_REQUIRES(ker_shape.N == ofm_shape.C);
+}
+
+void ShapeValidator::visit(const ir::operation::Gather &node)
+{
+ const auto &operands = _graph.operands();
+ const auto ofm_index{node.getOutputs().at(0)};
+ if (operands.at(ofm_index).info().isDynamic())
+ return;
+
+ const auto ifm_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
+ const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
+
+ const auto &ifm_shape = operands.at(ifm_index).shape();
+ const auto &indices_shape = operands.at(indices_index).shape();
+ const auto &ofm_shape = operands.at(ofm_index).shape();
+
+ OP_REQUIRES(ifm_shape.rank() <= 4);
+ OP_REQUIRES(indices_shape.rank() <= 3);
+ OP_REQUIRES(ofm_shape.rank() <= 4);
+}
+
+void ShapeValidator::visit(const ir::operation::DepthToSpace &node)
+{
+ const auto &operands = _graph.operands();
+ int32_t block_size = node.param().block_size;
+
+ // shape check
+ const auto output_index{node.getOutputs().at(0)};
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto input_index{node.getInputs().at(ir::operation::DepthToSpace::Input::INPUT)};
+
+ const auto frontend_layout = _graph.layout();
+ const auto output_shape = operands.at(output_index).shape().asFeature(frontend_layout);
+ const auto input_shape = operands.at(input_index).shape().asFeature(frontend_layout);
+
+ OP_REQUIRES(operands.at(input_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(output_index).shape().rank() == 4);
+
+ {
+ OP_REQUIRES(output_shape.N == input_shape.N);
+ OP_REQUIRES(output_shape.H == input_shape.H * block_size);
+ OP_REQUIRES(output_shape.W == input_shape.W * block_size);
+ OP_REQUIRES(input_shape.C % (block_size * block_size) == 0);
+ OP_REQUIRES(output_shape.C == input_shape.C / (block_size * block_size));
+ }
+}
+
+void ShapeValidator::visit(const ir::operation::Pack &node)
+{
+ const auto &operands = _graph.operands();
+ const auto axis{node.param().axis};
+ const auto output_index{node.getOutputs().at(0)};
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ // shape check
+ const auto &output_shape = operands.at(output_index).shape();
+ const auto output_rank = static_cast<int32_t>(output_shape.rank());
+
+ const auto input1_index{node.getInputs().at(0)};
+ const auto &input_shape = operands.at(input1_index).shape();
+
+ OP_REQUIRES(axis >= -output_rank && axis < output_rank);
+ for (const auto &index : node.getInputs())
+ {
+ OP_REQUIRES(input_shape == operands.at(index).shape());
+ }
+}
+
+void ShapeValidator::visit(const ir::operation::LSTM &node)
+{
+ // NOTE This validation is for static rnn(non-dynamic shape), but not for dynamic rnn
+ // TODO Support dynamic rnn
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto scratch_buffer_index{
+ node.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)}; // Optional
+ const auto output_state_out_index{
+ node.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)}; // Optional
+ const auto cell_state_out_index{
+ node.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)}; // Optional
+
+ const auto input_index{node.getInputs().at(ir::operation::LSTM::Input::INPUT)};
+ const auto input_to_input_weights_index{
+ node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)}; // Optional
+ const auto input_to_forget_weights_index{
+ node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_FORGET_WEIGHTS)};
+ const auto input_to_cell_weights_index{
+ node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_CELL_WEIGHTS)};
+ const auto input_to_output_weights_index{
+ node.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
+ const auto recurrent_to_input_weights_index{
+ node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)}; // Optional
+ const auto recurrent_to_forget_weights_index{
+ node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)};
+ const auto recurrent_to_cell_weights_index{
+ node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)};
+ const auto recurrent_to_output_weights_index{
+ node.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
+ const auto cell_to_input_weights_index{
+ node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_INPUT_WEIGHTS)}; // Optional
+ const auto cell_to_forget_weights_index{
+ node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_FORGET_WEIGHTS)}; // Optional
+ const auto cell_to_output_weights_index{
+ node.getInputs().at(ir::operation::LSTM::Input::CELL_TO_OUTPUT_WEIGHTS)}; // Optional
+ const auto input_gate_bias_index{
+ node.getInputs().at(ir::operation::LSTM::Input::INPUT_GATE_BIAS)}; // Optional
+ const auto forget_gate_bias_index{
+ node.getInputs().at(ir::operation::LSTM::Input::FORGET_GATE_BIAS)};
+ const auto cell_bias_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_BIAS)};
+ const auto output_gate_bias_index{
+ node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_GATE_BIAS)};
+ const auto projection_weights_index{
+ node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_WEIGHTS)}; // Optional
+ const auto projection_bias_index{
+ node.getInputs().at(ir::operation::LSTM::Input::PROJECTION_BIAS)}; // Optional
+ const auto output_state_in_index{
+ node.getInputs().at(ir::operation::LSTM::Input::OUTPUT_STATE_IN)};
+ const auto cell_state_in_index{node.getInputs().at(ir::operation::LSTM::Input::CELL_STATE_IN)};
+
+ OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
+ for (int i = 0; i < operands.at(input_index).shape().rank() - 1; ++i)
+ {
+ OP_REQUIRES(operands.at(input_index).shape().dim(i) ==
+ operands.at(output_index).shape().dim(i));
+ }
+ OP_REQUIRES((operands.at(output_index).shape().rank() == 2 ||
+ operands.at(output_index).shape().rank() == 3) &&
+ (operands.at(input_index).shape().rank() == 2 ||
+ operands.at(input_index).shape().rank() == 3) &&
+ (!operands.exist(input_to_input_weights_index) ||
+ operands.at(input_to_input_weights_index).shape().rank() == 2) &&
+ operands.at(input_to_forget_weights_index).shape().rank() == 2 &&
+ operands.at(input_to_cell_weights_index).shape().rank() == 2 &&
+ operands.at(input_to_output_weights_index).shape().rank() == 2 &&
+ (!operands.exist(recurrent_to_input_weights_index) ||
+ operands.at(recurrent_to_input_weights_index).shape().rank() == 2) &&
+ operands.at(recurrent_to_forget_weights_index).shape().rank() == 2 &&
+ operands.at(recurrent_to_cell_weights_index).shape().rank() == 2 &&
+ operands.at(recurrent_to_output_weights_index).shape().rank() == 2 &&
+ (!operands.exist(projection_weights_index) ||
+ operands.at(projection_weights_index).shape().rank() == 2) &&
+ operands.at(output_state_in_index).shape().rank() == 2 &&
+ operands.at(cell_state_in_index).shape().rank() == 2);
+
+ OP_REQUIRES((!operands.exist(cell_to_input_weights_index) ||
+ operands.at(cell_to_input_weights_index).shape().rank() == 1) &&
+ (!operands.exist(cell_to_forget_weights_index) ||
+ operands.at(cell_to_forget_weights_index).shape().rank() == 1) &&
+ (!operands.exist(cell_to_output_weights_index) ||
+ operands.at(cell_to_output_weights_index).shape().rank() == 1) &&
+ (!operands.exist(input_gate_bias_index) ||
+ operands.at(input_gate_bias_index).shape().rank() == 1) &&
+ operands.at(forget_gate_bias_index).shape().rank() == 1 &&
+ operands.at(cell_bias_index).shape().rank() == 1 &&
+ operands.at(output_gate_bias_index).shape().rank() == 1 &&
+ (!operands.exist(projection_bias_index) ||
+ operands.at(projection_bias_index).shape().rank() == 1));
+
+ // CIFG assertion
+ OP_REQUIRES(((!operands.exist(input_to_input_weights_index) ||
+ (operands.at(input_to_input_weights_index).shape().dim(0) == 0 &&
+ operands.at(input_to_input_weights_index).shape().dim(1) == 0)) &&
+ (!operands.exist(recurrent_to_input_weights_index) ||
+ (operands.at(recurrent_to_input_weights_index).shape().dim(0) == 0 &&
+ operands.at(recurrent_to_input_weights_index).shape().dim(1) == 0)) &&
+ (!operands.exist(input_gate_bias_index) ||
+ operands.at(input_gate_bias_index).shape().dim(0) == 0) &&
+ (!operands.exist(cell_to_input_weights_index) ||
+ operands.at(cell_to_input_weights_index).shape().dim(0) == 0)) ||
+ ((operands.exist(input_to_input_weights_index) &&
+ (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
+ operands.at(input_to_input_weights_index).shape().dim(1) != 0)) &&
+ (operands.exist(recurrent_to_input_weights_index) &&
+ (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
+ operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0)) &&
+ (operands.exist(input_gate_bias_index) &&
+ operands.at(input_gate_bias_index).shape().dim(0) != 0)));
+
+ // Peephole assertion
+ OP_REQUIRES(((!operands.exist(cell_to_forget_weights_index) ||
+ operands.at(cell_to_forget_weights_index).shape().dim(0) == 0) &&
+ (!operands.exist(cell_to_output_weights_index) ||
+ operands.at(cell_to_output_weights_index).shape().dim(0) == 0)) ||
+ ((operands.exist(cell_to_forget_weights_index) &&
+ operands.at(cell_to_forget_weights_index).shape().dim(0) != 0) &&
+ (operands.exist(cell_to_output_weights_index) &&
+ operands.at(cell_to_output_weights_index).shape().dim(0) != 0)));
+
+ bool has_input_to_input_weights =
+ operands.exist(input_to_input_weights_index) &&
+ (operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
+ operands.at(input_to_input_weights_index).shape().dim(1) != 0);
+ bool has_recurrent_to_input_weights =
+ operands.exist(recurrent_to_input_weights_index) &&
+ (operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
+ operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0);
+ bool has_input_gate_bias =
+ operands.exist(input_gate_bias_index) && operands.at(input_gate_bias_index).shape().dim(0) != 0;
+ bool has_cell_to_input_weights = operands.exist(cell_to_input_weights_index) &&
+ operands.at(cell_to_input_weights_index).shape().dim(0) != 0;
+ bool has_cell_to_forget_weights = operands.exist(cell_to_forget_weights_index) &&
+ operands.at(cell_to_forget_weights_index).shape().dim(0) != 0;
+ bool has_cell_to_output_weights = operands.exist(cell_to_output_weights_index) &&
+ operands.at(cell_to_output_weights_index).shape().dim(0) != 0;
+ bool has_projection_weights = operands.exist(projection_weights_index) &&
+ (operands.at(projection_weights_index).shape().dim(0) != 0 &&
+ operands.at(projection_weights_index).shape().dim(1) != 0);
+ bool has_projection_bias =
+ operands.exist(projection_bias_index) && operands.at(projection_bias_index).shape().dim(0) != 0;
+
+ // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
+ // true: no CIFG
+ // false: CIFG
+ bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
+
+ // NOTE The cell_to_input_weights do not exist in regular CIFG although peephole.
+ // true: peephole
+ // false: no peephole
+ bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights;
+
+ // NOTE The projection weights may have data but the projection bias may not.
+ bool has_projection_param = has_projection_weights;
+
+ const auto batch_size = (operands.at(input_index).shape().rank() == 3 && node.param().time_major)
+ ? operands.at(input_index).shape().dim(1)
+ : operands.at(input_index).shape().dim(0);
+ OP_REQUIRES(batch_size == operands.at(output_state_in_index).shape().dim(0) &&
+ batch_size == operands.at(cell_state_in_index).shape().dim(0));
+
+ const auto input_size =
+ operands.at(input_index).shape().dim(operands.at(input_index).shape().rank() - 1);
+ OP_REQUIRES(input_size == operands.at(input_to_forget_weights_index).shape().dim(1) &&
+ input_size == operands.at(input_to_cell_weights_index).shape().dim(1) &&
+ input_size == operands.at(input_to_output_weights_index).shape().dim(1));
+
+ const auto num_units = operands.at(input_to_output_weights_index).shape().dim(0);
+ OP_REQUIRES(num_units == operands.at(input_to_cell_weights_index).shape().dim(0) &&
+ num_units == operands.at(input_to_output_weights_index).shape().dim(0) &&
+ num_units == operands.at(recurrent_to_forget_weights_index).shape().dim(0) &&
+ num_units == operands.at(recurrent_to_cell_weights_index).shape().dim(0) &&
+ num_units == operands.at(recurrent_to_output_weights_index).shape().dim(0) &&
+ num_units == operands.at(forget_gate_bias_index).shape().dim(0) &&
+ num_units == operands.at(cell_bias_index).shape().dim(0) &&
+ num_units == operands.at(output_gate_bias_index).shape().dim(0) &&
+ num_units == operands.at(cell_state_in_index).shape().dim(1));
+
+ const auto output_size =
+ operands.at(output_index).shape().dim(operands.at(output_index).shape().rank() - 1);
+ OP_REQUIRES(output_size == operands.at(recurrent_to_forget_weights_index).shape().dim(1) &&
+ output_size == operands.at(recurrent_to_cell_weights_index).shape().dim(1) &&
+ output_size == operands.at(recurrent_to_output_weights_index).shape().dim(1) &&
+ output_size == operands.at(output_state_in_index).shape().dim(1));
+
+ if (has_cifg_param)
+ {
+ OP_REQUIRES(input_size == operands.at(input_to_input_weights_index).shape().dim(1));
+ OP_REQUIRES(
+ num_units == operands.at(input_to_input_weights_index).shape().dim(0) &&
+ num_units == operands.at(recurrent_to_input_weights_index).shape().dim(0) &&
+ ((operands.exist(cell_to_input_weights_index) &&
+ num_units == operands.at(cell_to_input_weights_index).shape().dim(0)) ||
+ (!operands.exist(cell_to_input_weights_index) ||
+ operands.at(cell_to_input_weights_index).shape().dim(0) == 0) /* non-peephole */) &&
+ num_units == operands.at(input_gate_bias_index).shape().dim(0));
+ OP_REQUIRES(output_size == operands.at(recurrent_to_input_weights_index).shape().dim(1));
+ OP_REQUIRES(has_input_to_input_weights && has_recurrent_to_input_weights &&
+ has_input_gate_bias);
+ if (has_cell_to_input_weights)
+ {
+ // NOTE The cell_to_input_weights exist only in case of non-CIFG and peephole.
+ OP_REQUIRES(has_peephole_param);
+ }
+ if (operands.exist(scratch_buffer_index))
+ OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 4);
+ }
+ else
+ {
+ if (operands.exist(scratch_buffer_index))
+ OP_REQUIRES(operands.at(scratch_buffer_index).shape().dim(1) == num_units * 3);
+ }
+
+ if (has_peephole_param)
+ {
+ OP_REQUIRES(num_units == operands.at(cell_to_forget_weights_index).shape().dim(0) &&
+ num_units == operands.at(cell_to_output_weights_index).shape().dim(0) &&
+ (num_units == operands.at(cell_to_input_weights_index).shape().dim(0) ||
+ operands.at(cell_to_input_weights_index).shape().dim(0) == 0 /* CIFG */));
+ }
+
+ if (has_projection_param)
+ {
+ OP_REQUIRES(num_units == operands.at(projection_weights_index).shape().dim(1));
+ OP_REQUIRES(output_size == operands.at(projection_weights_index).shape().dim(0));
+ if (has_projection_bias)
+ {
+ OP_REQUIRES(output_size == operands.at(projection_bias_index).shape().dim(0));
+ }
+ }
+
+ if (operands.exist(scratch_buffer_index))
+ {
+ OP_REQUIRES(operands.at(scratch_buffer_index).shape().rank() == 2);
+ OP_REQUIRES(batch_size == operands.at(scratch_buffer_index).shape().dim(0));
+ }
+
+ if (operands.exist(output_state_out_index))
+ {
+ OP_REQUIRES(operands.at(output_state_out_index).shape().rank() == 2);
+ OP_REQUIRES(batch_size == operands.at(output_state_out_index).shape().dim(0));
+ OP_REQUIRES(output_size == operands.at(output_state_out_index).shape().dim(1));
+ }
+
+ if (operands.exist(cell_state_out_index))
+ {
+ OP_REQUIRES(operands.at(cell_state_out_index).shape().rank() == 2);
+ OP_REQUIRES(batch_size == operands.at(cell_state_out_index).shape().dim(0));
+ OP_REQUIRES(num_units == operands.at(cell_state_out_index).shape().dim(1));
+ }
+}
+
+void ShapeValidator::visit(const ir::operation::L2Normalization &node)
+{
+ const auto &operands = _graph.operands();
+ const auto ofm_index{node.getOutputs().at(0)};
+ if (operands.at(ofm_index).info().isDynamic())
+ return;
+
+ const auto ifm_index{node.getInputs().at(ir::operation::L2Normalization::Input::INPUT)};
+
+ auto ifm_shape = operands.at(ifm_index).shape();
+ auto ofm_shape = operands.at(ofm_index).shape();
+
+ OP_REQUIRES(ifm_shape.rank() == ofm_shape.rank());
+
+ for (auto i = 0; i < ifm_shape.rank(); i++)
+ {
+ OP_REQUIRES(ifm_shape.dim(i) == ofm_shape.dim(i));
+ }
+}
+
+void ShapeValidator::visit(const ir::operation::Unpack &node)
+{
+ const auto &operands = _graph.operands();
+ const auto axis{node.param().axis};
+ const auto output_index{node.getInputs().at(0)};
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto input_index{node.getInputs().at(ir::operation::Unpack::Input::INPUT)};
+
+ const auto &input_shape = operands.at(input_index).shape();
+ const auto input_rank = static_cast<int32_t>(input_shape.rank());
+
+ OP_REQUIRES(axis >= -input_rank && axis < input_rank);
+}
+
+void ShapeValidator::visit(const ir::operation::Pad &node)
+{
+ const auto &operands = _graph.operands();
+ const auto pad_index{node.getInputs().at(ir::operation::Pad::Input::PAD)};
+ OP_REQUIRES(operands.at(pad_index).typeInfo().type() == ir::DataType::INT32);
+
+ const auto output_index{node.getInputs().at(0)};
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto input_index{node.getInputs().at(ir::operation::Pad::Input::INPUT)};
+
+ const auto &pad_shape = operands.at(pad_index).shape();
+ const auto input_rank = static_cast<int32_t>(operands.at(input_index).shape().rank());
+
+ OP_REQUIRES(pad_shape.rank() == 2);
+ OP_REQUIRES(pad_shape.dim(0) == input_rank);
+ OP_REQUIRES(pad_shape.dim(1) == 2);
+ OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
+}
+
+void ShapeValidator::visit(const ir::operation::Select &)
+{
+ // TODO Shape validation of select
+}
+
+void ShapeValidator::visit(const ir::operation::StridedSlice &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
+
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ OP_REQUIRES(operands.at(input_index).shape().rank() <= 4);
+}
+
+void ShapeValidator::visit(const ir::operation::Split &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto input_index{node.getInputs().at(ir::operation::Split::Input::INPUT)};
+ const auto axis_index{node.getInputs().at(ir::operation::Split::Input::AXIS)};
+
+ const auto num_splits = node.param().num_splits;
+ const auto input_rank = operands.at(input_index).shape().rank();
+ auto axis = *reinterpret_cast<const int32_t *>(operands.at(axis_index).data()->base());
+ axis = axis < 0 ? axis + input_rank : axis;
+
+ OP_REQUIRES(axis >= 0 && axis < input_rank);
+ OP_REQUIRES(operands.at(input_index).shape().dim(axis) % num_splits == 0);
+}
+
+void ShapeValidator::visit(const ir::operation::Shape &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto input_index{node.getInputs().at(0)};
+ UNUSED_RELEASE(input_index);
+ OP_REQUIRES(operands.at(output_index).shape().rank() == 1);
+}
+
+void ShapeValidator::visit(const ir::operation::ResizeBilinear &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
+
+ if (operands.at(output_index).info().isDynamic())
+ {
+ return;
+ }
+ OP_REQUIRES(operands.at(input_index).shape().rank() == 4);
+ OP_REQUIRES(operands.at(output_index).shape().rank() == 4);
+}
+
+void ShapeValidator::visit(const ir::operation::Reverse &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(ir::operation::Reverse::Input::INPUT)};
+
+ if (operands.at(output_index).info().isDynamic())
+ return;
+ OP_REQUIRES(operands.at(output_index).shape() == operands.at(input_index).shape());
+}
+
+void ShapeValidator::visit(const ir::operation::If &)
+{
+ // TODO Add to validate with subgraphs
+}
+
+void ShapeValidator::visit(const ir::operation::While &)
+{
+ // This validator does not check shape. So checking isDynamic() is skipped.
+ // TODO Add to validate with subgraphs
+}
+
+void ShapeValidator::visit(const ir::operation::SquaredDifference &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ const auto lhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::LHS)};
+ const auto rhs_index{node.getInputs().at(ir::operation::SquaredDifference::Input::RHS)};
+
+ // Check for dimension constraints
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ auto output_shape = operands.at(output_index).shape();
+ auto lhs_shape = operands.at(lhs_index).shape();
+ auto rhs_shape = operands.at(rhs_index).shape();
+ // Check for output rank
+ OP_REQUIRES(output_shape.rank() == std::max(lhs_shape.rank(), rhs_shape.rank()));
+ auto min_rank = std::min(lhs_shape.rank(), rhs_shape.rank());
+
+ for (int idx = 1; idx <= min_rank; idx++)
+ {
+ int l_idx = lhs_shape.rank() - idx;
+ int r_idx = rhs_shape.rank() - idx;
+ int out_idx = output_shape.rank() - idx;
+
+ OP_REQUIRES((l_idx >= 0) && (r_idx >= 0) && (out_idx >= 0));
+
+ auto l_dims = lhs_shape.dim(l_idx);
+ auto r_dims = rhs_shape.dim(r_idx);
+ auto out_dims = output_shape.dim(out_idx);
+
+ OP_REQUIRES(((l_dims == r_dims) && (out_dims == l_dims)) ||
+ ((l_dims == 1) && (out_dims == r_dims)) || ((r_dims == 1) && (out_dims == l_dims)));
+ }
+ auto &tmp_shape = (lhs_shape.rank() > rhs_shape.rank()) ? lhs_shape : rhs_shape;
+ for (int idx = min_rank + 1; idx <= output_shape.rank(); idx++)
+ {
+ int out_idx = output_shape.rank() - idx;
+ int tmp_idx = tmp_shape.rank() - idx;
+
+ OP_REQUIRES((out_idx >= 0) && (tmp_idx >= 0) &&
+ (output_shape.dim(out_idx) == tmp_shape.dim(tmp_idx)));
+ }
+}
+void ShapeValidator::visit(const ir::operation::Tile &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto input_index{node.getInputs().at(0)};
+ const auto multiple_index{node.getInputs().at(1)};
+
+ OP_REQUIRES(operands.at(multiple_index).shape().rank() == 1);
+ OP_REQUIRES(operands.at(multiple_index).shape().dim(0) ==
+ operands.at(input_index).shape().rank());
+ OP_REQUIRES(operands.at(input_index).shape().rank() == operands.at(output_index).shape().rank());
+}
+
+void ShapeValidator::visit(const ir::operation::Range &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ const auto start_index{node.getInputs().at(ir::operation::Range::Input::START)};
+ const auto limit_index{node.getInputs().at(ir::operation::Range::Input::LIMIT)};
+ const auto delta_index{node.getInputs().at(ir::operation::Range::Input::DELTA)};
+
+ // Check for dimension constraints
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ OP_REQUIRES(operands.at(start_index).shape().rank() == 0);
+ OP_REQUIRES(operands.at(limit_index).shape().rank() == 0);
+ OP_REQUIRES(operands.at(delta_index).shape().rank() == 0);
+}
+
+void ShapeValidator::visit(const ir::operation::MatrixBandPart &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT)};
+ const auto num_lower_index{
+ node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_LOWER_DIAG)};
+ const auto num_upper_index{
+ node.getInputs().at(ir::operation::MatrixBandPart::Input::NUM_UPPER_DIAG)};
+
+ // Check for dimension constraints
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ OP_REQUIRES(operands.at(input_index).shape().rank() >= 2); // input must be more than 2 dim matrix
+ OP_REQUIRES(operands.at(num_upper_index).shape().rank() == 0); // num_lower must be scalar
+ OP_REQUIRES(operands.at(num_lower_index).shape().rank() == 0); // num_upper must be scalar
+}
+
+void ShapeValidator::visit(const ir::operation::LogSoftmax &node)
+{
+ const auto &operands = _graph.operands();
+ const auto output_index{node.getOutputs().at(0)};
+ if (operands.at(output_index).info().isDynamic())
+ return;
+
+ const auto input_index{node.getInputs().at(0)};
+
+ OP_REQUIRES(operands.at(output_index).shape().rank() == operands.at(input_index).shape().rank());
+}
+
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/OperationValidator.h b/runtime/onert/core/src/compiler/ShapeValidator.h
index deb6357bb..da83a432a 100644
--- a/runtime/onert/core/src/compiler/OperationValidator.h
+++ b/runtime/onert/core/src/compiler/ShapeValidator.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef __ONERT_COMPILER_OPERATION_VALIDATOR_H__
-#define __ONERT_COMPILER_OPERATION_VALIDATOR_H__
+#ifndef __ONERT_COMPILER_SHAPE_VALIDATOR_H__
+#define __ONERT_COMPILER_SHAPE_VALIDATOR_H__
#include "ir/Layout.h"
#include "ir/OperationVisitor.h"
@@ -34,19 +34,29 @@ namespace onert
namespace compiler
{
-class OperationValidator : public ir::OperationVisitor
+class ShapeValidator : public ir::OperationVisitor
{
public:
- OperationValidator(void) = delete;
- OperationValidator(const ir::Graph &graph);
+ ShapeValidator(void) = delete;
+ ShapeValidator(const ir::Graph &graph);
+ ShapeValidator(const ShapeValidator &) = delete;
+ ShapeValidator(ShapeValidator &&) = delete;
+ ~ShapeValidator() = default;
public:
+ ShapeValidator &operator=(const ShapeValidator &) = delete;
+ ShapeValidator &operator=(ShapeValidator &&) = delete;
void operator()();
public:
void visit(const ir::operation::BatchMatMul &node) override;
void visit(const ir::operation::BatchToSpaceND &node) override;
+ void visit(const ir::operation::BCQFullyConnected &node) override;
+ void visit(const ir::operation::BCQGather &node) override;
+ void visit(const ir::operation::Conv2D &node) override;
void visit(const ir::operation::Comparison &node) override;
+ void visit(const ir::operation::DepthwiseConv2D &node) override;
+ void visit(const ir::operation::FullyConnected &node) override;
void visit(const ir::operation::Softmax &node) override;
void visit(const ir::operation::InstanceNorm &node) override;
void visit(const ir::operation::Permute &node) override;
@@ -88,13 +98,10 @@ private:
void checkUnaryOp(const ir::Operation &node);
private:
- // TODO Remove _ctx field
const ir::Graph &_graph;
- const ir::Operands &_ctx;
- ir::Layout _current_op_seq_layout;
};
} // namespace compiler
} // namespace onert
-#endif // __ONERT_COMPILER_OPERATION_VALIDATOR_H__
+#endif // __ONERT_COMPILER_SHAPE_VALIDATOR_H__
diff --git a/runtime/onert/core/src/compiler/StaticShapeInference.cc b/runtime/onert/core/src/compiler/StaticShapeInference.cc
deleted file mode 100644
index 4eba1ff49..000000000
--- a/runtime/onert/core/src/compiler/StaticShapeInference.cc
+++ /dev/null
@@ -1,1096 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "compiler/StaticShapeInference.h"
-#include "util/ShapeInference.h"
-#include "util/logging.h"
-
-#include <sstream>
-
-namespace onert
-{
-namespace compiler
-{
-
-bool StaticShapeInferer::infer(const ir::OpSequence &op_seq)
-{
- bool has_dynamic_tensor = false;
-
- for (const auto &operation_idx : op_seq.operations())
- {
- auto &op = _operations.at(operation_idx);
- auto opcode = op.opcode();
-
- _return_has_dynamic_tensor = false; // this is used as a return value inside operation's visit()
-
- // IF: need shape inference for then, else
- // While: need shape inference for condition, body
- if (opcode == ir::OpCode::If || opcode == ir::OpCode::While)
- {
- op.accept(*this);
- }
- else
- {
- _return_has_dynamic_tensor = checkDynamicInput(op);
-
- if (_return_has_dynamic_tensor)
- {
- setDynamicOutput(op);
- }
- else
- {
- op.accept(*this);
- }
- }
-
- has_dynamic_tensor = has_dynamic_tensor || _return_has_dynamic_tensor;
- }
-
- return has_dynamic_tensor;
-}
-
-bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op)
-{
- for (auto input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
- {
- if (_operands.at(input_idx).info().isDynamic())
- {
- return true;
- }
- }
-
- return false;
-}
-
-void StaticShapeInferer::setDynamicOutput(const ir::Operation &op)
-{
- for (auto output_idx : op.getOutputs())
- {
- _operands.at(output_idx).info().setDynamic();
- }
-}
-
-void StaticShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op,
- const ir::OperandIndex lhs_idx,
- const ir::OperandIndex rhs_idx)
-{
- const auto &lhs = _operands.at(lhs_idx);
- const auto &rhs = _operands.at(rhs_idx);
-
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- // re-sizing output shape
- ir::Shape new_shape = shape_inference::inferEltwiseShape(lhs.info().shape(), rhs.info().shape());
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::handleSimpleUnaryOp(const ir::Operation &op,
- const ir::OperandIndex input_idx)
-{
- const auto &input = _operands.at(input_idx);
-
- // get mutable output operand
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- // re-sizing output shape
- ir::Shape new_shape = input.info().shape();
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::dump()
-{
- auto get_shape_str = [](const ir::Shape &shape) {
- std::stringstream sstream;
- sstream << "shape : {";
- for (int i = 0; i < shape.rank(); i++)
- {
- if (i == 0)
- sstream << shape.dim(i);
- else
- sstream << " " << shape.dim(i);
- }
- sstream << "}";
- return sstream.str();
- };
-
- for (const auto &pair : _lowered_subgs)
- {
- const auto index = pair.first;
- const auto &lowered_subg = pair.second;
- VERBOSE(StaticShapeInferer) << "SubGraph #" << index.value() << std::endl;
- lowered_subg->graph().operands().iterate(
- [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
- VERBOSE(StaticShapeInferer) << "Operand #" << ind.value() << ", "
- << (operand.info().isDynamic() ? "Dynamic" : "Static") << ", "
- << get_shape_str(operand.info().shape()) << std::endl;
- });
- }
-}
-
-void StaticShapeInferer::visit(const ir::operation::ArgMax &op)
-{
- const auto input_idx{op.getInputs().at(ir::operation::ArgMax::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
-
- // get mutable output operand
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
- const auto rank = input.info().shape().rank();
- const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
-
- assert(0 <= axis && axis < rank);
-
- // re-sizing output shape
- ir::Shape new_shape = shape_inference::inferArgMaxShape(input.info().shape(), axis, rank);
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::BatchMatMul &op)
-{
- const auto lhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::LHS);
- const auto rhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::RHS);
- const auto output_index = op.getOutputs().at(0);
- const auto lhs = _operands.at(lhs_index);
- const auto rhs = _operands.at(rhs_index);
- auto &output = _operands.at(output_index);
- auto new_shape = shape_inference::inferBatchMatMulShape(lhs.shape(), rhs.shape(), op.param());
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::BinaryArithmetic &op)
-{
- handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::BinaryArithmetic::Input::LHS),
- op.getInputs().at(ir::operation::BinaryArithmetic::Input::RHS));
-}
-
-void StaticShapeInferer::visit(const ir::operation::BroadcastTo &op)
-{
- // get mutable output operand
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- const auto shape_idx{op.getInputs().at(ir::operation::BroadcastTo::Input::SHAPE)};
- const auto &shape = _operands.at(shape_idx);
-
- if (!shape.isConstant())
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- return;
- }
-
- // assert(shape.typeInfo().type() == ir::DataType::INT32);
- auto shape_buffer = reinterpret_cast<const int32_t *>(shape.data()->base());
-
- // re-sizing output shape
- ir::Shape new_shape = shape_inference::inferBroadcastToShape(shape.info().shape(), shape_buffer);
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Comparison &op)
-{
- handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Comparison::Input::INPUT0),
- op.getInputs().at(ir::operation::Comparison::Input::INPUT1));
-}
-
-void StaticShapeInferer::visit(const ir::operation::Concat &op)
-{
- const auto input_count = op.getInputs().size();
-
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- shape_inference::Shapes input_shapes;
- for (uint32_t i = 0; i < input_count; i++)
- {
- const auto input_idx{op.getInputs().at(i)};
- const auto &input = _operands.at(input_idx);
- input_shapes.emplace_back(input.shape());
- }
-
- ir::Shape out_shape = shape_inference::inferConcatShape(input_shapes, op.param());
-
- // re-sizing output shape
- output.info().shape(out_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Conv2D &op)
-{
- const auto input_idx{op.getInputs().at(ir::operation::Conv2D::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
- const auto ker_idx{op.getInputs().at(ir::operation::Conv2D::Input::KERNEL)};
- const auto &ker = _operands.at(ker_idx);
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- // re-sizing output shape
- ir::Shape new_shape =
- shape_inference::inferConv2DShape(input.info().shape(), ker.info().shape(), op.param());
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::ElementwiseActivation &op)
-{
- handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseActivation::Input::INPUT));
-}
-
-void StaticShapeInferer::visit(const ir::operation::ElementwiseBinary &op)
-{
- handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::ElementwiseBinary::Input::LHS),
- op.getInputs().at(ir::operation::ElementwiseBinary::Input::RHS));
-}
-
-void StaticShapeInferer::visit(const ir::operation::ElementwiseUnary &op)
-{
- handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT));
-}
-
-void StaticShapeInferer::visit(const ir::operation::ExpandDims &op)
-{
- const auto input_idx{op.getInputs().at(ir::operation::ExpandDims::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
- const auto axis_idx{op.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
- const auto &axis = _operands.at(axis_idx);
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- if (!axis.isConstant())
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- return;
- }
-
- // even when axis is constant, output shape should be recalculated since user might call
- // nnfw_set_input_tensorinfo(input, some_new_shape)
- auto axis_buf = reinterpret_cast<const int32_t *>(axis.data()->base());
- assert(axis_buf);
-
- // re-sizing output shape
- ir::Shape new_shape = shape_inference::inferExpandDimsShape(input.info().shape(), axis_buf[0]);
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Fill &op)
-{
- const auto input_idx{op.getInputs().at(ir::operation::Fill::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- if (!input.isConstant())
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- return;
- }
-
- assert(input.typeInfo().type() == ir::DataType::INT32);
-
- auto input_buf = reinterpret_cast<const int32_t *>(input.data()->base());
- assert(input_buf);
-
- // re-sizing output shape
- ir::Shape new_shape = shape_inference::inferFillShape(input.info().shape(), input_buf);
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::FullyConnected &op)
-{
- const auto input_idx{op.getInputs().at(ir::operation::FullyConnected::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
-
- const auto ker_idx{op.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)};
- const auto &ker = _operands.at(ker_idx);
-
- // get mutable output operand
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
- // re-sizing output shape
- ir::Shape new_shape =
- shape_inference::inferFullyConnectedShape(input.info().shape(), ker.info().shape());
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::FusedBatchNorm &op)
-{
- handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::FusedBatchNorm::Input::INPUT));
-}
-
-void StaticShapeInferer::visit(const ir::operation::Gather &op)
-{
- const auto input_idx{op.getInputs().at(ir::operation::Gather::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
-
- // get mutable output operand
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- const auto indices_idx{op.getInputs().at(ir::operation::Gather::Input::INDICES)};
- const auto &indices = _operands.at(indices_idx);
- const auto rank = input.info().shape().rank();
- const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
-
- assert(0 <= axis && axis < rank);
-
- // re-sizing output shape
- ir::Shape new_shape =
- shape_inference::inferGatherShape(input.info().shape(), indices.info().shape(), axis, rank);
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::If &op)
-{
- auto &then_graph = _lowered_subgs.at(op.param().then_subg_index)->graph();
- auto &else_graph = _lowered_subgs.at(op.param().else_subg_index)->graph();
- const std::vector<ir::OperandIndex> inputs{op.getInputs().begin() + 1, op.getInputs().end()};
- const auto &outputs = op.getOutputs();
-
- // re-sizing input shapes of then subgraph
- const auto &then_inputs = then_graph.getInputs();
- assert(inputs.size() == then_inputs.size());
- for (size_t i = 0; i < inputs.size(); ++i)
- {
- auto &then_input = then_graph.operands().at(then_inputs.at(i));
- if (_operands.at(inputs.at(i)).info().isDynamic())
- {
- then_input.info().setDynamic();
- }
- else
- {
- auto new_shape = _operands.at(inputs.at(i)).info().shape();
- then_input.info().shape(new_shape);
- }
- }
-
- // re-sizing input shapes of else subgraph
- const auto &else_inputs = else_graph.getInputs();
- assert(inputs.size() == else_inputs.size());
- for (size_t i = 0; i < inputs.size(); ++i)
- {
- auto &else_input = else_graph.operands().at(else_inputs.at(i));
- if (_operands.at(inputs.at(i)).info().isDynamic())
- {
- else_input.info().setDynamic();
- }
- else
- {
- const auto &new_shape = _operands.at(inputs.at(i)).info().shape();
- else_input.info().shape(new_shape);
- }
- }
-
- // re-sizing operands of then subgraph
- StaticShapeInferer then_inferer(op.param().then_subg_index, _lowered_subgs);
- _lowered_subgs.at(op.param().then_subg_index)
- ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
- bool has_dynamic_tensor = then_inferer.infer(op_seq);
- op_seq.has_dynamic_tensor(has_dynamic_tensor);
- });
-
- // re-sizing operands of else subgraph
- StaticShapeInferer else_inferer(op.param().else_subg_index, _lowered_subgs);
- _lowered_subgs.at(op.param().else_subg_index)
- ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
- bool has_dynamic_tensor = else_inferer.infer(op_seq);
- op_seq.has_dynamic_tensor(has_dynamic_tensor);
- });
-
- // re-sizing output shapes
- const auto &then_outputs = _lowered_subgs.at(op.param().then_subg_index)->graph().getOutputs();
- const auto &else_outputs = _lowered_subgs.at(op.param().else_subg_index)->graph().getOutputs();
- assert(outputs.size() == then_outputs.size());
- assert(outputs.size() == else_outputs.size());
- for (size_t i = 0; i < outputs.size(); ++i)
- {
- const auto &then_output = then_graph.operands().at(then_outputs.at(i));
- const auto &else_output = else_graph.operands().at(else_outputs.at(i));
- auto &output = _operands.at(outputs.at(i));
- if (!then_output.info().isDynamic() && !else_output.info().isDynamic() &&
- then_output.shape() == else_output.shape())
- {
- output.info().shape(then_output.shape());
- }
- else
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- }
- }
-}
-
-void StaticShapeInferer::visit(const ir::operation::L2Normalization &op)
-{
- handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::Input::INPUT));
-}
-
-void StaticShapeInferer::visit(const ir::operation::MatrixBandPart &op)
-{
- handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT));
-}
-
-void StaticShapeInferer::visit(const ir::operation::OneHot &op)
-{
- const auto indice_idx{op.getInputs().at(ir::operation::OneHot::Input::INDICES)};
- const auto &indice = _operands.at(indice_idx);
- const auto depth_idx{op.getInputs().at(ir::operation::OneHot::Input::DEPTH)};
- const auto &depth = _operands.at(depth_idx);
-
- const auto axis = op.param().axis;
-
- auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- if (!depth.isConstant())
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- return;
- }
-
- const auto *depth_buf = reinterpret_cast<const int32_t *>(depth.data()->base());
- assert(depth_buf);
- // re-sizing output shape
- ir::Shape new_shape = shape_inference::inferOnehotShape(indice.info().shape(), *depth_buf, axis);
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Pack &op)
-{
- const auto input_idx{op.getInputs().at(0)};
- const auto &input = _operands.at(input_idx);
-
- // get mutable output operand
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- const auto rank = input.shape().rank() + 1;
- const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
- const auto num = op.param().num;
-
- assert(0 <= axis && axis < rank);
-
- // re-sizing output shape
- ir::Shape new_shape = shape_inference::inferPackShape(input.info().shape(), axis, rank, num);
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Pad &op)
-{
- const auto input_idx{op.getInputs().at(ir::operation::Pad::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
-
- const auto pad_idx{op.getInputs().at(ir::operation::Pad::Input::PAD)};
- const auto &pad = _operands.at(pad_idx);
-
- // get mutable output operand
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- // if pad is not constant, output also becomes dynamic
- if (!pad.isConstant())
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- return;
- }
-
- // re-sizing output shape
- const auto new_shape = shape_inference::inferPadShape(
- input.shape(), reinterpret_cast<const int32_t *>(pad.data()->base()),
- pad.shape().num_elements());
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Permute &op)
-{
- const auto input_idx{op.getInputs().at(0)};
- const auto &input = _operands.at(input_idx);
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- // re-sizing output shape
- // Permute is a special operation that layouts of input/output may be different on backend
- // However, it is not applied here, so input/output have the same layout of frontend. Because
- // "ExecutorFactory" would convert shape of input/output accoding to the layouts when registering
- // operand info to "TensorBuilder" after calling "StaticShapeInferer"
- const auto new_shape = input.info().shape();
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Pow &op)
-{
- handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Pow::Input::LHS),
- op.getInputs().at(ir::operation::Pow::Input::RHS));
-}
-
-void StaticShapeInferer::visit(const ir::operation::Range &op)
-{
- const auto start_idx{op.getInputs().at(ir::operation::Range::Input::START)};
- const auto limit_idx{op.getInputs().at(ir::operation::Range::Input::LIMIT)};
- const auto delta_idx{op.getInputs().at(ir::operation::Range::Input::DELTA)};
- const auto &start_op = _operands.at(start_idx);
- const auto &limit_op = _operands.at(limit_idx);
- const auto &delta_op = _operands.at(delta_idx);
-
- // get mutable output operand
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- ir::Shape new_shape;
- if (start_op.isConstant() && limit_op.isConstant() && delta_op.isConstant())
- {
- assert(start_op.typeInfo().type() == limit_op.typeInfo().type() &&
- start_op.typeInfo().type() == delta_op.typeInfo().type());
- if (output.typeInfo().type() == ir::DataType::FLOAT32)
- {
- new_shape = shape_inference::inferRangeShape<float>(
- start_op.asScalar<float>(), limit_op.asScalar<float>(), delta_op.asScalar<float>());
- }
- else if (output.typeInfo().type() == ir::DataType::INT32)
- {
- new_shape = shape_inference::inferRangeShape<int32_t>(
- start_op.asScalar<int32_t>(), limit_op.asScalar<int32_t>(), delta_op.asScalar<int32_t>());
- }
- assert(output.shape() == new_shape);
- }
- else
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- }
-}
-
-void StaticShapeInferer::visit(const ir::operation::Reduce &op)
-{
- const auto input_idx{op.getInputs().at(ir::operation::Reduce::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
-
- const auto axes_idx{op.getInputs().at(ir::operation::Reduce::Input::AXES)};
- const auto &axes = _operands.at(axes_idx);
-
- // get mutable output operand
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- std::vector<int32_t> axes_vec;
- for (size_t i = 0; i < axes.shape().num_elements(); ++i)
- {
- switch (axes.typeInfo().type())
- {
- case ir::DataType::INT32:
- {
- axes_vec.emplace_back(reinterpret_cast<const int32_t *>(axes.data()->base())[i]);
- break;
- }
- case ir::DataType::INT64:
- {
- axes_vec.emplace_back(reinterpret_cast<const int64_t *>(axes.data()->base())[i]);
- break;
- }
- default:
- throw std::runtime_error("StaticShapeInferer " + op.name() + ": Not supported data type");
- break;
- }
- }
- const auto keep_dims = op.param().keep_dims;
-
- // re-sizing output shape
- ir::Shape new_shape =
- shape_inference::inferReduceShape(input.info().shape(), axes_vec, keep_dims);
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Reshape &op)
-{
- const auto input_idx{op.getInputs().at(ir::operation::Reshape::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
-
- // get mutable output operand
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- // New shape is given by second input tensor
- if (op.getInputs().size() == 2)
- {
- // Let's check the second input
- const auto shape_idx{op.getInputs().at(ir::operation::Reshape::Input::SHAPE)};
- const auto &shape = _operands.at(shape_idx);
-
- if (shape.isConstant())
- {
- const auto *shape_buf = reinterpret_cast<const int32_t *>(shape.data()->base());
- assert(shape_buf);
-
- ir::Shape new_shape = shape_inference::inferReshapeShape(
- shape_buf, shape.shape().num_elements(), input.shape().num_elements());
-
- // if shape is from Const, TFLC put the shape of output into tensor
- if (new_shape != output.shape())
- {
- // change on output shape
- output.info().shape(new_shape);
- }
- }
- else
- {
- // if shape is NOT Const, set output shape to be dynamic_
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- }
- }
- // New shape is given by option
- else if (op.param().new_shape.size() != 0)
- {
- // Let's check the new_shape option
- auto shape = op.param().new_shape;
- ir::Shape new_shape = shape_inference::inferReshapeShape(shape.data(), shape.size(),
- input.shape().num_elements());
-
- if (new_shape != output.shape())
- {
- // change on output shape
- output.info().shape(new_shape);
- }
- }
- else
- {
- throw std::runtime_error("Reshape: new shape is missing");
- }
-}
-
-void StaticShapeInferer::visit(const ir::operation::ResizeBilinear &op)
-{
- const auto input_idx{op.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
-
- // get mutable output operand
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- // Shape inferencing logic based on Params
- ir::Shape new_shape = shape_inference::inferResizeBilinearShape(
- input.shape(), op.param().height_out, op.param().width_out);
-
- // if size_op is from Const, TFLC put the shape of output into tensor
- if (new_shape != output.shape())
- {
- // change on output shape
- output.info().shape(new_shape);
- }
-}
-
-void StaticShapeInferer::visit(const ir::operation::Reverse &op)
-{
- handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Reverse::Input::INPUT));
-}
-
-void StaticShapeInferer::visit(const ir::operation::Select &op)
-{
- const auto input_cond_idx{op.getInputs().at(ir::operation::Select::Input::CONDITION)};
- const auto &input_cond = _operands.at(input_cond_idx);
-
- const auto input_true_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)};
- const auto &input_true = _operands.at(input_true_idx);
-
- const auto input_false_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)};
- const auto &input_false = _operands.at(input_false_idx);
-
- auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- // Select output shpae
- ir::Shape new_shape = shape_inference::inferSelectShape(
- input_cond.info().shape(), input_true.info().shape(), input_false.info().shape());
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Shape &op)
-{
- const auto input_idx{op.getInputs().at(0)};
- const auto &input = _operands.at(input_idx);
-
- // get mutable output operand
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- // re-sizing output shape
- ir::Shape output_shape;
- output_shape.append(input.info().shape().rank());
-
- output.info().shape(output_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Slice &op)
-{
- const auto input_index{op.getInputs().at(ir::operation::Slice::Input::INPUT)};
- const auto &input = _operands.at(input_index);
- const auto begins_index{op.getInputs().at(ir::operation::Slice::Input::BEGINS)};
- const auto &begins = _operands.at(begins_index);
- const auto sizes_index{op.getInputs().at(ir::operation::Slice::Input::SIZES)};
- const auto &sizes = _operands.at(sizes_index);
- const auto output_index = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_index);
-
- // Whether input is constant or not does not affect whether output is dynamic or not
- if (!(begins.isConstant() && sizes.isConstant()))
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- return;
- }
-
- auto begins_buf = reinterpret_cast<const int32_t *>(begins.data()->base());
- auto sizes_buf = reinterpret_cast<const int32_t *>(sizes.data()->base());
-
- ir::Shape new_shape =
- shape_inference::inferSliceShape(input.info().shape(), begins_buf, sizes_buf);
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Softmax &op)
-{
- handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Softmax::Input::INPUT));
-}
-
-void StaticShapeInferer::visit(const ir::operation::SpaceToBatchND &op)
-{
- const auto output_index = op.getOutputs().at(0);
- const auto input_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
- const auto block_shape_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
- const auto padding_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
-
- ir::Operand &output = _operands.at(output_index);
- const auto &input = _operands.at(input_idx);
- const auto &block_shape = _operands.at(block_shape_idx);
- const auto &padding = _operands.at(padding_idx);
-
- // Whether input is constant or not does not affect whether output is dynamic or not
- if (!(block_shape.isConstant() && padding.isConstant()))
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- return;
- }
-
- auto input_shape = input.info().shape();
- auto block_shape_shape = block_shape.info().shape();
- auto padding_shape = padding.info().shape();
-
- auto block_shape_data = reinterpret_cast<const int32_t *>(block_shape.data()->base());
- auto padding_data = reinterpret_cast<const int32_t *>(padding.data()->base());
-
- ir::Shape new_shape = shape_inference::inferSpaceToBatchNDShape(
- input_shape, block_shape_shape, padding_shape, block_shape_data, padding_data);
-
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Split &op)
-{
- const auto input_idx{op.getInputs().at(0)};
- const auto &input = _operands.at(input_idx);
-
- const auto axis = op.param().axis;
- const auto num_splits = op.param().num_splits;
-
- const auto rank = input.info().shape().rank();
- auto axis_resolved = axis < 0 ? axis + rank : axis;
-
- assert(0 <= axis_resolved && axis_resolved < rank);
-
- ir::Shape new_shape =
- shape_inference::inferSplitShape(input.info().shape(), axis_resolved, num_splits);
- auto output_tensors = op.getOutputs();
- for (auto output_idx : output_tensors)
- {
- ir::Operand &output = _operands.at(output_idx);
- output.info().shape(new_shape);
- }
-}
-
-void StaticShapeInferer::visit(const ir::operation::SquaredDifference &op)
-{
- handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::SquaredDifference::Input::LHS),
- op.getInputs().at(ir::operation::SquaredDifference::Input::RHS));
-}
-
-void StaticShapeInferer::visit(const ir::operation::Squeeze &op)
-{
- const auto input_idx{op.getInputs().at(ir::operation::Squeeze::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
-
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- if (input.info().isDynamic())
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- return;
- }
-
- // Squeeze output shpae
- ir::Shape new_shape = shape_inference::inferSqueezeShape(input.info().shape(), op.param());
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::StridedSlice &op)
-{
- const auto input_index{op.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
- const auto &input = _operands.at(input_index);
- const auto starts_index{op.getInputs().at(ir::operation::StridedSlice::Input::STARTS)};
- const auto &starts = _operands.at(starts_index);
- const auto ends_index{op.getInputs().at(ir::operation::StridedSlice::Input::ENDS)};
- const auto &ends = _operands.at(ends_index);
- const auto strides_index{op.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)};
- const auto &strides = _operands.at(strides_index);
- const auto output_index = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_index);
-
- if (!(starts.isConstant() && ends.isConstant() && strides.isConstant()))
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- return;
- }
-
- const auto begin_mask = op.param().begin_mask;
- const auto end_mask = op.param().end_mask;
- const auto shrink_axis_mask = op.param().shrink_axis_mask;
- const auto rank = input.info().shape().rank();
-
- auto starts_buf = reinterpret_cast<const uint32_t *>(starts.data()->base());
- auto ends_buf = reinterpret_cast<const uint32_t *>(ends.data()->base());
- auto strides_buf = reinterpret_cast<const uint32_t *>(strides.data()->base());
-
- auto op_params = shape_inference::buildStridedSliceParams(
- starts_buf, ends_buf, strides_buf, begin_mask, end_mask, shrink_axis_mask, rank);
-
- ir::Shape new_shape =
- shape_inference::inferStridedSliceShape(input.info().shape(), op_params, rank);
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Tile &op)
-{
- const auto input_idx{op.getInputs().at(ir::operation::Tile::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
-
- const auto multiplier_idx{op.getInputs().at(ir::operation::Tile::Input::MULTIPLES)};
- const auto &multiplier = _operands.at(multiplier_idx);
-
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
-
- if (!multiplier.isConstant())
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- return;
- }
-
- auto multiplier_buffer = reinterpret_cast<const int32_t *>(multiplier.data()->base());
- assert(multiplier_buffer);
-
- // re-sizing output shape
- auto new_shape = shape_inference::inferTileShape(input.info().shape(), multiplier_buffer);
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Transpose &op)
-{
- const auto input_idx{op.getInputs().at(ir::operation::Transpose::Input::INPUT)};
- const auto &input = _operands.at(input_idx);
-
- // get mutable output operand
- const auto output_idx = op.getOutputs().at(0);
- ir::Operand &output = _operands.at(output_idx);
- const auto perm{op.param().perm};
- // const auto rank{op.param().rank};
-
- // set output shape, based on input and params
- ir::Shape new_shape = shape_inference::inferTransposeShape(input.info().shape(), perm);
- output.info().shape(new_shape);
-}
-
-void StaticShapeInferer::visit(const ir::operation::Unpack &op)
-{
- const auto input_idx{op.getInputs().at(0)};
- const auto &input = _operands.at(input_idx);
- const auto num = op.param().num;
- const auto rank = input.shape().rank();
- const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
-
- assert(axis < rank);
- if (axis < 0)
- {
- for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
- {
- const auto output_idx = op.getOutputs().at(out_tensor_idx);
- ir::Operand &output = _operands.at(output_idx);
- output.info().setDynamic();
- }
- _return_has_dynamic_tensor = true;
- return;
- }
-
- ir::Shape new_shape = shape_inference::inferUnpackShape(input.info().shape(), axis, rank);
-
- // re-sizing output shape
- for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
- {
- const auto output_idx = op.getOutputs().at(out_tensor_idx);
- ir::Operand &output = _operands.at(output_idx);
- output.info().shape(new_shape);
- }
-}
-
-void StaticShapeInferer::visit(const ir::operation::While &op)
-{
- auto &cond_graph = _lowered_subgs.at(op.param().cond_subg_index)->graph();
- auto &body_graph = _lowered_subgs.at(op.param().body_subg_index)->graph();
- const auto inputs = op.getInputs();
- const auto &outputs = op.getOutputs();
-
- // re-sizing input shapes of then subgraph
- const auto &cond_inputs = cond_graph.getInputs();
- assert(inputs.size() == cond_inputs.size());
- for (size_t i = 0; i < inputs.size(); ++i)
- {
- const auto &input = _operands.at(inputs.at(i));
- auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
- if (input.info().isDynamic())
- {
- cond_input.info().setDynamic();
- }
- else
- {
- auto new_shape = input.info().shape();
- cond_input.info().shape(new_shape);
- }
- }
-
- // re-sizing input shapes of body subgraph
- const auto &body_inputs = body_graph.getInputs();
- assert(cond_inputs.size() == body_inputs.size());
- for (size_t i = 0; i < cond_inputs.size(); ++i)
- {
- const auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
- auto &body_input = body_graph.operands().at(body_inputs.at(i));
- if (cond_input.info().isDynamic())
- {
- body_input.info().setDynamic();
- }
- else
- {
- const auto &new_shape = cond_input.info().shape();
- body_input.info().shape(new_shape);
- }
- }
-
- // re-sizing operands of body subgraph
- StaticShapeInferer body_inferer(op.param().body_subg_index, _lowered_subgs);
- _lowered_subgs.at(op.param().body_subg_index)
- ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
- bool has_dynamic_tensor = body_inferer.infer(op_seq);
- op_seq.has_dynamic_tensor(has_dynamic_tensor);
- });
-
- // Check whether while operation's shapes are predictable
- // If any of shape of body outputs and cond inputs are different, non-constant operands would be
- // set to dynamic
- bool check_unpredictable_dynamic = false;
- const auto &body_outputs = body_graph.getOutputs();
- assert(body_outputs.size() == cond_inputs.size());
- for (size_t i = 0; i < body_outputs.size(); ++i)
- {
- const auto &body_output = body_graph.operands().at(body_outputs.at(i));
- auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
- if ((cond_input.info().isDynamic() != body_output.info().isDynamic()) ||
- (cond_input.shape() != body_output.shape()))
- {
- check_unpredictable_dynamic = true;
- break;
- }
- }
-
- if (check_unpredictable_dynamic)
- {
- // Set inputs of body subgraph
- for (const auto &input_index : body_inputs)
- {
- auto &input = body_graph.operands().at(input_index);
- if (!input.isConstant())
- {
- input.info().setDynamic();
- }
- }
-
- // Set inputs of cond subgraph
- for (const auto &input_index : cond_inputs)
- {
- auto &input = cond_graph.operands().at(input_index);
- if (!input.isConstant())
- {
- input.info().setDynamic();
- }
- }
-
- // Set non-constant operands of body subgraph to dynamic
- StaticShapeInferer body_inferer(op.param().body_subg_index, _lowered_subgs);
- _lowered_subgs.at(op.param().body_subg_index)
- ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
- bool has_dynamic_tensor = body_inferer.infer(op_seq);
- op_seq.has_dynamic_tensor(has_dynamic_tensor);
- });
- }
-
- // re-sizing operands of cond subgraph
- // If check_unpredictable_dynamic is true, non-constant operands of cond subgraph would be set to
- // dynamic
- StaticShapeInferer cond_inferer(op.param().cond_subg_index, _lowered_subgs);
- _lowered_subgs.at(op.param().cond_subg_index)
- ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
- bool has_dynamic_tensor = cond_inferer.infer(op_seq);
- op_seq.has_dynamic_tensor(has_dynamic_tensor);
- });
-
- // re-sizing outputs of while operation
- // If check_unpredictable_dynamic is true, outputs of while operation would be set to dynamic
- assert(cond_inputs.size() == outputs.size());
- for (size_t i = 0; i < cond_inputs.size(); ++i)
- {
- const auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
- auto &output = _operands.at(outputs.at(i));
- if (cond_input.info().isDynamic())
- {
- output.info().setDynamic();
- _return_has_dynamic_tensor = true;
- }
- else
- {
- const auto new_shape = cond_input.info().shape();
- output.info().shape(new_shape);
- }
- }
-}
-
-} // namespace compiler
-
-} // namespace onert
diff --git a/runtime/onert/core/src/compiler/StaticShapeInferer.cc b/runtime/onert/core/src/compiler/StaticShapeInferer.cc
new file mode 100644
index 000000000..622edbab4
--- /dev/null
+++ b/runtime/onert/core/src/compiler/StaticShapeInferer.cc
@@ -0,0 +1,1487 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "compiler/StaticShapeInferer.h"
+#include "util/ShapeInference.h"
+#include "util/logging.h"
+
+#include <misc/polymorphic_downcast.h>
+
+#include <sstream>
+#include <stdexcept>
+
+namespace onert
+{
+namespace compiler
+{
+void OperandObserver::updateShapes(const std::vector<ir::OperandInfo> &changed_operands_info,
+ bool unpredictable)
+{
+ assert(changed_operands_info.size() == _operands.size());
+ for (size_t i = 0; i < changed_operands_info.size(); ++i)
+ {
+ const auto &changed_operand_info = changed_operands_info.at(i);
+ auto &operand = _operands.at(i);
+ // assert(changed_operand_info.typeInfo() == operand->typeInfo());
+ // assert(changed_operand_info.typeInfo() == operand->typeInfo());
+ // This error check may by replaced by an assertion if this function is called after the
+ // validation of models are completed.
+ if (changed_operand_info.typeInfo() != operand->typeInfo())
+ {
+ throw std::runtime_error("OperandObserver: The types of operands are mismatched");
+ }
+ if (!operand->info().isConstant() && (changed_operand_info.isDynamic() || unpredictable))
+ {
+ operand->info().setDynamic();
+ }
+ else
+ {
+ const auto &new_shape = changed_operands_info.at(i).shape();
+ operand->info().shape(new_shape);
+ }
+ }
+}
+
+void StaticShapeInferer::infer()
+{
+ for (const auto &op_idx : _lowered_subg->graph().topolSortOperations())
+ {
+ const auto &op = _lowered_subg->graph().operations().at(op_idx);
+ bool has_dynamic_tensor = false;
+ const auto opcode = op.opcode();
+ // IF: requires shape inference for then, else
+ // While: requires shape inference for condition, body
+ if (opcode == ir::OpCode::If || opcode == ir::OpCode::While)
+ {
+ op.accept(*this);
+ }
+ else
+ {
+ has_dynamic_tensor = checkDynamicInput(op);
+ if (has_dynamic_tensor)
+ {
+ setDynamicOutput(op);
+ }
+ else
+ {
+ op.accept(*this);
+ }
+ }
+ has_dynamic_tensor = has_dynamic_tensor || checkDynamicOutput(op);
+ _lowered_subg->setHasDynamicTensor(op_idx, has_dynamic_tensor);
+ }
+
+ if (_controlflow_output_observer != nullptr)
+ {
+ // re-sizing output shapes of the controflow operation branching to this subgraph
+ std::vector<ir::OperandInfo> outputs_info;
+ const auto &graph = _lowered_subg->graph();
+ const auto &outputs = graph.getOutputs();
+ for (size_t i = 0; i < outputs.size(); ++i)
+ {
+ const auto &operand_info = graph.operands().at(outputs.at(i)).info();
+ outputs_info.emplace_back(operand_info);
+ }
+ _controlflow_output_observer->updateShapes(outputs_info);
+ }
+}
+
+bool StaticShapeInferer::checkDynamicInput(const ir::IOperation &op)
+{
+ const auto &operands = _lowered_subg->graph().operands();
+ for (auto &&input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
+ {
+ if (operands.at(input_idx).info().isDynamic())
+ {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+bool StaticShapeInferer::checkDynamicOutput(const ir::IOperation &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+ for (auto &&output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
+ {
+ if (operands.at(output_idx).info().isDynamic())
+ {
+ return true;
+ }
+ }
+ return false;
+}
+
+void StaticShapeInferer::setDynamicOutput(const ir::IOperation &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+ for (auto &&output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
+ {
+ operands.at(output_idx).info().setDynamic();
+ }
+}
+
+void StaticShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op,
+ const ir::OperandIndex lhs_idx,
+ const ir::OperandIndex rhs_idx)
+{
+ auto &operands = _lowered_subg->graph().operands();
+ const auto &lhs = operands.at(lhs_idx);
+ const auto &rhs = operands.at(rhs_idx);
+
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ // re-sizing output shape
+ ir::Shape new_shape = shape_inference::inferEltwiseShape(lhs.info().shape(), rhs.info().shape());
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::handleSimpleUnaryOp(const ir::Operation &op,
+ const ir::OperandIndex input_idx)
+{
+ auto &operands = _lowered_subg->graph().operands();
+ const auto &input = operands.at(input_idx);
+
+ // get mutable output operand
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ // re-sizing output shape
+ ir::Shape new_shape = input.info().shape();
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::dump()
+{
+ auto get_shape_str = [](const ir::Shape &shape) {
+ std::stringstream sstream;
+ sstream << "shape : {";
+ for (int i = 0; i < shape.rank(); i++)
+ {
+ if (i == 0)
+ sstream << shape.dim(i);
+ else
+ sstream << " " << shape.dim(i);
+ }
+ sstream << "}";
+ return sstream.str();
+ };
+
+ _lowered_subg->graph().operands().iterate(
+ [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
+ VERBOSE(StaticShapeInferer) << " " << ind << ", "
+ << (operand.info().isDynamic() ? "Dynamic" : "Static") << ", "
+ << get_shape_str(operand.info().shape()) << std::endl;
+ });
+}
+
+std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>>
+StaticShapeInferer::createStaticShapeInferers(
+ const std::unordered_map<ir::SubgraphIndex, ILoweredGraph *> &lowered_subgs)
+{
+ // Allocate StaticShapeInferer per each subgraph
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers;
+ for (auto &&pair : lowered_subgs)
+ {
+ const auto &subg_index = pair.first;
+ auto &lowered_subg = pair.second;
+ inferers[subg_index] = std::make_unique<StaticShapeInferer>(lowered_subg);
+ }
+
+ // Append observers in all StaticShapeInferers
+ for (auto &&pair : lowered_subgs)
+ {
+ const auto &subg_index = pair.first;
+ auto &lowered_subg = pair.second;
+
+ // TODO: Change this iteration for all to controlflow iteration
+ lowered_subg->graph().operations().iterate(
+ [&](const ir::OperationIndex &, const ir::IOperation &op) {
+ // A Function to append child inferers. These make it possible for a StaticShapeInferer to
+ // call StaticShapeInferes of child subgraphs recursively
+ auto appendChildInferer = [&](const ir::SubgraphIndex &child_subg_idx) {
+ auto *child_inferer = inferers.at(child_subg_idx).get();
+ inferers.at(subg_index)->appendChildInferer(child_subg_idx, child_inferer);
+ };
+
+ // A Function to appaend subg input observers. This makes it possible for a
+ // StaticShapeInferer to update inputs of child subgraphs
+ auto appendSubgraphInputObserver = [&](const ir::SubgraphIndex &child_subg_idx) {
+ std::vector<ir::Operand *> child_subg_inputs;
+ auto &child_subg = lowered_subgs.at(child_subg_idx)->graph();
+ for (const auto &input_idx : child_subg.getInputs())
+ {
+ auto operand_ptr = child_subg.operands().getRawPtr(input_idx);
+ child_subg_inputs.emplace_back(operand_ptr);
+ }
+ inferers.at(subg_index)
+ ->appendSubgInputObserver(child_subg_idx,
+ std::make_unique<OperandObserver>(child_subg_inputs));
+ };
+
+ // A Function to set controlflow output observers. This makes it possible for a
+ // StaticShapeInferer to update outputs of parent controlflow opeerations
+ auto setControlFlowOutputObserver = [&](const ir::SubgraphIndex &child_subg_idx) {
+ std::vector<ir::Operand *> cf_outputs;
+ auto &subg = lowered_subg->graph();
+ for (const auto &output_idx : op.getOutputs())
+ {
+ auto operand_ptr = subg.operands().getRawPtr(output_idx);
+ cf_outputs.emplace_back(operand_ptr);
+ }
+ inferers.at(child_subg_idx)
+ ->setControlflowOutputObserver(std::make_unique<OperandObserver>(cf_outputs));
+ };
+
+ // Append Observers in a StaticShapeInferer
+ if (op.opcode() == ir::OpCode::If)
+ {
+ // TODO Remove dynamic_cast
+ // An virtual base class cannot be downcasted by static_cast
+ try
+ {
+ const auto &if_op = dynamic_cast<const ir::operation::If &>(op);
+
+ appendChildInferer(if_op.param().then_subg_index);
+ appendChildInferer(if_op.param().else_subg_index);
+
+ appendSubgraphInputObserver(if_op.param().then_subg_index);
+ appendSubgraphInputObserver(if_op.param().else_subg_index);
+
+ setControlFlowOutputObserver(if_op.param().then_subg_index);
+ }
+ catch (const std::bad_cast &)
+ {
+ throw std::runtime_error("StaticShapeInferer: Invalid If operation");
+ }
+ }
+ else if (op.opcode() == ir::OpCode::While)
+ {
+ // TODO Remove dynamic_cast
+ try
+ {
+ const auto &while_op = dynamic_cast<const ir::operation::While &>(op);
+
+ appendChildInferer(while_op.param().cond_subg_index);
+ appendChildInferer(while_op.param().body_subg_index);
+
+ appendSubgraphInputObserver(while_op.param().cond_subg_index);
+ appendSubgraphInputObserver(while_op.param().body_subg_index);
+
+ setControlFlowOutputObserver(while_op.param().body_subg_index);
+ }
+ catch (const std::bad_cast &)
+ {
+ throw std::runtime_error("StaticShapeInferer: Invalid While operation");
+ }
+ }
+ });
+ }
+
+ return inferers;
+}
+
+void StaticShapeInferer::visit(const ir::operation::ArgMinMax &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+
+ const auto axis_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::AXIS)};
+ const auto &axis = operands.at(axis_idx);
+
+ // get mutable output operand
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ if (!axis.isConstant())
+ {
+ output.info().setDynamic();
+ return;
+ }
+
+ const auto rank = input.info().shape().rank();
+ auto axis_value = axis.asScalar<int32_t>();
+ axis_value = axis_value < 0 ? axis_value + rank : axis_value;
+
+ // re-sizing output shape
+ ir::Shape new_shape =
+ shape_inference::inferArgMinMaxShape(input.info().shape(), axis_value, rank);
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::BatchMatMul &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto lhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::LHS);
+ const auto rhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::RHS);
+ const auto output_index = op.getOutputs().at(0);
+ const auto &lhs = operands.at(lhs_index);
+ const auto &rhs = operands.at(rhs_index);
+ auto &output = operands.at(output_index);
+ auto new_shape = shape_inference::inferBatchMatMulShape(lhs.shape(), rhs.shape(), op.param());
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::BCQFullyConnected &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+
+ const auto cluster_idx{
+ op.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)};
+ const auto &cluster = operands.at(cluster_idx);
+
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ auto cluster_buf = reinterpret_cast<const int32_t *>(cluster.data()->base());
+ assert(cluster_buf);
+
+ // re-sizing output shape
+ ir::Shape new_shape = shape_inference::inferBCQFullyConnectedShape(
+ input.info().shape(), cluster.info().shape(), cluster_buf);
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::BCQGather &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto indices_idx{op.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
+ const auto &indices = operands.at(indices_idx);
+
+ const auto input_binary_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
+ const auto &input_binary = operands.at(input_binary_idx);
+
+ const auto cluster_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)};
+ const auto &cluster = operands.at(cluster_idx);
+
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ auto cluster_buf = reinterpret_cast<const int32_t *>(cluster.data()->base());
+ assert(cluster_buf);
+
+ auto rank = input_binary.shape().rank();
+
+ // re-sizing output shape
+ ir::Shape new_shape = shape_inference::inferBCQGatherShape(
+ indices.info().shape(), cluster.info().shape(), cluster_buf, rank, op.param());
+
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::BinaryArithmetic &op)
+{
+ handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::BinaryArithmetic::Input::LHS),
+ op.getInputs().at(ir::operation::BinaryArithmetic::Input::RHS));
+}
+
+void StaticShapeInferer::visit(const ir::operation::BroadcastTo &op)
+{
+ // get mutable output operand
+ auto &operands = _lowered_subg->graph().operands();
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ const auto shape_idx{op.getInputs().at(ir::operation::BroadcastTo::Input::SHAPE)};
+ const auto &shape = operands.at(shape_idx);
+
+ if (!shape.isConstant())
+ {
+ output.info().setDynamic();
+ return;
+ }
+
+ // assert(shape.typeInfo().type() == ir::DataType::INT32);
+ auto shape_buffer = reinterpret_cast<const int32_t *>(shape.data()->base());
+
+ // re-sizing output shape
+ ir::Shape new_shape = shape_inference::inferBroadcastToShape(shape.info().shape(), shape_buffer);
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Comparison &op)
+{
+ handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Comparison::Input::INPUT0),
+ op.getInputs().at(ir::operation::Comparison::Input::INPUT1));
+}
+
+void StaticShapeInferer::visit(const ir::operation::Concat &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_count = op.getInputs().size();
+
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ shape_inference::Shapes input_shapes;
+ for (uint32_t i = 0; i < input_count; i++)
+ {
+ const auto input_idx{op.getInputs().at(i)};
+ const auto &input = operands.at(input_idx);
+ input_shapes.emplace_back(input.shape());
+ }
+
+ ir::Shape out_shape = shape_inference::inferConcatShape(input_shapes, op.param());
+
+ // re-sizing output shape
+ output.info().shape(out_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Conv2D &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::Conv2D::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+ const auto ker_idx{op.getInputs().at(ir::operation::Conv2D::Input::KERNEL)};
+ const auto &ker = operands.at(ker_idx);
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ // re-sizing output shape
+ ir::Shape new_shape =
+ shape_inference::inferConv2DShape(input.info().shape(), ker.info().shape(), op.param());
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::DepthwiseConv2D &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::DepthwiseConv2D::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+ const auto ker_idx{op.getInputs().at(ir::operation::DepthwiseConv2D::Input::KERNEL)};
+ const auto &ker = operands.at(ker_idx);
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ // re-sizing output shape
+ ir::Shape new_shape = shape_inference::inferDepthwiseConv2DShape(input.info().shape(),
+ ker.info().shape(), op.param());
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::ElementwiseActivation &op)
+{
+ handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseActivation::Input::INPUT));
+}
+
+void StaticShapeInferer::visit(const ir::operation::ElementwiseBinary &op)
+{
+ handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::ElementwiseBinary::Input::LHS),
+ op.getInputs().at(ir::operation::ElementwiseBinary::Input::RHS));
+}
+
+void StaticShapeInferer::visit(const ir::operation::ElementwiseUnary &op)
+{
+ handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT));
+}
+
+void StaticShapeInferer::visit(const ir::operation::ExpandDims &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::ExpandDims::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+ const auto axis_idx{op.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
+ const auto &axis = operands.at(axis_idx);
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ if (!axis.isConstant())
+ {
+ output.info().setDynamic();
+ return;
+ }
+
+ // even when axis is constant, output shape should be recalculated since user might call
+ // nnfw_set_input_tensorinfo(input, some_new_shape)
+ auto axis_type = axis.typeInfo().type();
+ assert(axis_type == ir::DataType::INT32 || axis_type == ir::DataType::INT64);
+
+ assert(axis.data()->base());
+ int32_t axis_value =
+ (axis_type == ir::DataType::INT32)
+ ? reinterpret_cast<const int32_t *>(axis.data()->base())[0]
+ : static_cast<int32_t>(reinterpret_cast<const int64_t *>(axis.data()->base())[0]);
+
+ // re-sizing output shape
+ ir::Shape new_shape = shape_inference::inferExpandDimsShape(input.info().shape(), axis_value);
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Fill &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto shape_idx{op.getInputs().at(ir::operation::Fill::Input::SHAPE)};
+ const auto &shape = operands.at(shape_idx);
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ if (!shape.isConstant())
+ {
+ output.info().setDynamic();
+ return;
+ }
+
+ const auto dims_type = shape.typeInfo().type();
+ assert(dims_type == ir::DataType::INT32 || dims_type == ir::DataType::INT64);
+
+ auto dims_buf = shape.data()->base();
+ assert(dims_buf);
+
+ const auto &dims_shape = shape.info().shape();
+ const auto &new_shape = ((dims_type == ir::DataType::INT32)
+ ? shape_inference::inferFillShape<int32_t>(
+ dims_shape, reinterpret_cast<const int32_t *>(dims_buf))
+ : shape_inference::inferFillShape<int64_t>(
+ dims_shape, reinterpret_cast<const int64_t *>(dims_buf)));
+
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::FullyConnected &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::FullyConnected::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+
+ const auto ker_idx{op.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)};
+ const auto &ker = operands.at(ker_idx);
+
+ // get mutable output operand
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+ // re-sizing output shape
+ ir::Shape new_shape =
+ shape_inference::inferFullyConnectedShape(input.info().shape(), ker.info().shape());
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::FusedBatchNorm &op)
+{
+ handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::FusedBatchNorm::Input::INPUT));
+}
+
+void StaticShapeInferer::visit(const ir::operation::Gather &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::Gather::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+
+ // get mutable output operand
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ const auto indices_idx{op.getInputs().at(ir::operation::Gather::Input::INDICES)};
+ const auto &indices = operands.at(indices_idx);
+ const auto rank = input.info().shape().rank();
+ const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
+
+ assert(0 <= axis && axis < rank);
+
+ // re-sizing output shape
+ ir::Shape new_shape =
+ shape_inference::inferGatherShape(input.info().shape(), indices.info().shape(), axis, rank);
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::If &op)
+{
+ // re-sizing input shapes of then/else subgraph
+ const std::vector<ir::OperandIndex> inputs{op.getInputs().begin() + 1, op.getInputs().end()};
+
+ std::vector<ir::OperandInfo> inputs_info;
+ const auto &graph = _lowered_subg->graph();
+ for (size_t i = 0; i < inputs.size(); ++i)
+ {
+ const auto &operand_info = graph.operands().at(inputs.at(i)).info();
+ inputs_info.emplace_back(operand_info);
+ }
+ _subg_input_observers.at(op.param().then_subg_index)->updateShapes(inputs_info);
+ _child_inferers.at(op.param().then_subg_index)->infer();
+
+ _subg_input_observers.at(op.param().else_subg_index)->updateShapes(inputs_info);
+ _child_inferers.at(op.param().else_subg_index)->infer();
+}
+
+void StaticShapeInferer::visit(const ir::operation::L2Normalization &op)
+{
+ handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::Input::INPUT));
+}
+
+void StaticShapeInferer::visit(const ir::operation::Loss &op)
+{
+ // TODO Consider SparseCategoricalCrossentropy case
+
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_index{op.getInputs().at(ir::operation::Loss::Input::Y_PRED)};
+ auto &input = operands.at(input_index);
+
+ const auto output_index{op.getOutputs().at(0)};
+ auto &output = operands.at(output_index);
+
+ ir::Shape new_shape = output.info().shape();
+ new_shape.dim(0) = input.info().shape().dim(0);
+
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::LSTM &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto output_index{op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
+ auto &output = operands.at(output_index);
+
+ const auto output_state_out_index{
+ op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)};
+
+ const auto cell_state_out_index{op.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)};
+
+ const auto scratch_buffer_index{op.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)};
+
+ if (output.info().isDynamic() ||
+ (operands.exist(output_state_out_index) &&
+ operands.at(output_state_out_index).info().isDynamic()) ||
+ (operands.exist(cell_state_out_index) &&
+ operands.at(cell_state_out_index).info().isDynamic()) ||
+ (operands.exist(scratch_buffer_index) &&
+ operands.at(scratch_buffer_index).info().isDynamic()))
+ return;
+
+ const auto input_index{op.getInputs().at(ir::operation::LSTM::Input::INPUT)};
+ const auto &input = operands.at(input_index);
+
+ const auto input_to_output_weights_index{
+ op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
+ const auto &input_to_output_weights = operands.at(input_to_output_weights_index);
+
+ const auto recurrent_to_output_weights_index{
+ op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
+ const auto &recurrent_to_output_weights = operands.at(recurrent_to_output_weights_index);
+
+ // re-sizing outputs
+ const int n_batch = (input.shape().rank() == 3 && op.param().time_major) ? input.shape().dim(1)
+ : input.shape().dim(0);
+ const int n_cell = input_to_output_weights.shape().dim(0);
+ const int n_output = recurrent_to_output_weights.shape().dim(1);
+ if (input.shape().rank() == 3)
+ {
+ if (op.param().time_major)
+ output.info().shape(ir::Shape{input.shape().dim(0), n_batch, n_output});
+ else
+ output.info().shape(ir::Shape{n_batch, input.shape().dim(1), n_output});
+ }
+ else
+ {
+ assert(input.shape().rank() == 2);
+ output.info().shape(ir::Shape{n_batch, n_output});
+ }
+
+ if (operands.exist(output_state_out_index))
+ {
+ auto &output_state_out = operands.at(output_state_out_index);
+ output_state_out.info().shape(ir::Shape{n_batch, n_output});
+ }
+
+ if (operands.exist(cell_state_out_index))
+ {
+ auto &cell_state_out = operands.at(cell_state_out_index);
+ cell_state_out.info().shape(ir::Shape{n_batch, n_cell});
+ }
+
+ if (operands.exist(scratch_buffer_index))
+ {
+ auto &scratch_buffer = operands.at(scratch_buffer_index);
+
+ const auto input_to_input_weights_index{
+ op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)};
+ const auto recurrent_to_input_weights_index{
+ op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)};
+
+ bool has_input_to_input_weights =
+ operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
+ operands.at(input_to_input_weights_index).shape().dim(1) != 0;
+ bool has_recurrent_to_input_weights =
+ operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
+ operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0;
+
+ // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
+ // true: no CIFG
+ // false: CIFG
+ bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
+ if (has_cifg_param)
+ {
+ scratch_buffer.info().shape(ir::Shape{n_batch, n_cell * 4});
+ }
+ else
+ {
+ scratch_buffer.info().shape(ir::Shape{n_batch, n_cell * 3});
+ }
+ }
+}
+
+void StaticShapeInferer::visit(const ir::operation::MatrixBandPart &op)
+{
+ handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT));
+}
+
+void StaticShapeInferer::visit(const ir::operation::OneHot &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto indice_idx{op.getInputs().at(ir::operation::OneHot::Input::INDICES)};
+ const auto &indice = operands.at(indice_idx);
+ const auto depth_idx{op.getInputs().at(ir::operation::OneHot::Input::DEPTH)};
+ const auto &depth = operands.at(depth_idx);
+
+ const auto axis = op.param().axis;
+
+ auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ if (!depth.isConstant())
+ {
+ output.info().setDynamic();
+ return;
+ }
+
+ const auto *depth_buf = reinterpret_cast<const int32_t *>(depth.data()->base());
+ assert(depth_buf);
+ // re-sizing output shape
+ ir::Shape new_shape = shape_inference::inferOnehotShape(indice.info().shape(), *depth_buf, axis);
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Pack &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(0)};
+ const auto &input = operands.at(input_idx);
+
+ // get mutable output operand
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ const auto rank = input.shape().rank() + 1;
+ const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
+ const auto num = op.param().num;
+
+ assert(0 <= axis && axis < rank);
+
+ // re-sizing output shape
+ ir::Shape new_shape = shape_inference::inferPackShape(input.info().shape(), axis, rank, num);
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Pad &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::Pad::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+
+ const auto pad_idx{op.getInputs().at(ir::operation::Pad::Input::PAD)};
+ const auto &pad = operands.at(pad_idx);
+
+ // get mutable output operand
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ // if pad is not constant, output also becomes dynamic
+ if (!pad.isConstant())
+ {
+ output.info().setDynamic();
+ return;
+ }
+
+ // re-sizing output shape
+ const auto &new_shape = shape_inference::inferPadShape(
+ input.shape(), reinterpret_cast<const int32_t *>(pad.data()->base()),
+ pad.shape().num_elements());
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Permute &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(0)};
+ const auto &input = operands.at(input_idx);
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ // re-sizing output shape
+ // Permute is a special operation that layouts of input/output may be different on backend
+ // However, it is not applied here, so input/output have the same layout of frontend. Because
+ // "ExecutorFactory" would convert shape of input/output accoding to the layouts when registering
+ // operand info to "TensorBuilder" after calling "StaticShapeInferer"
+ const auto &new_shape = input.info().shape();
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Pool2D &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto layout = _lowered_subg->graph().layout();
+
+ const auto input_idx{op.getInputs().at(ir::operation::Pool2D::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+ if (input.info().shape().rank() != 4)
+ {
+ throw std::runtime_error(op.name() + ": supports only 4D tensor as input");
+ }
+
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ ir::Shape new_shape = shape_inference::inferPoolShape(input.info().shape(), op.param(), layout);
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Pow &op)
+{
+ handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Pow::Input::LHS),
+ op.getInputs().at(ir::operation::Pow::Input::RHS));
+}
+
+void StaticShapeInferer::visit(const ir::operation::Range &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto start_idx{op.getInputs().at(ir::operation::Range::Input::START)};
+ const auto limit_idx{op.getInputs().at(ir::operation::Range::Input::LIMIT)};
+ const auto delta_idx{op.getInputs().at(ir::operation::Range::Input::DELTA)};
+ const auto &start_op = operands.at(start_idx);
+ const auto &limit_op = operands.at(limit_idx);
+ const auto &delta_op = operands.at(delta_idx);
+
+ // get mutable output operand
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ ir::Shape new_shape;
+ if (start_op.isConstant() && limit_op.isConstant() && delta_op.isConstant())
+ {
+ assert(start_op.typeInfo().type() == limit_op.typeInfo().type() &&
+ start_op.typeInfo().type() == delta_op.typeInfo().type());
+ if (output.typeInfo().type() == ir::DataType::FLOAT32)
+ {
+ new_shape = shape_inference::inferRangeShape<float>(
+ start_op.asScalar<float>(), limit_op.asScalar<float>(), delta_op.asScalar<float>());
+ }
+ else if (output.typeInfo().type() == ir::DataType::INT32)
+ {
+ new_shape = shape_inference::inferRangeShape<int32_t>(
+ start_op.asScalar<int32_t>(), limit_op.asScalar<int32_t>(), delta_op.asScalar<int32_t>());
+ }
+ assert(output.shape() == new_shape);
+ }
+ else
+ {
+ output.info().setDynamic();
+ }
+}
+
+void StaticShapeInferer::visit(const ir::operation::Reduce &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::Reduce::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+
+ const auto axes_idx{op.getInputs().at(ir::operation::Reduce::Input::AXES)};
+ const auto &axes = operands.at(axes_idx);
+
+ // get mutable output operand
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ std::vector<int32_t> axes_vec;
+ for (size_t i = 0; i < axes.shape().num_elements(); ++i)
+ {
+ switch (axes.typeInfo().type())
+ {
+ case ir::DataType::INT32:
+ {
+ axes_vec.emplace_back(reinterpret_cast<const int32_t *>(axes.data()->base())[i]);
+ break;
+ }
+ case ir::DataType::INT64:
+ {
+ axes_vec.emplace_back(reinterpret_cast<const int64_t *>(axes.data()->base())[i]);
+ break;
+ }
+ default:
+ throw std::runtime_error("StaticShapeInferer " + op.name() + ": Not supported data type");
+ break;
+ }
+ }
+ const auto keep_dims = op.param().keep_dims;
+
+ // re-sizing output shape
+ ir::Shape new_shape =
+ shape_inference::inferReduceShape(input.info().shape(), axes_vec, keep_dims);
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Reshape &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::Reshape::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+
+ // get mutable output operand
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ // New shape is given by second input tensor
+ if (op.getInputs().size() == 2)
+ {
+ // Let's check the second input
+ const auto shape_idx{op.getInputs().at(ir::operation::Reshape::Input::SHAPE)};
+ const auto &shape = operands.at(shape_idx);
+
+ if (shape.isConstant())
+ {
+ const auto *shape_buf = reinterpret_cast<const int32_t *>(shape.data()->base());
+ assert(shape_buf);
+
+ ir::Shape new_shape =
+ shape_inference::inferReshapeShape(input.shape(), shape_buf, shape.shape().num_elements());
+
+ // if shape is from Const, TFLC put the shape of output into tensor
+ if (new_shape != output.shape())
+ {
+ // change on output shape
+ output.info().shape(new_shape);
+ }
+ }
+ else
+ {
+ // if shape is NOT Const, set output shape to be dynamic_
+ output.info().setDynamic();
+ }
+ }
+ // New shape is given by option
+ else if (op.param().new_shape.size() != 0)
+ {
+ // Let's check the new_shape option
+ auto shape = op.param().new_shape;
+ ir::Shape new_shape =
+ shape_inference::inferReshapeShape(input.shape(), shape.data(), shape.size());
+
+ if (new_shape != output.shape())
+ {
+ // change on output shape
+ output.info().shape(new_shape);
+ }
+ }
+ else
+ {
+ throw std::runtime_error("Reshape: new shape is missing");
+ }
+}
+
+void StaticShapeInferer::visit(const ir::operation::ResizeBilinear &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+
+ // get mutable output operand
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ int32_t height_out, width_out;
+ if (op.getInputs().size() == 2)
+ {
+ auto &size = operands.at(op.getInputs().at(ir::operation::ResizeBilinear::Input::SIZE));
+ if (!size.isConstant())
+ {
+ output.info().setDynamic();
+ return;
+ }
+ const auto size_v = size.asVector<std::int32_t>();
+ height_out = size_v[0];
+ width_out = size_v[1];
+ }
+ else
+ {
+ height_out = op.param().height_out;
+ width_out = op.param().width_out;
+ }
+
+ // Shape inferencing logic based on Params
+ ir::Shape new_shape =
+ shape_inference::inferResizeBilinearShape(input.shape(), height_out, width_out);
+
+ // if size_op is from Const, TFLC put the shape of output into tensor
+ if (new_shape != output.shape())
+ {
+ // change on output shape
+ output.info().shape(new_shape);
+ }
+}
+
+void StaticShapeInferer::visit(const ir::operation::Reverse &op)
+{
+ handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Reverse::Input::INPUT));
+}
+
+void StaticShapeInferer::visit(const ir::operation::Select &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_cond_idx{op.getInputs().at(ir::operation::Select::Input::CONDITION)};
+ const auto &input_cond = operands.at(input_cond_idx);
+
+ const auto input_true_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)};
+ const auto &input_true = operands.at(input_true_idx);
+
+ const auto input_false_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)};
+ const auto &input_false = operands.at(input_false_idx);
+
+ auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ // Select output shpae
+ ir::Shape new_shape = shape_inference::inferSelectShape(
+ input_cond.info().shape(), input_true.info().shape(), input_false.info().shape());
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Shape &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(0)};
+ const auto &input = operands.at(input_idx);
+
+ // get mutable output operand
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ // re-sizing output shape
+ ir::Shape output_shape;
+ output_shape.append(input.info().shape().rank());
+
+ output.info().shape(output_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Slice &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_index{op.getInputs().at(ir::operation::Slice::Input::INPUT)};
+ const auto &input = operands.at(input_index);
+ const auto begins_index{op.getInputs().at(ir::operation::Slice::Input::BEGINS)};
+ const auto &begins = operands.at(begins_index);
+ const auto sizes_index{op.getInputs().at(ir::operation::Slice::Input::SIZES)};
+ const auto &sizes = operands.at(sizes_index);
+ const auto output_index = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_index);
+
+ // Whether input is constant or not does not affect whether output is dynamic or not
+ if (!(begins.isConstant() && sizes.isConstant()))
+ {
+ output.info().setDynamic();
+ return;
+ }
+
+ auto begins_buf = begins.data()->base();
+ auto sizes_buf = sizes.data()->base();
+
+ const auto begins_type = begins.typeInfo().type();
+ assert(begins_type == ir::DataType::INT32 || begins_type == ir::DataType::INT64);
+ assert(begins_type == sizes.typeInfo().type());
+
+ ir::Shape new_shape =
+ (begins_type == ir::DataType::INT32)
+ ? shape_inference::inferSliceShape<int32_t>(input.info().shape(),
+ reinterpret_cast<const int32_t *>(begins_buf),
+ reinterpret_cast<const int32_t *>(sizes_buf))
+ : shape_inference::inferSliceShape<int64_t>(input.info().shape(),
+ reinterpret_cast<const int64_t *>(begins_buf),
+ reinterpret_cast<const int64_t *>(sizes_buf));
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Softmax &op)
+{
+ handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Softmax::Input::INPUT));
+}
+
+void StaticShapeInferer::visit(const ir::operation::SpaceToBatchND &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto output_index = op.getOutputs().at(0);
+ const auto input_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
+ const auto &block_shape_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
+ const auto &padding_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
+
+ ir::Operand &output = operands.at(output_index);
+ const auto &input = operands.at(input_idx);
+ const auto &block_shape = operands.at(block_shape_idx);
+ const auto &padding = operands.at(padding_idx);
+
+ // Whether input is constant or not does not affect whether output is dynamic or not
+ if (!(block_shape.isConstant() && padding.isConstant()))
+ {
+ output.info().setDynamic();
+ return;
+ }
+
+ const auto &input_shape = input.info().shape();
+ const auto &block_shape_shape = block_shape.info().shape();
+ const auto &padding_shape = padding.info().shape();
+
+ auto block_shape_data = reinterpret_cast<const int32_t *>(block_shape.data()->base());
+ auto padding_data = reinterpret_cast<const int32_t *>(padding.data()->base());
+
+ ir::Shape new_shape = shape_inference::inferSpaceToBatchNDShape(
+ input_shape, block_shape_shape, padding_shape, block_shape_data, padding_data);
+
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Split &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::Split::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+
+ const auto axis_idx{op.getInputs().at(ir::operation::Split::Input::AXIS)};
+ const auto &axis = operands.at(axis_idx);
+
+ auto outputs = op.getOutputs();
+ if (!axis.isConstant())
+ {
+ for (auto &&output_idx : outputs)
+ {
+ ir::Operand &output = operands.at(output_idx);
+ output.info().setDynamic();
+ }
+ return;
+ }
+
+ const auto num_splits = op.param().num_splits;
+
+ const auto rank = input.info().shape().rank();
+ auto axis_value = axis.asScalar<int32_t>();
+ axis_value = axis_value < 0 ? axis_value + rank : axis_value;
+
+ assert(0 <= axis_value && axis_value < rank);
+
+ ir::Shape new_shape =
+ shape_inference::inferSplitShape(input.info().shape(), axis_value, num_splits);
+ for (auto &&output_idx : outputs)
+ {
+ ir::Operand &output = operands.at(output_idx);
+ output.info().shape(new_shape);
+ }
+}
+
+void StaticShapeInferer::visit(const ir::operation::SquaredDifference &op)
+{
+ handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::SquaredDifference::Input::LHS),
+ op.getInputs().at(ir::operation::SquaredDifference::Input::RHS));
+}
+
+void StaticShapeInferer::visit(const ir::operation::Squeeze &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::Squeeze::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ // Squeeze output shpae
+ ir::Shape new_shape = shape_inference::inferSqueezeShape(input.info().shape(), op.param());
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::StridedSlice &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_index{op.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
+ const auto &input = operands.at(input_index);
+ const auto starts_index{op.getInputs().at(ir::operation::StridedSlice::Input::STARTS)};
+ const auto &starts = operands.at(starts_index);
+ const auto ends_index{op.getInputs().at(ir::operation::StridedSlice::Input::ENDS)};
+ const auto &ends = operands.at(ends_index);
+ const auto strides_index{op.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)};
+ const auto &strides = operands.at(strides_index);
+ const auto output_index = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_index);
+
+ if (!(starts.isConstant() && ends.isConstant() && strides.isConstant()))
+ {
+ output.info().setDynamic();
+ return;
+ }
+
+ const auto begin_mask = op.param().begin_mask;
+ const auto end_mask = op.param().end_mask;
+ const auto shrink_axis_mask = op.param().shrink_axis_mask;
+ const auto rank = input.info().shape().rank();
+
+ auto starts_buf = reinterpret_cast<const uint32_t *>(starts.data()->base());
+ auto ends_buf = reinterpret_cast<const uint32_t *>(ends.data()->base());
+ auto strides_buf = reinterpret_cast<const uint32_t *>(strides.data()->base());
+
+ auto op_params = shape_inference::buildStridedSliceParams(
+ starts_buf, ends_buf, strides_buf, begin_mask, end_mask, shrink_axis_mask, rank);
+
+ ir::Shape new_shape =
+ shape_inference::inferStridedSliceShape(input.info().shape(), op_params, rank);
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Tile &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::Tile::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+
+ const auto multiplier_idx{op.getInputs().at(ir::operation::Tile::Input::MULTIPLES)};
+ const auto &multiplier = operands.at(multiplier_idx);
+
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ if (!multiplier.isConstant())
+ {
+ output.info().setDynamic();
+ return;
+ }
+
+ auto multiplier_buffer = reinterpret_cast<const int32_t *>(multiplier.data()->base());
+ assert(multiplier_buffer);
+
+ // re-sizing output shape
+ auto new_shape = shape_inference::inferTileShape(input.info().shape(), multiplier_buffer,
+ multiplier.shape().num_elements());
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Transpose &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(ir::operation::Transpose::Input::INPUT)};
+ const auto &input = operands.at(input_idx);
+
+ const auto perm_idx{op.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
+ const auto &perm = operands.at(perm_idx);
+
+ // perm.shape() != ir::Shape{0} means that perm is (n-1...0)
+ // TODO This condition changes to perm.num_elements() == 0
+ const auto is_regular_transpose = perm.shape() == ir::Shape{0};
+
+ // get mutable output operand
+ const auto output_idx = op.getOutputs().at(0);
+ auto &output = operands.at(output_idx);
+ if (!perm.isConstant() && !is_regular_transpose)
+ {
+ output.info().setDynamic();
+ return;
+ }
+
+ ir::Shape new_shape;
+ if (is_regular_transpose)
+ {
+ // Call by (n-1...0)
+ new_shape = shape_inference::inferTransposeShape(input.info().shape(), nullptr, 0);
+ }
+ else
+ {
+ // Check rank
+ if (input.info().shape().rank() != static_cast<int>(perm.info().shape().num_elements()))
+ {
+ throw std::runtime_error("StaticShapeInferer failed, bad rank size: " +
+ std::to_string(perm.info().shape().num_elements()));
+ }
+
+ // set output shape, based on input and params
+ const auto perm_buf = reinterpret_cast<const int32_t *>(perm.data()->base());
+ new_shape = shape_inference::inferTransposeShape(input.info().shape(), perm_buf,
+ perm.shape().num_elements());
+ }
+ output.info().shape(new_shape);
+}
+
+void StaticShapeInferer::visit(const ir::operation::Unpack &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ const auto input_idx{op.getInputs().at(0)};
+ const auto &input = operands.at(input_idx);
+ const auto num = op.param().num;
+ const auto rank = input.shape().rank();
+ const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
+
+ assert(axis < rank);
+ if (axis < 0)
+ {
+ for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
+ {
+ const auto output_idx = op.getOutputs().at(out_tensor_idx);
+ ir::Operand &output = operands.at(output_idx);
+ output.info().setDynamic();
+ }
+ return;
+ }
+
+ ir::Shape new_shape = shape_inference::inferUnpackShape(input.info().shape(), axis, rank);
+
+ // re-sizing output shape
+ for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
+ {
+ const auto output_idx = op.getOutputs().at(out_tensor_idx);
+ ir::Operand &output = operands.at(output_idx);
+ output.info().shape(new_shape);
+ }
+}
+
+void StaticShapeInferer::visit(const ir::operation::While &op)
+{
+ auto body_input_observer = _subg_input_observers.at(op.param().body_subg_index).get();
+ auto cond_input_observer = _subg_input_observers.at(op.param().cond_subg_index).get();
+ // re-sizing input shapes of body subgraph
+ const auto &inputs = op.getInputs();
+ std::vector<ir::OperandInfo> inputs_info;
+ const auto &graph = _lowered_subg->graph();
+ for (size_t i = 0; i < inputs.size(); ++i)
+ {
+ const auto &operand_info = graph.operands().at(inputs.at(i)).info();
+ inputs_info.emplace_back(operand_info);
+ }
+
+ body_input_observer->updateShapes(inputs_info);
+ _child_inferers.at(op.param().body_subg_index)->infer();
+
+ // Check whether while operation's shapes are predictable
+ // This while op's outputs are also updated in the above function
+ // "_child_inferers.at(op.param().body_subg_index)->update()". That means that body's outputs and
+ // thils op's outputs must have the same shape. So we can predict whether body subgraphs will
+ // change at every step by comparing the shapes of inputs/outputs. If any of shape of body outputs
+ // and inputs are different Non-constant operands will be set to dynamic.
+ bool check_unpredictable_dynamic = false;
+ const auto &updated_outputs = op.getOutputs();
+ assert(inputs_info.size() == updated_outputs.size());
+ for (size_t i = 0; i < updated_outputs.size(); ++i)
+ {
+ const auto &input_info = inputs_info.at(i);
+ const auto &output_info = graph.operands().at(updated_outputs.at(i)).info();
+ if (input_info.isDynamic() != output_info.isDynamic() ||
+ input_info.shape() != output_info.shape())
+ {
+ check_unpredictable_dynamic = true;
+ break;
+ }
+ }
+
+ if (check_unpredictable_dynamic)
+ {
+ body_input_observer->updateShapes(inputs_info, check_unpredictable_dynamic);
+ _child_inferers.at(op.param().body_subg_index)->infer();
+ }
+ cond_input_observer->updateShapes(inputs_info, check_unpredictable_dynamic);
+ _child_inferers.at(op.param().cond_subg_index)->infer();
+}
+
+void StaticShapeInferer::visit(const ir::operation::DetectionPostProcess &op)
+{
+ // TODO: NMS supports very limited input/output size.
+ ir::operation::DetectionPostProcess::Param param = op.param();
+
+ auto &operands = _lowered_subg->graph().operands();
+ const int num_detected_boxes = param.max_detections * param.max_classes_per_detection;
+
+ const auto output_idx1 = op.getOutputs().at(0);
+ auto &output1 = operands.at(output_idx1);
+ output1.info().shape({1, num_detected_boxes, 4});
+
+ const auto output_idx2 = op.getOutputs().at(1);
+ auto &output2 = operands.at(output_idx2);
+ output2.info().shape({1, num_detected_boxes});
+
+ const auto output_idx3 = op.getOutputs().at(2);
+ auto &output3 = operands.at(output_idx3);
+ output3.info().shape({1, num_detected_boxes});
+
+ const auto output_idx4 = op.getOutputs().at(3);
+ auto &output4 = operands.at(output_idx4);
+ output4.info().shape({1});
+}
+void StaticShapeInferer::visit(const ir::operation::Bulk &op)
+{
+ auto &operands = _lowered_subg->graph().operands();
+
+ // TODO: support multiple inputs/outputs
+ const auto input_idx{op.getInputs().at(0)};
+ const auto &input = operands.at(input_idx);
+ const auto output_idx = op.getOutputs().at(0);
+ ir::Operand &output = operands.at(output_idx);
+
+ const auto &cur_input_shape = input.info().shape();
+ auto origin_output_shape = op.param().origin_output_shapes[0];
+
+ // TODO: more check for valid batch request
+ if ((cur_input_shape.dim(0) < origin_output_shape.dim(0)) ||
+ (cur_input_shape.dim(0) % origin_output_shape.dim(0) != 0))
+ {
+ throw std::runtime_error("StaticShapeInferer " + op.name() + ": Not supported batch size");
+ }
+ size_t batch_multiplier = cur_input_shape.dim(0) / origin_output_shape.dim(0);
+
+ ir::Shape new_shape;
+ new_shape.append(origin_output_shape.dim(0) * batch_multiplier);
+ for (int32_t d = 1; d < origin_output_shape.rank(); ++d)
+ new_shape.append(origin_output_shape.dim(d));
+
+ output.info().shape(new_shape);
+}
+
+} // namespace compiler
+
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/TensorBuilders.h b/runtime/onert/core/src/compiler/TensorBuilders.h
deleted file mode 100644
index 3b0360b4b..000000000
--- a/runtime/onert/core/src/compiler/TensorBuilders.h
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __ONERT_COMPILER_TENSOR_BUILDERS_H__
-#define __ONERT_COMPILER_TENSOR_BUILDERS_H__
-
-#include <unordered_set>
-#include <memory>
-#include "backend/BackendContext.h"
-#include "backend/Backend.h"
-#include "backend/controlflow/Config.h"
-#include "backend/controlflow/TensorBuilder.h"
-#include "util/logging.h"
-
-namespace onert
-{
-namespace compiler
-{
-
-class TensorBuilders
-{
-public:
- TensorBuilders() = default;
-
- TensorBuilders(const onert::backend::BackendContexts &backend_contexts, bool include_controlflow)
- {
- for (const auto &e : backend_contexts)
- {
- if (e.first->config()->id() == backend::controlflow::Config::ID)
- {
- _cf_tensor_builder = std::dynamic_pointer_cast<backend::controlflow::TensorBuilder>(
- e.second->tensor_builder);
- if (include_controlflow)
- _tensor_builders.insert(e.second->tensor_builder);
- }
- else
- {
- _tensor_builders.insert(e.second->tensor_builder);
- }
- }
- }
-
- std::unordered_set<std::shared_ptr<onert::backend::ITensorBuilder>>::const_iterator begin() const
- {
- return _tensor_builders.cbegin();
- }
- std::unordered_set<std::shared_ptr<onert::backend::ITensorBuilder>>::const_iterator end() const
- {
- return _tensor_builders.cend();
- }
-
- std::shared_ptr<backend::controlflow::TensorBuilder> getControlflowTensorBuilder() const
- {
- return _cf_tensor_builder;
- }
-
-private:
- std::unordered_set<std::shared_ptr<backend::ITensorBuilder>> _tensor_builders;
- std::shared_ptr<backend::controlflow::TensorBuilder> _cf_tensor_builder;
-};
-
-} // namespace compiler
-} // namespace onert
-
-#endif // __ONERT_COMPILER_TENSOR_BUILDERS_H__
diff --git a/runtime/onert/core/src/compiler/TensorRegistries.h b/runtime/onert/core/src/compiler/TensorRegistries.h
index 8be87b081..4c30785df 100644
--- a/runtime/onert/core/src/compiler/TensorRegistries.h
+++ b/runtime/onert/core/src/compiler/TensorRegistries.h
@@ -17,13 +17,14 @@
#ifndef __ONERT_COMPILER_TENSOR_REGISTRIES_H__
#define __ONERT_COMPILER_TENSOR_REGISTRIES_H__
-#include <unordered_set>
-#include <memory>
-#include "backend/BackendContext.h"
+#include "../backend/builtin/Config.h"
+#include "../backend/builtin/TensorRegistry.h"
+
#include "backend/Backend.h"
-#include "backend/controlflow/Config.h"
-#include "backend/controlflow/TensorBuilder.h"
-#include "backend/controlflow/TensorRegistry.h"
+#include "backend/BackendContext.h"
+
+#include <memory>
+#include <unordered_set>
namespace onert
{
@@ -35,17 +36,16 @@ class TensorRegistries
public:
TensorRegistries() = default;
- TensorRegistries(const onert::backend::BackendContexts &backend_contexts,
- bool include_controlflow)
+ TensorRegistries(const onert::backend::BackendContexts &backend_contexts, bool include_builtin)
{
for (const auto &e : backend_contexts)
{
auto tensor_reg = e.second->tensor_registry;
- if (e.first->config()->id() == backend::controlflow::Config::ID)
+ if (e.first->config()->id() == backend::builtin::Config::ID)
{
- _cf_tensor_reg =
- std::dynamic_pointer_cast<backend::controlflow::TensorRegistry>(tensor_reg);
- if (include_controlflow)
+ _builtin_tensor_reg =
+ std::dynamic_pointer_cast<backend::builtin::TensorRegistry>(tensor_reg);
+ if (include_builtin)
_tensor_regs.insert(tensor_reg);
}
else
@@ -64,14 +64,14 @@ public:
return _tensor_regs.cend();
}
- std::shared_ptr<backend::controlflow::TensorRegistry> getControlflowTensorRegistry() const
+ std::shared_ptr<backend::builtin::TensorRegistry> getBuiltinTensorRegistry() const
{
- return _cf_tensor_reg;
+ return _builtin_tensor_reg;
}
- std::shared_ptr<backend::ITensor> getITensor(ir::OperandIndex ind) const
+ backend::ITensor *getITensor(ir::OperandIndex ind) const
{
- for (auto &tensor_reg : _tensor_regs)
+ for (const auto &tensor_reg : _tensor_regs)
{
auto tensor = tensor_reg->getITensor(ind);
if (tensor)
@@ -82,7 +82,7 @@ public:
private:
std::unordered_set<std::shared_ptr<backend::ITensorRegistry>> _tensor_regs;
- std::shared_ptr<backend::controlflow::TensorRegistry> _cf_tensor_reg;
+ std::shared_ptr<backend::builtin::TensorRegistry> _builtin_tensor_reg;
};
} // namespace compiler
diff --git a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc
index 647669e46..ac131803f 100644
--- a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc
+++ b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.cc
@@ -17,8 +17,9 @@
#include "ConstantInsertionPass.h"
#include "backend/Backend.h"
-#include <ir/Graph.h>
-#include <util/Utils.h>
+#include "ir/Graph.h"
+#include "util/Utils.h"
+#include "util/logging.h"
namespace onert
{
@@ -27,39 +28,30 @@ namespace compiler
namespace pass
{
-void ConstantInsertionPass::callback(const ir::OperationIndex &node_index, ir::Operation &node)
+void ConstantInsertionPass::callback(const ir::OperationIndex &node_index, ir::IOperation &node)
{
- const auto &op_sequence_index = _lowered_graph.op_seqs().getOperation(node_index);
- const auto op_seq_lower_info = _lowered_graph.getLowerInfo(op_sequence_index);
- const auto backend = op_seq_lower_info->backend();
- const auto layout = op_seq_lower_info->layout();
- const auto factor = ir::operand::PermuteFactor{backend, layout};
+ const auto op_lower_info = _lowered_graph.lower_info().operation.getRawPtr(node_index);
+ const auto backend = op_lower_info->backend();
+ const auto layout = op_lower_info->layout();
+ const auto factor = PermuteFactor{backend, layout};
- for (const auto input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
+ for (const auto &input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
{
auto &object = _graph.operands().at(input);
- if (object.isConstant())
+ const auto key = ReplaceKey{input, factor};
+ if (object.isConstant() && (object.getUses().size() >= 2 ||
+ _replace_operands_map.find(key) != _replace_operands_map.end()))
{
- const auto key = ReplaceKey{input, factor};
if (_replace_operands_map.count(key) == 0)
{
- auto new_object = object;
- new_object.unsetDef();
- // TODO Remove const_case
- const_cast<ir::OperationIndexSet &>(new_object.getUses()).clear();
+ ir::Operand new_object(object);
+ new_object.clearDefUse();
const auto new_index = _graph.operands().emplace(new_object);
_replace_operands_map[key] = new_index;
}
const auto replaced_input = _replace_operands_map[key];
- // Update op_seq
- if (_lowered_graph.op_seqs().at(op_sequence_index).getInputs().contains(input))
- {
- // All inputs of op_seq have the same PermuteFactor because those inputs are inputs of first
- // operation
- _lowered_graph.op_seqs().at(op_sequence_index).replaceInputs(input, replaced_input);
- }
// Update the same inputs of a node at once because inputs of an operation have the same
// PermuteFactor
@@ -69,6 +61,8 @@ void ConstantInsertionPass::callback(const ir::OperationIndex &node_index, ir::O
auto &replaced_object = _graph.operands().at(replaced_input);
replaced_object.insertUse(node_index);
+ VERBOSE(ConstInsertPass) << "New operand " << replaced_input << " added(copy of " << input
+ << ") for " << factor << std::endl;
// Remove this node from uses of origin operand
// Constant operand has no def.
assert(!object.getDef().valid());
@@ -76,12 +70,16 @@ void ConstantInsertionPass::callback(const ir::OperationIndex &node_index, ir::O
// Remove origin operand
if (object.getUses().size() == 0)
+ {
_graph.removeOperand(input);
+ VERBOSE(ConstInsertPass) << "Original operand " << input << " removed - no uses"
+ << std::endl;
+ }
}
}
// Now this runtime does not support the node making output as constant
- for (const auto &output : node.getOutputs())
+ for (const auto &output : node.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
{
UNUSED_RELEASE(output);
assert(!_graph.operands().at(output).isConstant());
diff --git a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h
index 052883c92..d5b9aa14e 100644
--- a/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h
+++ b/runtime/onert/core/src/compiler/pass/ConstantInsertionPass.h
@@ -17,7 +17,7 @@
#ifndef __ONERT_COMPILER_PASS_CONSTANT_INSERTION_PASS_H__
#define __ONERT_COMPILER_PASS_CONSTANT_INSERTION_PASS_H__
-#include <ir/operand/PermuteFactor.h>
+#include <compiler/PermuteFactor.h>
#include <ir/Index.h>
#include "LoweredOperationPass.h"
#include <unordered_map>
@@ -39,13 +39,13 @@ public:
std::string id() final { return "ConstantInsertionPass"; }
public:
- void callback(const ir::OperationIndex &index, ir::Operation &node) final;
+ void callback(const ir::OperationIndex &index, ir::IOperation &node) final;
private:
struct ReplaceKey
{
ir::OperandIndex index;
- ir::operand::PermuteFactor factor;
+ PermuteFactor factor;
bool operator==(const ReplaceKey &other) const
{
@@ -61,8 +61,7 @@ private:
std::size_t operator()(const ReplaceKey &key) const noexcept
{
using std::hash;
- return hash<ir::OperandIndex>()(key.index) ^
- (hash<ir::operand::PermuteFactor>()(key.factor) << 1);
+ return hash<ir::OperandIndex>()(key.index) ^ (hash<PermuteFactor>()(key.factor) << 1);
}
};
diff --git a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc
index 1c1dbe0ee..32e32d0ef 100644
--- a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc
+++ b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.cc
@@ -18,8 +18,9 @@
#include "backend/Backend.h"
#include <ir/Graph.h>
-#include <ir/operand/PermuteFactor.h>
+#include <compiler/PermuteFactor.h>
#include <util/Utils.h>
+#include "util/logging.h"
namespace onert
{
@@ -28,25 +29,25 @@ namespace compiler
namespace pass
{
-void ConstantLoweringPass::callback(const ir::OperationIndex &node_index, ir::Operation &node)
+void ConstantLoweringPass::callback(const ir::OperationIndex &node_index, ir::IOperation &node)
{
- const auto &op_sequence_index = _lowered_graph.op_seqs().getOperation(node_index);
- const auto op_seq_lower_info = _lowered_graph.getLowerInfo(op_sequence_index);
- const auto backend = op_seq_lower_info->backend();
- const auto layout = op_seq_lower_info->layout();
- const auto factor = ir::operand::PermuteFactor{backend, layout};
+ const auto op_lower_info = _lowered_graph.lower_info().operation.getRawPtr(node_index);
+ const auto backend = op_lower_info->backend();
+ const auto layout = op_lower_info->layout();
+ const auto factor = PermuteFactor{backend, layout};
// Now this runtime does not support the node making output of operation as constant
- for (const auto input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
+ for (const auto &input : node.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
{
auto &object = _graph.operands().at(input);
if (object.isConstant())
{
// All constant operand are already assinged at each backend by ContantInsertionPass. So a
// constant has `def` and `use` as the same PermuteFactor
- _lowered_graph.setLowerInfo(input, std::make_unique<ir::operand::LowerInfo>());
- _lowered_graph.getLowerInfo(input)->addDefPermuteFactor(factor);
- _lowered_graph.getLowerInfo(input)->addUsePermuteFactor(factor);
+ auto operand_li = std::make_unique<compiler::OperandLowerInfo>();
+ operand_li->addDefPermuteFactor(factor);
+ operand_li->addUsePermuteFactor(factor);
+ _lowered_graph.lower_info().operand.set(input, std::move(operand_li));
}
}
}
diff --git a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h
index e17d776d1..d60a1033f 100644
--- a/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h
+++ b/runtime/onert/core/src/compiler/pass/ConstantLoweringPass.h
@@ -36,7 +36,7 @@ public:
std::string id() final { return "ConstantLoweringPass"; }
public:
- void callback(const ir::OperationIndex &index, ir::Operation &node) final;
+ void callback(const ir::OperationIndex &index, ir::IOperation &node) final;
};
} // namespace pass
diff --git a/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc b/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc
new file mode 100644
index 000000000..1448de473
--- /dev/null
+++ b/runtime/onert/core/src/compiler/pass/ConstantOutputPass.cc
@@ -0,0 +1,68 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConstantOutputPass.h"
+
+#include "ir/Graph.h"
+#include "ir/operation/Permute.h"
+#include "util/logging.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace pass
+{
+
+void ConstantOutputPass::callback(const ir::OperandIndex &ind, ir::Operand &obj)
+{
+ if (!_graph.getOutputs().contains(ind) || !obj.isConstant())
+ return;
+
+ auto permute_input_ind = _graph.addOperand(obj.shape(), obj.typeInfo());
+ auto &permute_input_obj = _graph.operands().at(permute_input_ind);
+
+ // Move the const data
+ permute_input_obj.data(obj.shareData());
+ obj.releaseData();
+ obj.info().setAsNonConst();
+
+ using ir::operation::Permute;
+ auto permute_obj = std::make_unique<Permute>(permute_input_ind, ind, Permute::Type::COPY);
+ auto permute_ind = _graph.operations().push(std::move(permute_obj));
+
+ permute_input_obj.insertUse(permute_ind);
+ obj.setDef(permute_ind);
+
+ // Make the operations that uses this operand to use the generated operand
+ auto orig_uses = obj.getUses();
+ for (auto &&use : orig_uses)
+ {
+ permute_input_obj.insertUse(use);
+ obj.removeUse(use);
+ _graph.operations().at(use).replaceInputs(ind, permute_input_ind);
+ }
+
+ VERBOSE(ConstantOutputPass) << "Permute Op inserted for a constant ouput, node index : "
+ << permute_ind << std::endl;
+ VERBOSE(ConstantOutputPass) << " - Input (inserted) Operand : " << permute_input_ind
+ << std::endl;
+ VERBOSE(ConstantOutputPass) << " - Output(original) Operand : " << ind << std::endl;
+}
+
+} // namespace pass
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/pass/ConstantOutputPass.h b/runtime/onert/core/src/compiler/pass/ConstantOutputPass.h
new file mode 100644
index 000000000..193dd3a68
--- /dev/null
+++ b/runtime/onert/core/src/compiler/pass/ConstantOutputPass.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_PASS_CONSTANT_OUTPUT_PASS_H__
+#define __ONERT_COMPILER_PASS_CONSTANT_OUTPUT_PASS_H__
+
+#include "OperandPass.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace pass
+{
+
+/**
+ * @brief Pass to specially handle constant model outputs
+ *
+ * As an output buffer is given right before an execution but constant initialization is done at
+ * prepare phase, the current runtime structure cannot handle when an output is constant.
+ * To resolve this problem, this pass inserts a Permute layer with a const input and make the model
+ * output tensor to be its output.
+ *
+ * e.g.)
+ *
+ * ((Const Output))
+ *
+ * becomes
+ *
+ * (Const) -> [Permute] -> ((Output))
+ *
+ * Note that this is a mandatory pass for Graph.
+ */
+class ConstantOutputPass : public OperandPass
+{
+public:
+ using OperandPass::OperandPass;
+
+public:
+ std::string id() final { return "ConstantOutputPass"; }
+
+public:
+ void callback(const ir::OperandIndex &i, ir::Operand &o) final;
+};
+
+} // namespace pass
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_PASS_CONSTANT_INSERTION_PASS_H__
diff --git a/runtime/onert/core/src/compiler/pass/IPass.h b/runtime/onert/core/src/compiler/pass/IPass.h
new file mode 100644
index 000000000..77f5916fd
--- /dev/null
+++ b/runtime/onert/core/src/compiler/pass/IPass.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_PASS_IPASS_H__
+#define __ONERT_COMPILER_PASS_IPASS_H__
+
+#include <string>
+
+namespace onert
+{
+namespace compiler
+{
+namespace pass
+{
+
+struct IPass
+{
+ virtual ~IPass() = default;
+
+ virtual std::string id() = 0;
+ virtual void run() = 0;
+};
+
+} // namespace pass
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_PASS_IPASS_H__
diff --git a/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h b/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h
index 0c5f7d745..64831a0ac 100644
--- a/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h
+++ b/runtime/onert/core/src/compiler/pass/LoweredOperandPass.h
@@ -18,7 +18,7 @@
#define __ONERT_IR_PASS_LOWERED_OPERAND_PASS_H__
#include "OperandPass.h"
-#include "compiler/LoweredGraph.h"
+#include "compiler/ILoweredGraph.h"
namespace onert
{
@@ -30,8 +30,8 @@ namespace pass
class LoweredOperandPass : public OperandPass
{
public:
- LoweredOperandPass(compiler::LoweredGraph &lowered_graph)
- : OperandPass{lowered_graph.graph()}, _lowered_graph{lowered_graph}
+ LoweredOperandPass(compiler::ILoweredGraph &lowered_graph)
+ : OperandPass{lowered_graph.graph()}, _lowered_graph{lowered_graph}
{
// DO NOTHING
}
@@ -42,7 +42,7 @@ public:
void callback(const ir::OperandIndex &i, ir::Operand &o) override = 0;
protected:
- compiler::LoweredGraph &_lowered_graph;
+ compiler::ILoweredGraph &_lowered_graph;
};
} // namespace pass
diff --git a/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h b/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h
index 5c8569be2..27ca77c91 100644
--- a/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h
+++ b/runtime/onert/core/src/compiler/pass/LoweredOperationPass.h
@@ -18,7 +18,7 @@
#define __ONERT_IR_PASS_LOWERED_OPERATION_PASS_H__
#include "OperationPass.h"
-#include "compiler/LoweredGraph.h"
+#include "compiler/ILoweredGraph.h"
namespace onert
{
@@ -30,8 +30,8 @@ namespace pass
class LoweredOperationPass : public OperationPass
{
public:
- LoweredOperationPass(LoweredGraph &lowered_graph)
- : OperationPass{lowered_graph.graph()}, _lowered_graph{lowered_graph}
+ LoweredOperationPass(ILoweredGraph &lowered_graph)
+ : OperationPass{lowered_graph.graph()}, _lowered_graph{lowered_graph}
{
// DO NOTHING
}
@@ -39,10 +39,10 @@ public:
virtual ~LoweredOperationPass() = default;
std::string id() override = 0;
- void callback(const ir::OperationIndex &i, ir::Operation &o) override = 0;
+ void callback(const ir::OperationIndex &i, ir::IOperation &o) override = 0;
protected:
- LoweredGraph &_lowered_graph;
+ ILoweredGraph &_lowered_graph;
};
} // namespace pass
diff --git a/runtime/onert/core/src/compiler/pass/OddOutputPass.cc b/runtime/onert/core/src/compiler/pass/OddOutputPass.cc
new file mode 100644
index 000000000..e2b3f6111
--- /dev/null
+++ b/runtime/onert/core/src/compiler/pass/OddOutputPass.cc
@@ -0,0 +1,90 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "OddOutputPass.h"
+
+#include "ir/Graph.h"
+#include "ir/operation/Permute.h"
+#include "util/logging.h"
+#include "util/Utils.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace pass
+{
+
+void OddOutputPass::run()
+{
+ auto &outputs = _graph.getOutputs();
+
+ VERBOSE(OddOutputPass) << "Case 1 : An operand which is a model output and a model input"
+ << std::endl;
+ for (const auto &ind : outputs)
+ {
+ if (_graph.getInputs().contains(ind))
+ {
+ auto permute_output_ind = insertPermute(ind);
+ // Update the output to be newly added operand
+ _graph.getOutputs().replace(ind, permute_output_ind);
+ }
+ }
+
+ VERBOSE(OddOutputPass) << "Case 2 : Two or more duplicated outputs" << std::endl;
+ std::unordered_set<ir::OperandIndex> occurence;
+ for (auto &&ind : outputs)
+ {
+ auto &obj = _graph.operands().at(ind);
+ if (occurence.count(ind) == 0)
+ {
+ occurence.insert(ind);
+ continue;
+ }
+
+ // Panic when it is const, it must have been handled earlier in another pass
+ UNUSED_RELEASE(obj);
+ assert(!obj.isConstant());
+
+ auto permute_output_ind = insertPermute(ind);
+ ind = permute_output_ind; // Replace output index to fix output duplication
+ }
+}
+
+ir::OperandIndex OddOutputPass::insertPermute(ir::OperandIndex ind)
+{
+ auto &obj = _graph.operands().at(ind);
+ auto output_ind = _graph.addOperand(obj.shape(), obj.typeInfo());
+ auto &output_obj = _graph.operands().at(output_ind);
+
+ using ir::operation::Permute;
+ auto permute_obj = std::make_unique<Permute>(ind, output_ind, Permute::Type::COPY);
+ auto permute_ind = _graph.operations().push(std::move(permute_obj));
+
+ output_obj.setDef(permute_ind);
+ obj.insertUse(permute_ind);
+
+ VERBOSE(OddOutputPass) << "Permute Op inserted for a constant output, node index : "
+ << permute_ind << std::endl;
+ VERBOSE(OddOutputPass) << " - Input (original) Operand : " << ind << std::endl;
+ VERBOSE(OddOutputPass) << " - Output(inserted) Operand : " << output_ind << std::endl;
+
+ return output_ind;
+}
+
+} // namespace pass
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/pass/OddOutputPass.h b/runtime/onert/core/src/compiler/pass/OddOutputPass.h
new file mode 100644
index 000000000..2accbac60
--- /dev/null
+++ b/runtime/onert/core/src/compiler/pass/OddOutputPass.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_PASS_ODD_OUTPUT_PASS_H__
+#define __ONERT_COMPILER_PASS_ODD_OUTPUT_PASS_H__
+
+#include <unordered_set>
+
+#include "Pass.h"
+#include "ir/Index.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace pass
+{
+
+/**
+ * @brief Pass to specially handle odd outputs in a subgraph
+ *
+ * Runtime Graph IR requires every input or output must have distinct tensor index, this is onert's
+ * restriction. However we allow duplication of indices in the models(or API). So we should
+ * transform the graph after model-loading.
+ *
+ * This is necessary since our API lets users to set different buffers for each input and output so
+ * it is unavoidable that we must copy the value at runtime.
+ *
+ * Note that this is a mandatory pass for Graph.
+ *
+ * Case 1 : An operand which is a model output and a model input
+ *
+ * Create an operand and insert a Permute(copy) op between them. And change the output to be the
+ * newly generated operand.
+ *
+ * e.g.)
+ *
+ * ```
+ * ((#0 Input0 and also Output0))
+ * becomes
+ * ((#0 Input0)) -> [#0 Permute] -> ((#1 Output0))
+ * ```
+ *
+ * Case 2 : Two or more duplicated outputs
+ *
+ * Do the same with Case 1, but between two outputs of the same tensor index.
+ *
+ * e.g.)
+ *
+ * ```
+ * ((#0 Input0)) -> [#0 Some Operation] -> ((#1 Output0 and also Output1))
+ * becomes
+ * ((#0 Input0)) -> [#0 Some Operation] -> ((#1 Output0)) [#1 Permute] -> ((#2 Output1))
+ * ```
+ *
+ */
+class OddOutputPass : public Pass
+{
+public:
+ using Pass::Pass;
+
+public:
+ std::string id() final { return "OddOutputPass"; }
+
+public:
+ void run() override;
+
+private:
+ ir::OperandIndex insertPermute(ir::OperandIndex input);
+};
+
+} // namespace pass
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_PASS_ODD_OUTPUT_PASS_H__
diff --git a/runtime/onert/core/src/compiler/pass/OperandPass.cc b/runtime/onert/core/src/compiler/pass/OperandPass.cc
index 50c001c30..db8ebedcd 100644
--- a/runtime/onert/core/src/compiler/pass/OperandPass.cc
+++ b/runtime/onert/core/src/compiler/pass/OperandPass.cc
@@ -28,7 +28,7 @@ namespace pass
void OperandPass::run()
{
_graph.operands().iterate(
- [&](const ir::OperandIndex &index, ir::Operand &object) { callback(index, object); });
+ [&](const ir::OperandIndex &index, ir::Operand &object) { callback(index, object); });
}
} // namespace pass
diff --git a/runtime/onert/core/src/compiler/pass/OperationPass.cc b/runtime/onert/core/src/compiler/pass/OperationPass.cc
index d7a55cb22..bd9bcb4a4 100644
--- a/runtime/onert/core/src/compiler/pass/OperationPass.cc
+++ b/runtime/onert/core/src/compiler/pass/OperationPass.cc
@@ -17,7 +17,7 @@
#include "OperationPass.h"
#include "ir/Index.h"
-#include "ir/Operation.h"
+#include "ir/IOperation.h"
#include "ir/Graph.h"
namespace onert
@@ -30,7 +30,7 @@ namespace pass
void OperationPass::run()
{
_graph.operations().iterate(
- [&](const ir::OperationIndex &index, ir::Operation &node) { callback(index, node); });
+ [&](const ir::OperationIndex &index, ir::IOperation &node) { callback(index, node); });
}
} // namespace pass
diff --git a/runtime/onert/core/src/compiler/pass/OperationPass.h b/runtime/onert/core/src/compiler/pass/OperationPass.h
index ac4d818a2..0a00b11d1 100644
--- a/runtime/onert/core/src/compiler/pass/OperationPass.h
+++ b/runtime/onert/core/src/compiler/pass/OperationPass.h
@@ -29,7 +29,7 @@ namespace onert
{
namespace ir
{
-class Operation;
+struct IOperation;
} // namespace ir
} // namespace onert
@@ -62,7 +62,7 @@ public:
* @param index is the index of a node in graph
* @param node is the node in graph
*/
- virtual void callback(const ir::OperationIndex &index, ir::Operation &node) = 0;
+ virtual void callback(const ir::OperationIndex &index, ir::IOperation &node) = 0;
/**
* @brief Run the pass
diff --git a/runtime/onert/core/src/compiler/pass/Pass.h b/runtime/onert/core/src/compiler/pass/Pass.h
index 3f356c337..b34695c97 100644
--- a/runtime/onert/core/src/compiler/pass/Pass.h
+++ b/runtime/onert/core/src/compiler/pass/Pass.h
@@ -17,6 +17,8 @@
#ifndef __ONERT_COMPILER_PASS_PASS_H__
#define __ONERT_COMPILER_PASS_PASS_H__
+#include "IPass.h"
+
#include <string>
namespace onert
@@ -24,7 +26,7 @@ namespace onert
namespace ir
{
class Graph;
-} // namespace compiler
+} // namespace ir
} // namespace onert
namespace onert
@@ -34,7 +36,7 @@ namespace compiler
namespace pass
{
-class Pass
+class Pass : public IPass
{
public:
Pass(ir::Graph &graph) : _graph{graph} {}
diff --git a/runtime/onert/core/src/compiler/pass/PassRunner.cc b/runtime/onert/core/src/compiler/pass/PassRunner.cc
new file mode 100644
index 000000000..cd1b82bb2
--- /dev/null
+++ b/runtime/onert/core/src/compiler/pass/PassRunner.cc
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PassRunner.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace pass
+{
+
+PassRunner &PassRunner::append(std::unique_ptr<IPass> pass)
+{
+ _passes.emplace_back(std::move(pass));
+ return *this;
+}
+
+void PassRunner::run()
+{
+ for (auto &&pass : _passes)
+ {
+ VERBOSE(PassRunner) << "Start running '" << pass->id() << "'" << std::endl;
+ pass->run();
+ VERBOSE(PassRunner) << "Finished running '" << pass->id() << "'" << std::endl;
+ // TODO Dump graph?
+ }
+}
+
+} // namespace pass
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/pass/PassRunner.h b/runtime/onert/core/src/compiler/pass/PassRunner.h
new file mode 100644
index 000000000..03bfbe220
--- /dev/null
+++ b/runtime/onert/core/src/compiler/pass/PassRunner.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_PASS_PASS_RUNNER_H__
+#define __ONERT_COMPILER_PASS_PASS_RUNNER_H__
+
+#include <initializer_list>
+#include <memory>
+#include <vector>
+
+#include "IPass.h"
+#include "util/logging.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace pass
+{
+
+/**
+ * @brief Composite passes with logging
+ */
+class PassRunner
+{
+public:
+ PassRunner() = default;
+ PassRunner &append(std::unique_ptr<IPass> pass);
+
+ void run();
+
+private:
+ std::vector<std::unique_ptr<IPass>> _passes;
+};
+
+} // namespace pass
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_PASS_PASS_RUNNER_H__
diff --git a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc
index f01697034..d9452c7f9 100644
--- a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc
+++ b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.cc
@@ -15,8 +15,8 @@
*/
#include "PermutationEliminationPass.h"
-#include "backend/controlflow/Config.h"
+#include "backend/Backend.h"
#include "util/logging.h"
namespace onert
@@ -26,7 +26,7 @@ namespace compiler
namespace pass
{
-void PermutationEliminationPass::callback(const ir::OperationIndex &ind, ir::Operation &node)
+void PermutationEliminationPass::callback(const ir::OperationIndex &ind, ir::IOperation &node)
{
_op_ind = ind;
node.accept(*this);
@@ -39,8 +39,9 @@ void PermutationEliminationPass::visit(const ir::operation::Permute &node)
// Check if two tensors are both portable if not, we can't eliminate the node
{
- auto in_def_factor = _lowered_graph.getLowerInfo(in_operand)->def_factors().getOnlyElement();
- auto out_def_factor = _lowered_graph.getLowerInfo(out_operand)->def_factors().getOnlyElement();
+ auto &operand_li_map = _lowered_graph.lower_info().operand;
+ auto in_def_factor = operand_li_map.getRawPtr(in_operand)->def_factors().getOnlyElement();
+ auto out_def_factor = operand_li_map.getRawPtr(out_operand)->def_factors().getOnlyElement();
auto in_config = in_def_factor.backend()->config();
auto out_config = out_def_factor.backend()->config();
@@ -53,59 +54,50 @@ void PermutationEliminationPass::visit(const ir::operation::Permute &node)
if (_graph.getOutputs().contains(out_operand))
{
+ // If the input is a const, we cannot remove it since we cannot put the constant data in the
+ // output buffer during prepare phase.
+ auto permute_input = node.getInputs().at(0);
+ if (_graph.operands().at(permute_input).isConstant())
+ return;
+ // If the input is a model input, we cannot remove it since our API lets users to set different
+ // buffers for inputs and outputs even though one tensor is both at the same time.
+ auto permute_output = node.getOutputs().at(0);
+ if (_graph.getInputs().contains(permute_input) && _graph.getOutputs().contains(permute_output))
+ return;
+ // Likewise, if copying between outputs to outputs, keep it.
+ if (_graph.getOutputs().contains(permute_input) && _graph.getOutputs().contains(permute_output))
+ return;
+
// Exceptional case : When the output operand is a model output
// In this case we keep the output and remove the input
auto &out_operand_obj = _graph.operands().at(out_operand);
assert(out_operand_obj.getDef() == _op_ind);
out_operand_obj.unsetDef();
- _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
- if (!op_seq.getOutputs().contains(in_operand))
+ _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) {
+ if (!op.getOutputs().contains(in_operand))
return;
-
- // Update OpSequence/ir::Operation edges and ir::Operand edges
- op_seq.replaceOutputs(in_operand, out_operand);
- for (auto op : op_seq.operations())
- {
- auto &operation_obj = _graph.operations().at(op);
- if (operation_obj.getOutputs().contains(in_operand))
- {
- operation_obj.replaceOutputs(in_operand, out_operand);
- out_operand_obj.setDef(op);
- }
- }
+ // Update Operation and Operand edges
+ op.replaceOutputs(in_operand, out_operand);
+ out_operand_obj.setDef(op_ind);
});
- // Remove Permute operation, enclosing OpSequence and the operand
+ // Remove Permute operation and the operand
{
_graph.removeOperand(in_operand);
-
- auto op_seq_ind = _lowered_graph.op_seqs().getOperation(_op_ind);
- // Assumes enclosing OpSequence contatins just this Permute operation
- assert(_lowered_graph.op_seqs().at(op_seq_ind).size() == 1);
- _lowered_graph.op_seqs().remove(op_seq_ind);
_graph.operations().remove(_op_ind);
}
- _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
- if (!op_seq.getInputs().contains(in_operand))
+ _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) {
+ if (!op.getInputs().contains(in_operand))
return;
-
- op_seq.replaceInputs(in_operand, out_operand);
- for (auto op : op_seq.operations())
- {
- auto &operation_obj = _graph.operations().at(op);
- if (operation_obj.getInputs().contains(in_operand))
- {
- operation_obj.replaceInputs(in_operand, out_operand);
- out_operand_obj.insertUse(op);
- }
- }
+ op.replaceInputs(in_operand, out_operand);
+ out_operand_obj.insertUse(op_ind);
});
VERBOSE(removePermute) << "Permute Op removed, node index : " << _op_ind << std::endl;
- VERBOSE(removePermute) << " - Input (removed) ir::Operand : " << in_operand << std::endl;
- VERBOSE(removePermute) << " - Output(kept) ir::Operand : " << out_operand << std::endl;
+ VERBOSE(removePermute) << " - Input (removed) Operand : " << in_operand << std::endl;
+ VERBOSE(removePermute) << " - Output(kept) Operand : " << out_operand << std::endl;
}
else
{
@@ -114,37 +106,23 @@ void PermutationEliminationPass::visit(const ir::operation::Permute &node)
auto &in_operand_obj = _graph.operands().at(in_operand);
in_operand_obj.removeUse(_op_ind);
- // Make OpSequences(that use the output) use the input
- _lowered_graph.op_seqs().iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
- if (!op_seq.getInputs().contains(out_operand))
+ // Make operations(that use the output) use the input
+ _graph.operations().iterate([&](const ir::OperationIndex &op_ind, ir::IOperation &op) {
+ if (!op.getInputs().contains(out_operand))
return;
-
- op_seq.replaceInputs(out_operand, in_operand);
- for (auto op : op_seq.operations())
- {
- auto &operation_obj = _graph.operations().at(op);
- if (operation_obj.getInputs().contains(out_operand))
- {
- operation_obj.replaceInputs(out_operand, in_operand);
- in_operand_obj.insertUse(op);
- }
- }
+ op.replaceInputs(out_operand, in_operand);
+ in_operand_obj.insertUse(op_ind);
});
- // Remove Permute operation, enclosing OpSequence and the operand
+ // Remove the Permute operation and out_operand
{
_graph.removeOperand(out_operand);
-
- auto op_seq_ind = _lowered_graph.op_seqs().getOperation(_op_ind);
- // Assumes enclosing OpSequence contatins just this Permute operation
- assert(_lowered_graph.op_seqs().at(op_seq_ind).size() == 1);
- _lowered_graph.op_seqs().remove(op_seq_ind);
_graph.operations().remove(_op_ind);
}
- VERBOSE(removePermute) << "Permute Op removed, node index : " << _op_ind << std::endl;
- VERBOSE(removePermute) << " - Input (kept) ir::Operand : " << in_operand << std::endl;
- VERBOSE(removePermute) << " - Output(removed) ir::Operand : " << out_operand << std::endl;
+ VERBOSE(removePermute) << "Permute Op removed : " << _op_ind << std::endl;
+ VERBOSE(removePermute) << " - Input (kept) Operand : " << in_operand << std::endl;
+ VERBOSE(removePermute) << " - Output(removed) Operand : " << out_operand << std::endl;
}
}
diff --git a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h
index 29daf1a82..18ba99804 100644
--- a/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h
+++ b/runtime/onert/core/src/compiler/pass/PermutationEliminationPass.h
@@ -35,7 +35,7 @@ namespace pass
* are compatible and layouts match.
*
* Permute input tensor is kept and the output is removed for all the cases, except model outputs.
- * As all output tensors have to be controlflow backend, so the output is kept.
+ * As all output tensors have to be builtin backend, so the output is kept.
*
* @note This is an optimization pass which means that everything should work fine even if this pass
* was skipped.
@@ -49,7 +49,7 @@ public:
std::string id() final { return "PermutationEliminationPass"; }
public:
- void callback(const ir::OperationIndex &i, ir::Operation &n) final;
+ void callback(const ir::OperationIndex &i, ir::IOperation &n) final;
private:
void visit(const ir::operation::Permute &) final;
diff --git a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc
index c83a72ada..f5ad7e636 100644
--- a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc
+++ b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.cc
@@ -9,6 +9,7 @@
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
@@ -16,18 +17,16 @@
#include "PermutationInsertionPass.h"
-#include <cassert>
-#include <utility>
-#include <unordered_map>
+#include "../../backend/builtin/Config.h"
-#include "backend/controlflow/Config.h"
-#include "ir/Operand.h"
-#include "ir/operation/LowerInfo.h"
-#include "ir/Graph.h"
-#include "backend/IConfig.h"
+#include "compiler/OperationLowerInfo.h"
+#include "ir/operation/Permute.h"
#include "util/logging.h"
+
+#include <cassert>
#include <memory>
-#include "ir/operation/Permute.h"
+#include <unordered_map>
+#include <utility>
namespace onert
{
@@ -38,7 +37,8 @@ namespace pass
void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Operand &object)
{
- auto &&operand_li = _lowered_graph.getLowerInfo(index);
+ auto &operand_li_map = _lowered_graph.lower_info().operand;
+ auto &&operand_li = operand_li_map.getRawPtr(index);
assert(operand_li);
// NOTE Later, constants also will have Def
@@ -51,16 +51,16 @@ void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Opera
std::list<ir::OperationIndex> permute_indexes;
// Build a map for all necessary type of operands
- std::unordered_map<ir::operand::PermuteFactor, ir::OperandIndex> factor_to_index;
+ std::unordered_map<PermuteFactor, ir::OperandIndex> factor_to_index;
{
assert(operand_li->def_factors().size() == 1);
- for (auto factor : operand_li->def_factors())
+ for (auto &&factor : operand_li->def_factors())
{
factor_to_index.emplace(factor, index);
}
auto insert_set = operand_li->use_factors() - operand_li->def_factors();
- for (auto factor : insert_set)
+ for (auto &&factor : insert_set)
{
const auto permute_operation_index = insertPermute(index, factor);
permute_indexes.push_back(permute_operation_index);
@@ -75,33 +75,23 @@ void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Opera
std::list<ir::OperationIndex> remove_list;
auto uses = object.getUses();
- for (auto use : uses)
+ for (auto &&use : uses)
{
// If permute operation, ignore it
if (std::find(permute_indexes.begin(), permute_indexes.end(), use) != permute_indexes.end())
continue;
auto &operation = _graph.operations().at(use);
- assert(_lowered_graph.op_seqs().containsOperation(use));
- auto op_seq_index = _lowered_graph.op_seqs().getOperation(use);
- auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index);
- assert(op_seq_li);
- const auto op_seq_layout = op_seq_li->layout();
- const backend::Backend *backend = op_seq_li->backend();
+ auto op_li = _lowered_graph.lower_info().operation.getRawPtr(use);
+ assert(op_li);
+ const auto op_layout = op_li->layout();
+ const backend::Backend *backend = op_li->backend();
assert(backend);
- auto use_node_inputs = operation.getInputs();
- assert(use_node_inputs.contains(index));
+ assert(operation.getInputs().contains(index));
- auto new_index = factor_to_index.at({backend, op_seq_layout});
+ auto new_index = factor_to_index.at({backend, op_layout});
if (index != new_index)
{
- // Update from op_seq
- // Replace the same inputs of an OpSequence at once for the following reasons:
- // 1. An OpSequence's inputs are the same inputs of first operation
- // 2. An OpSequence may have inputs as the same operand (2 or more).
- // 3. The same inputs of OpSequence have the same PermuteFactor.
- _lowered_graph.op_seqs().at(op_seq_index).replaceInputs(index, new_index);
-
// Update from operation
// Replace the same inputs of an operation at once for the following reasons:
// No. 2 and 3 above
@@ -109,63 +99,69 @@ void PermutationInsertionPass::callback(const ir::OperandIndex &index, ir::Opera
// Update from operand
remove_list.push_back(
- use); // Removal should be done in another loop since we are in the loop
+ use); // Removal should be done in another loop since we are in the loop
_graph.operands().at(new_index).insertUse(use);
}
}
- for (auto &operation : remove_list)
+ for (const auto &operation_index : remove_list)
{
- object.removeUse(operation);
+ object.removeUse(operation_index);
}
}
}
ir::OperationIndex PermutationInsertionPass::insertPermute(const ir::OperandIndex &operand_index,
- const ir::operand::PermuteFactor &factor)
+ const PermuteFactor &factor)
{
- assert(!_graph.isBuildingPhase());
-
auto &operand = _graph.operands().at(operand_index);
// Generate output operand and permute operation
auto out_operand_index = _graph.addOperand(operand.shape(), operand.typeInfo());
- // change model output if operand_index is model output index
+ // change model output if operand_index is model output index and the out operand is builtin
+ // backend
auto &model_outputs = _graph.getOutputs();
- if (model_outputs.contains(operand_index))
+ const backend::Backend *builtin_backend = compiler::BackendManager::get().getBuiltin();
+ assert(builtin_backend->config()->id() == onert::backend::builtin::Config::ID);
+
+ if (model_outputs.contains(operand_index) && factor.backend() == builtin_backend)
{
model_outputs.replace(operand_index, out_operand_index);
}
+ auto &operand_li_map = _lowered_graph.lower_info().operand;
+
// Find Permute information
- auto input_factor = _lowered_graph.getLowerInfo(operand_index)->def_factors().getOnlyElement();
+ auto input_factor = operand_li_map.getRawPtr(operand_index)->def_factors().getOnlyElement();
auto input_backend = input_factor.backend();
auto output_backend = factor.backend();
// NOTE Permute may not have specific layout because the layout of input and output may be
// different.
const auto permute_node_layout = ir::Layout::UNKNOWN;
// NOTE If one backend supports several layout, the backend must support Permute operation
- const backend::Backend *permute_node_backend = compiler::BackendManager::get().getControlflow();
+ const backend::Backend *permute_node_backend = compiler::BackendManager::get().getBuiltin();
+ assert(permute_node_backend->config()->id() == onert::backend::builtin::Config::ID);
+
if (input_backend == output_backend)
{
permute_node_backend = input_backend;
}
- const ir::operand::PermuteFactor permute_node_factor{permute_node_backend, permute_node_layout};
+ const PermuteFactor permute_node_factor{permute_node_backend, permute_node_layout};
// Update LowerInfo of input operand
- auto operand_lower_info = _lowered_graph.getLowerInfo(operand_index);
+ auto operand_lower_info = operand_li_map.getRawPtr(operand_index);
operand_lower_info->removeUsePermuteFactor(factor);
operand_lower_info->addUsePermuteFactor(permute_node_factor);
// Update LowerInfo of output operand
- auto out_operand_li = std::make_unique<ir::operand::LowerInfo>();
+ auto out_operand_li = std::make_unique<compiler::OperandLowerInfo>();
// The input and output factors of all nodes will be the same except Permute. So Tensor's
// allocators allocates memory using only the information of def permutation factor now.
// TODO Change param to permute_node_factor
out_operand_li->addDefPermuteFactor(factor);
out_operand_li->addUsePermuteFactor(factor);
- _lowered_graph.setLowerInfo(out_operand_index, std::move(out_operand_li));
+ operand_li_map.set(out_operand_index, std::move(out_operand_li));
// Insert permute operation to the graph
const auto input_layout = input_factor.layout();
@@ -188,20 +184,18 @@ ir::OperationIndex PermutationInsertionPass::insertPermute(const ir::OperandInde
auto insert_node = std::make_unique<Permute>(operand_index, out_operand_index, permute_type);
auto node_index = _graph.operations().push(std::move(insert_node));
- const auto &node = _graph.operations().at(node_index);
VERBOSE_F() << "Permute Op inserted, node index : " << node_index << std::endl;
- VERBOSE_F() << " - Input (original) Operand : " << operand_index << std::endl;
- VERBOSE_F() << " - Output(inserted) Operand : " << out_operand_index << std::endl;
+ VERBOSE_F() << " - Input (original) Operand : " << operand_index << "("
+ << input_factor.backend()->config()->id() << ")" << std::endl;
+ VERBOSE_F() << " - Output(inserted) Operand : " << out_operand_index << "("
+ << factor.backend()->config()->id() << ")" << std::endl;
- // OpSequence
+ // Operation LowerInfo
{
- auto op_seq_index = _lowered_graph.op_seqs().emplace(node_index, permute_node_layout);
- auto &op_seq = _lowered_graph.op_seqs().at(op_seq_index);
- op_seq.setInputs(node.getInputs());
- op_seq.setOutputs(node.getOutputs());
- _lowered_graph.setLowerInfo(op_seq_index, std::make_unique<ir::operation::LowerInfo>(
- permute_node_backend, permute_node_layout));
+ auto &operation_li_map = _lowered_graph.lower_info().operation;
+ operation_li_map.set(node_index, std::make_unique<compiler::OperationLowerInfo>(
+ permute_node_backend, permute_node_layout));
}
// Update Use/Def info
diff --git a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.h b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.h
index 758515385..ee0a1464c 100644
--- a/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.h
+++ b/runtime/onert/core/src/compiler/pass/PermutationInsertionPass.h
@@ -20,7 +20,7 @@
#include "LoweredOperandPass.h"
#include "compiler/BackendManager.h"
#include "ir/Operand.h"
-#include "ir/operand/PermuteFactor.h"
+#include "compiler/PermuteFactor.h"
namespace onert
{
@@ -48,7 +48,7 @@ private:
* @return ir::OperationIndex
*/
ir::OperationIndex insertPermute(const ir::OperandIndex &operand_index,
- const ir::operand::PermuteFactor &factor);
+ const PermuteFactor &factor);
};
} // namespace pass
diff --git a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc
index c5c95c726..f014d29d3 100644
--- a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc
+++ b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.cc
@@ -30,10 +30,10 @@ namespace pass
using namespace ir;
-void PermutationOperationPass::callback(const OperationIndex &, Operation &node)
+void PermutationOperationPass::callback(const OperationIndex &, IOperation &node)
{
node.accept(*this);
-};
+}
// TODO Remove this. Expanding ranks of Operand is dangerous
void PermutationOperationPass::applyExpandRanks(const Operation &node)
@@ -43,9 +43,8 @@ void PermutationOperationPass::applyExpandRanks(const Operation &node)
assert(output.getDef().valid());
const auto node_index = output.getDef();
- const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
- const auto frontend_layout = _lowered_graph.op_seqs().at(op_seq_index).getLayout();
- const auto backend_layout = _lowered_graph.getLowerInfo(op_seq_index)->layout();
+ const auto frontend_layout = _graph.layout();
+ const auto backend_layout = _lowered_graph.lower_info().operation.getRawPtr(node_index)->layout();
if (frontend_layout == backend_layout)
{
@@ -84,10 +83,11 @@ void PermutationOperationPass::changeToKeepLayout(const Operation &node)
assert(output_obj.getDef().valid());
const auto node_index = output_obj.getDef();
- const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
- const auto frontend_layout = _lowered_graph.op_seqs().at(op_seq_index).getLayout();
- const auto backend_layout = _lowered_graph.getLowerInfo(op_seq_index)->layout();
+ auto &operation_li_map = _lowered_graph.lower_info().operation;
+ auto &operand_li_map = _lowered_graph.lower_info().operand;
+ const auto frontend_layout = _graph.layout();
+ const auto backend_layout = operation_li_map.getRawPtr(node_index)->layout();
if (frontend_layout == backend_layout)
{
@@ -97,96 +97,27 @@ void PermutationOperationPass::changeToKeepLayout(const Operation &node)
// Permutation changing layout beyond 4-D is not supported yet
assert(output_obj.shape().rank() <= 4);
- // Divide op_seq based on target operation
- {
- auto &prev_op_seq = _lowered_graph.op_seqs().at(op_seq_index);
- auto &operations = _lowered_graph.graph().operations();
-
- // Create new op_seq and move information from existing op_seq to new op_seq if target
- // node is the end of op_seq
- auto it = prev_op_seq.begin();
- // Find iterator of target node in op_seq
- while (*(it++) != node_index)
- ;
- if (it != prev_op_seq.end())
- {
- const auto &target_op_idx = *it;
- const auto &target_node = operations.at(target_op_idx);
- const auto &next_op_seq_index =
- _lowered_graph.op_seqs().emplace(target_op_idx, prev_op_seq.getLayout());
- auto &next_op_seq = _lowered_graph.op_seqs().at(next_op_seq_index);
- next_op_seq.setInputs(target_node.getInputs());
- next_op_seq.setOutputs(target_node.getOutputs());
-
- std::vector<OperationIndex> remove_list;
- remove_list.emplace_back(target_op_idx);
- while (++it != prev_op_seq.end())
- {
- next_op_seq.appendOperation(target_op_idx);
- next_op_seq.setOutputs(target_node.getOutputs());
- remove_list.emplace_back(target_op_idx);
- }
-
- prev_op_seq.setOutputs(node.getOutputs());
- for (const auto &index : remove_list)
- {
- prev_op_seq.remove(index);
- }
-
- const auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index);
- _lowered_graph.setLowerInfo(
- next_op_seq_index,
- std::make_unique<ir::operation::LowerInfo>(op_seq_li->backend(), op_seq_li->layout()));
- }
- }
-
- // Remove target operation from op_seq and insert the target operation to new op_seq
+ // Change PermuteFactors of operands and the operation of target node
{
- const auto backend = _lowered_graph.getLowerInfo(op_seq_index)->backend();
+ const auto op_li = operation_li_map.getRawPtr(node_index);
+ const auto backend = op_li->backend();
- // Remove target operation from op_sequence
- _lowered_graph.op_seqs().removeFromOpSequence(node_index);
+ operation_li_map.set(node_index,
+ std::make_unique<compiler::OperationLowerInfo>(backend, frontend_layout));
- if (!_lowered_graph.op_seqs().exist(op_seq_index))
- {
- // Remove lowerinfo for op_seq of target operation if the op_seq does not exist
- _lowered_graph.removeLowerInfo(op_seq_index);
- }
- else
- {
- // Update op_seq of target operation if the op_seq exists
- auto &prev_op_seq = _lowered_graph.op_seqs().at(op_seq_index);
- const auto &last_node_idx = *(--prev_op_seq.end());
- const auto &last_node = _lowered_graph.graph().operations().at(last_node_idx);
- prev_op_seq.setOutputs(last_node.getOutputs());
- }
-
- // Create new op_seq and set information to the op_seq
- auto new_op_seq_index = _lowered_graph.op_seqs().emplace(node_index, frontend_layout);
- auto &new_op_seq = _lowered_graph.op_seqs().at(new_op_seq_index);
- new_op_seq.setInputs(node.getInputs());
- new_op_seq.setOutputs(node.getOutputs());
- _lowered_graph.setLowerInfo(
- new_op_seq_index, std::make_unique<ir::operation::LowerInfo>(backend, frontend_layout));
- }
-
- // Change PermuteFactors of operands of target node
- {
- const auto &op_seq_index = _lowered_graph.op_seqs().getOperation(node_index);
- const auto op_seq_li = _lowered_graph.getLowerInfo(op_seq_index);
- const auto backend = op_seq_li->backend();
- const operand::PermuteFactor removed_factor{backend, backend_layout};
- const operand::PermuteFactor new_factor{backend, frontend_layout};
+ const PermuteFactor removed_factor{backend, backend_layout};
+ const PermuteFactor new_factor{backend, frontend_layout};
for (const auto &input : node.getInputs() | Remove::DUPLICATED | Remove::UNDEFINED)
{
+ // Check if it can be removed by checking if the operand is used by another operation and
+ // it uses the same backend and layout
bool canRemove = true;
for (const auto &use : _graph.operands().at(input).getUses())
{
if (use != node_index)
{
- const auto &use_op_seq_index = _lowered_graph.op_seqs().getOperation(use);
- auto use_op_seq_li = _lowered_graph.getLowerInfo(use_op_seq_index);
- if (use_op_seq_li->backend() == backend && use_op_seq_li->layout() == backend_layout)
+ auto use_op_li = operation_li_map.getRawPtr(use);
+ if (use_op_li->backend() == backend && use_op_li->layout() == backend_layout)
{
canRemove = false;
break;
@@ -194,27 +125,27 @@ void PermutationOperationPass::changeToKeepLayout(const Operation &node)
}
}
- auto lower_info = _lowered_graph.getLowerInfo(input);
+ auto input_li = operand_li_map.getRawPtr(input);
if (canRemove)
{
- lower_info->removeUsePermuteFactor(removed_factor);
+ input_li->removeUsePermuteFactor(removed_factor);
}
- lower_info->addUsePermuteFactor(new_factor);
+ input_li->addUsePermuteFactor(new_factor);
// Whether if node's input is an input of model or a constant
if (!_graph.operands().at(input).getDef().valid() &&
- (lower_info->def_factors().size() == 1 &&
- lower_info->def_factors().getOnlyElement() == removed_factor))
+ (input_li->def_factors().size() == 1 &&
+ input_li->def_factors().getOnlyElement() == removed_factor))
{
assert(_graph.getInputs().contains(input) || _graph.operands().at(input).isConstant());
- lower_info->removeDefPermuteFactor(removed_factor);
- lower_info->addDefPermuteFactor(new_factor);
+ input_li->removeDefPermuteFactor(removed_factor);
+ input_li->addDefPermuteFactor(new_factor);
}
}
- for (const auto &output : node.getOutputs() | Remove::DUPLICATED)
+ for (const auto &output : node.getOutputs() | Remove::DUPLICATED | Remove::UNDEFINED)
{
- auto lower_info = _lowered_graph.getLowerInfo(output);
+ auto lower_info = operand_li_map.getRawPtr(output);
lower_info->removeDefPermuteFactor(removed_factor);
lower_info->addDefPermuteFactor(new_factor);
@@ -279,6 +210,18 @@ void PermutationOperationPass::visit(const ir::operation::Gather &node)
}
}
+void PermutationOperationPass::visit(const ir::operation::OneHot &node)
+{
+ const auto &output_ind = node.getOutputs().at(0);
+ const auto &output_obj = _graph.operands().at(output_ind);
+ const auto &output_shape = output_obj.shape();
+
+ if (output_shape.rank() >= 4)
+ {
+ changeToKeepLayout(node);
+ }
+}
+
void PermutationOperationPass::visit(const ir::operation::Pack &node)
{
const auto &input_ind = node.getInputs().at(ir::operation::Reshape::Input::INPUT);
diff --git a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h
index 2dd76b971..e253a77ad 100644
--- a/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h
+++ b/runtime/onert/core/src/compiler/pass/PermutationOperationPass.h
@@ -36,7 +36,7 @@ public:
std::string id() final { return "PermutationOperationPass"; }
public:
- void callback(const ir::OperationIndex &i, ir::Operation &n) final;
+ void callback(const ir::OperationIndex &i, ir::IOperation &n) final;
public:
void visit(const ir::operation::BinaryArithmetic &) final;
@@ -44,6 +44,7 @@ public:
void visit(const ir::operation::Concat &) final;
void visit(const ir::operation::ElementwiseBinary &) final;
void visit(const ir::operation::ElementwiseUnary &) final;
+ void visit(const ir::operation::OneHot &) final;
void visit(const ir::operation::Pack &) final;
void visit(const ir::operation::PReLU &) final;
void visit(const ir::operation::SquaredDifference &) final;
diff --git a/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc
new file mode 100644
index 000000000..162c4e7ef
--- /dev/null
+++ b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.cc
@@ -0,0 +1,64 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Pass.h"
+
+#include "UnusedOperandEliminationPass.h"
+#include "ir/Index.h"
+#include "util/Set.h"
+#include "ir/Graph.h"
+
+/**
+ * @file UnusedOperandEliminationPass.cc
+ * @brief This file contains UnusedOperandEliminationPass class implementation
+ */
+
+namespace onert
+{
+namespace compiler
+{
+namespace pass
+{
+
+void UnusedOperandEliminationPass::run()
+{
+ util::Set<ir::OperandIndex> used;
+
+ _graph.operations().iterate([&](const ir::OperationIndex &, const ir::IOperation &node) {
+ for (auto &&ind : (node.getInputs() + node.getOutputs()) | ir::Remove::UNDEFINED)
+ {
+ used.add(ind);
+ }
+ });
+
+ // Graph's inputs/outputs are always considered as used
+ for (auto &&ind : (_graph.getInputs() + _graph.getOutputs()) | ir::Remove::UNDEFINED)
+ {
+ used.add(ind);
+ }
+
+ _graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &) {
+ if (!used.contains(ind))
+ {
+ VERBOSE() << "Remove unused operand " << ind << std::endl;
+ _graph.operands().remove(ind);
+ }
+ });
+}
+
+} // namespace pass
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.h b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.h
new file mode 100644
index 000000000..8078f4246
--- /dev/null
+++ b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.h
@@ -0,0 +1,54 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * @file UnusedOperandEliminationPass.h
+ * @brief This file contains UnusedOperandEliminationPass class
+ */
+
+#ifndef __ONERT_COMPILER_PASS_UNUSED_OPERAND_ELIMINATION_PASS_H__
+#define __ONERT_COMPILER_PASS_UNUSED_OPERAND_ELIMINATION_PASS_H__
+
+#include "Pass.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace pass
+{
+
+/**
+ * @brief A pass to eliminate unused operands from the graph
+ *
+ * Remove operands that are not used by any operations, except Graph inputs/outputs.
+ *
+ */
+class UnusedOperandEliminationPass : public Pass
+{
+public:
+ using Pass::Pass;
+
+public:
+ std::string id() override { return "UnusedOperandEliminationPass"; }
+ void run() final;
+};
+
+} // namespace pass
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_PASS_UNUSED_OPERAND_ELIMINATION_PASS_H__
diff --git a/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.test.cc b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.test.cc
new file mode 100644
index 000000000..572b4df24
--- /dev/null
+++ b/runtime/onert/core/src/compiler/pass/UnusedOperandEliminationPass.test.cc
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "UnusedOperandEliminationPass.h"
+
+#include "ir/Graph.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::ir;
+using namespace onert::compiler::pass;
+
+TEST(UnusedOperandEliminationPass, Simple)
+{
+ Graph graph;
+
+ // Add tensors
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+ auto in = graph.addOperand(shape, type);
+ auto out = graph.addOperand(shape, type);
+
+ auto unused = graph.addOperand(shape, type);
+
+ // Set model inputs/outputs
+ graph.addInput(in);
+ graph.addOutput(out);
+
+ UnusedOperandEliminationPass{graph}.run();
+
+ ASSERT_TRUE(graph.operands().exist(in));
+ ASSERT_TRUE(graph.operands().exist(out));
+ ASSERT_FALSE(graph.operands().exist(unused));
+}
diff --git a/runtime/onert/core/src/compiler/train/LoweredTrainableGraph.cc b/runtime/onert/core/src/compiler/train/LoweredTrainableGraph.cc
new file mode 100644
index 000000000..8b368c440
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/LoweredTrainableGraph.cc
@@ -0,0 +1,286 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "compiler/train/LoweredTrainableGraph.h"
+
+#include "../ManualScheduler.h"
+#include "../pass/ConstantInsertionPass.h"
+#include "../pass/ConstantLoweringPass.h"
+#include "../pass/PassRunner.h"
+#include "../pass/PermutationEliminationPass.h"
+#include "../pass/PermutationInsertionPass.h"
+#include "../pass/PermutationOperationPass.h"
+#include "../../backend/builtin/Config.h"
+#include "../../dumper/text/GraphDumper.h"
+#include "../../ir/verifier/Verifier.h"
+#include "TrainableOperationConverter.h"
+
+#include "backend/Backend.h"
+#include "backend/train/ITrainableBackend.h"
+#include "compiler/BackendResolver.h"
+#include "util/logging.h"
+
+#include <cassert>
+#include <sstream>
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+LoweredTrainableGraph::LoweredTrainableGraph(ir::train::TrainableGraph &graph,
+ const CompilerOptions &options)
+ : _trainable_graph{graph}
+{
+ lowerGraph(options);
+}
+
+void LoweredTrainableGraph::lowerGraph(const CompilerOptions &options)
+{
+ // Build backend contexts
+ auto &backend_manager = BackendManager::get();
+ // Create contexts for other backends
+ for (auto &&backend_str : options.backend_list)
+ {
+ backend_manager.loadBackend(backend_str);
+ auto backend = backend_manager.get(backend_str);
+
+ // TODO As the default value of backend list contains "cpu", "acl_cl" and "acl_neon", and some
+ // are not available on x64 or some other platforms. So this may be a workaround for x64 and
+ // we should change it back(throw if backend is not loaded) later.
+ if (!backend)
+ {
+ VERBOSE(LoweredTrainableGraph) << "Cannot load backend - " << backend_str << std::endl;
+ continue;
+ }
+ }
+ if (backend_manager.num_backends() == 0)
+ throw std::runtime_error{"No available backends loaded."};
+
+ // TODO Move "schedule" phase out of here
+ // TODO Scheduling
+ std::unique_ptr<BackendResolver> backend_resolver;
+ auto all_backends = backend_manager.getAll();
+
+ auto scheduler = ManualScheduler(all_backends, options);
+ backend_resolver = scheduler.schedule(_trainable_graph.graph());
+
+ // Check if backends are trainable
+ _trainable_graph.operations().iterate(
+ [&](const ir::OperationIndex &op_ind, const ir::IOperation &) {
+ const auto backend = backend_resolver->getBackend(op_ind);
+
+ // TODO Remove dynamic_cast
+ if (dynamic_cast<const backend::train::ITrainableBackend *>(backend) == nullptr)
+ {
+ throw std::runtime_error(backend->config()->id() + "backend does not support training");
+ }
+ });
+
+ makeLowerInfo(*backend_resolver);
+ VERBOSE(LoweredTrainableGraph) << "dump before mandatory passes" << std::endl;
+ dumper::text::dumpLoweredGraph(*this);
+
+ // Mandatory passes - kind of legalization(?)
+ compiler::pass::PassRunner{}
+ .append(std::make_unique<compiler::pass::ConstantInsertionPass>(*this))
+ .append(std::make_unique<compiler::pass::ConstantLoweringPass>(*this))
+ .append(std::make_unique<compiler::pass::PermutationOperationPass>(*this))
+ .append(std::make_unique<compiler::pass::PermutationInsertionPass>(*this))
+ .run();
+
+ // TODO Move converting Permute op into PermutationInsertionPass
+ auto op_converter = TrainableOperationConverter{_trainable_graph, nullptr};
+ _trainable_graph.operations().iterate(
+ [&](const onert::ir::OperationIndex &index, const onert::ir::IOperation &op) {
+ if (op.opcode() == ir::OpCode::Permute)
+ {
+ auto trainable_op = op_converter(op);
+ trainable_op->enableBackward();
+ auto gen_index = _trainable_graph.replaceOperation(index, std::move(trainable_op));
+ UNUSED_RELEASE(gen_index);
+ assert(gen_index == index);
+ }
+ });
+
+ dumpLowerInfo();
+
+ // Optimization passes (optional)
+ compiler::pass::PassRunner{}
+ .append(std::make_unique<compiler::pass::PermutationEliminationPass>(*this))
+ .run();
+
+ // TODO Update LowerInfo for training
+
+ VERBOSE(LoweredTrainableGraph) << "Dump after all the passes" << std::endl;
+ for (auto &&operand : _trainable_graph.getInputs())
+ VERBOSE(LoweredTrainableGraph) << "Graph Input : " << operand << std::endl;
+ for (auto &&operand : _trainable_graph.getOutputs())
+ VERBOSE(LoweredTrainableGraph) << "Graph Output : " << operand << std::endl;
+ dumper::text::dumpLoweredGraph(*this);
+
+ // Graph verifications
+ {
+ assert(ir::verifier::InputOutputChecker().verify(_trainable_graph.graph()));
+ assert(ir::verifier::DAGChecker().verify(_trainable_graph.graph()));
+ assert(ir::verifier::EdgeChecker().verify(_trainable_graph.graph()));
+ }
+}
+
+void LoweredTrainableGraph::makeLowerInfo(const compiler::BackendResolver &backend_resolver)
+{
+ _trainable_graph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) {
+ lower_info().operand.set(index, std::make_unique<OperandLowerInfo>());
+ });
+
+ // Set operand lower info using assigned backends to operations
+ _trainable_graph.operations().iterate(
+ [&](const ir::OperationIndex &op_ind, const ir::IOperation &op) {
+ auto backend = backend_resolver.getBackend(op_ind);
+ if (!backend)
+ {
+ throw std::runtime_error{"Fail to find backend for " + op.name() + " operation"};
+ }
+
+ auto frontend_layout = _trainable_graph.layout();
+
+ // The layout of each backend should be set at another place
+ // TODO Change setting layout of each backend at another place
+ auto backend_layout = backend->config()->supportLayout(op, frontend_layout);
+
+ for (auto &&ind : op.getInputs() | ir::Remove::UNDEFINED)
+ {
+ auto &operand_li = lower_info().operand.at(ind);
+ operand_li.addUsePermuteFactor(PermuteFactor{backend, backend_layout});
+ }
+ for (auto &&ind : op.getOutputs() | ir::Remove::UNDEFINED)
+ {
+ auto &operand_li = lower_info().operand.at(ind);
+ operand_li.addDefPermuteFactor(PermuteFactor{backend, backend_layout});
+ }
+ lower_info().operation.set(
+ op_ind, std::make_unique<compiler::OperationLowerInfo>(backend, backend_layout));
+ });
+
+ // Handle graph inputs and outputs
+ const auto builtin_backend = BackendManager::get().getBuiltin();
+ auto factor = PermuteFactor{builtin_backend, _trainable_graph.layout()};
+ for (auto &&index : _trainable_graph.getInputs() | ir::Remove::UNDEFINED)
+ {
+ auto &operand_li = lower_info().operand.at(index);
+ assert(operand_li.def_factors().empty());
+ operand_li.addDefPermuteFactor(factor);
+ }
+ for (auto &&index : _trainable_graph.getOutputs() | ir::Remove::UNDEFINED)
+ {
+ auto &operand_li = lower_info().operand.at(index);
+ operand_li.addUsePermuteFactor(factor);
+ }
+
+ // Handle variable tensors
+ _trainable_graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &operand) {
+ // Some inputs of an operation could be non-constant, but not existed in graph inputs/outputs
+ // and not undefined operand - these are variable tensors. For example,
+ // UnidirectionalSequenceLSTM has such inputs.
+ if (operand.info().isVariable())
+ {
+ // The variable operand with buffer is not supported yet
+ assert(operand.data() == nullptr);
+ assert(operand.getUses().size() == 1 && !operand.getDef().valid());
+ auto operand_li = lower_info().operand.at(index);
+ assert(operand_li.def_factors().empty());
+ operand_li.addDefPermuteFactor(operand_li.use_factors().getOnlyElement());
+ }
+ });
+}
+
+void LoweredTrainableGraph::dumpLowerInfo()
+{
+ if (::onert::util::logging::ctx.enabled() == false)
+ return;
+
+ std::map<uint32_t, std::string> dumps;
+
+ _trainable_graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &object) {
+ const auto operand_lower_info = lower_info().operand.getRawPtr(index);
+ assert(operand_lower_info);
+ if (!operand_lower_info->def_factors().empty() || !operand_lower_info->use_factors().empty())
+ {
+ auto shape_to_string = [](const ir::Shape &shape) {
+ std::stringstream sstream;
+ sstream << "{ ";
+ for (auto i = 0; i < shape.rank(); ++i)
+ sstream << (shape.dim(i)) << " ";
+ sstream << "}";
+ return sstream.str();
+ };
+
+ auto factors_to_string = [](const PermuteFactorSet &factors) {
+ std::string str;
+ for (auto &&factor : factors)
+ {
+ str += factor.backend()->config()->id();
+ str += "(" + to_string(factor.layout()) + ")";
+ str += " ";
+ }
+ return "{ " + str + "}";
+ };
+
+ auto operation_index_set_to_string = [](const ir::OperationIndexSet &operations) {
+ std::stringstream sstream;
+ sstream << "{ ";
+ for (auto &&op : operations)
+ sstream << op << " ";
+ sstream << "}";
+ return sstream.str();
+ };
+
+ auto data_to_str = [](const ir::Data *data) {
+ return (data ? (std::to_string(data->size()) + " bytes") : "N/A");
+ };
+
+ std::string shape_str = shape_to_string(object.shape());
+ std::string def_op = operation_index_set_to_string({object.getDef()});
+ std::string use_ops = operation_index_set_to_string(object.getUses());
+ std::string def_factors = factors_to_string(operand_lower_info->def_factors());
+ std::string use_factors = factors_to_string(operand_lower_info->use_factors());
+ std::stringstream sstream;
+ sstream << "Operand " << index << " Info" << std::endl;
+ sstream << " - Shape : " << shape_str << std::endl;
+ sstream << " - Def/Uses : Def " << def_op << " Uses " << use_ops << std::endl;
+ sstream << " - Data : " << data_to_str(object.data()) << std::endl;
+ sstream << " - LowerInfo : Def " << def_factors << " Uses " << use_factors << std::endl;
+ dumps.emplace(index.value(), sstream.str());
+ }
+ });
+
+ for (const auto &e : dumps)
+ {
+ if (!e.second.empty())
+ {
+ std::istringstream iss(e.second);
+ std::string line;
+ while (std::getline(iss, line))
+ VERBOSE(Lower) << line << std::endl;
+ }
+ }
+}
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/train/StaticBackwardShapeInferer.cc b/runtime/onert/core/src/compiler/train/StaticBackwardShapeInferer.cc
new file mode 100644
index 000000000..eae8cdeef
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/StaticBackwardShapeInferer.cc
@@ -0,0 +1,151 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "StaticBackwardShapeInferer.h"
+#include "util/ShapeInference.h"
+#include "util/logging.h"
+
+#include <misc/polymorphic_downcast.h>
+
+#include <sstream>
+#include <stdexcept>
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+void StaticBackwardShapeInferer::infer()
+{
+ // It is not determined to iterate in reverse order.
+ auto sorted_ops = _lowered_subg->graph().topolSortOperations();
+ for (auto it = sorted_ops.rbegin(); it != sorted_ops.rend(); ++it)
+ {
+ const auto op_idx = *it;
+ const auto &op = _lowered_subg->trainable_graph().operation(op_idx);
+ if (checkDynamicInput(op))
+ {
+ std::stringstream msg;
+ msg << "StaticBackwardShapeInferer does not support dynamic shape yet, ";
+ msg << op.name() << "(op index: " << op_idx << ") has dynamic shape.";
+ throw std::runtime_error(msg.str());
+ }
+
+ checkOutput(op);
+
+ op.accept(*this);
+ }
+}
+
+void StaticBackwardShapeInferer::dump()
+{
+ // TODO dump
+}
+
+bool StaticBackwardShapeInferer::checkDynamicInput(const ir::IOperation &op)
+{
+ const auto &operands = _lowered_subg->graph().operands();
+ for (const auto &input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
+ {
+ if (operands.at(input_idx).info().isDynamic())
+ {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+void StaticBackwardShapeInferer::checkOutput(const ir::IOperation &op)
+{
+ const auto &bwd_operands = _lowered_subg->trainable_graph().backward_operands();
+ for (const auto &output_idx : op.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
+ {
+ if (!bwd_operands.exist(output_idx))
+ {
+ std::stringstream msg;
+ msg << "StaticBackwardShapeInferer : Invalid output, ";
+ msg << op.name() << "'s back propagation output(index: " << output_idx << ") does not exist.";
+ throw std::runtime_error(msg.str());
+ }
+ }
+}
+
+void StaticBackwardShapeInferer::setShape(const ir::OperandIndex &index, const ir::Shape &shape)
+{
+ auto &tgraph = _lowered_subg->trainable_graph();
+
+ if (tgraph.backward_operands().exist(index))
+ tgraph.changeBackwardShape(index, shape);
+ else
+ {
+ // NOTE This code assumes the types are always the same, but I'm not sure.
+ const auto &type = tgraph.operands().at(index).typeInfo();
+ const auto new_index =
+ tgraph.addBackwardOperand(index, std::make_unique<ir::Operand>(shape, type));
+ assert(new_index == index);
+ UNUSED_RELEASE(new_index);
+ }
+}
+
+void StaticBackwardShapeInferer::visit(const ir::train::operation::Conv2D &)
+{
+ // NYI
+}
+
+void StaticBackwardShapeInferer::visit(const ir::train::operation::ElementwiseActivation &)
+{
+ // NYI
+}
+
+void StaticBackwardShapeInferer::visit(const ir::train::operation::Loss &)
+{
+ // NYI
+}
+
+void StaticBackwardShapeInferer::visit(const ir::train::operation::Permute &op)
+{
+ const auto &bwd_operands = _lowered_subg->trainable_graph().backward_operands();
+
+ const auto &output_idx = op.getOutputs().at(0);
+ const auto &output = bwd_operands.at(output_idx);
+
+ // re-sizing shape of back propagatation input
+ const auto &input_idx = op.getInputs().at(0);
+ const auto &new_shape = output.info().shape();
+ setShape(input_idx, new_shape);
+}
+
+void StaticBackwardShapeInferer::visit(const ir::train::operation::Pool2D &)
+{
+ // NYI
+}
+
+void StaticBackwardShapeInferer::visit(const ir::train::operation::Reshape &)
+{
+ // NYI
+}
+
+void StaticBackwardShapeInferer::visit(const ir::train::operation::Softmax &)
+{
+ // NYI
+}
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/train/StaticBackwardShapeInferer.h b/runtime/onert/core/src/compiler/train/StaticBackwardShapeInferer.h
new file mode 100644
index 000000000..2ad9bca5e
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/StaticBackwardShapeInferer.h
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_TRAIN_STATIC_BACKWARD_SHAPE_INFERER_H__
+#define __ONERT_COMPILER_TRAIN_STATIC_BACKWARD_SHAPE_INFERER_H__
+
+#include "ir/train/TrainableOperationVisitor.h"
+
+#include "compiler/train/LoweredTrainableGraph.h"
+#include "ir/Index.h"
+
+#include <memory>
+#include <unordered_map>
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+/**
+ * @brief Class to infer shape before running kernels. It does the following:
+ * - re-calculate and set output shape at compile time (before running kernels)
+ * - if calculation cannot be done at compile time, mark the outputs to be dynamic, meaning
+ * shapes of outputs will be calculated during running kernels
+ */
+class StaticBackwardShapeInferer : public ir::train::TrainableOperationVisitor
+{
+public:
+ StaticBackwardShapeInferer(compiler::train::LoweredTrainableGraph *lowered_subg)
+ : _lowered_subg{lowered_subg}
+ {
+ }
+
+ /**
+ * @brief Infer shape of operands belonging to ops and set the output shape.
+ * If output shape cannot be known without running op, mark it so that it can be allocated
+ * when running kernel.
+ */
+ void infer(void);
+
+ void dump();
+
+private:
+ bool checkDynamicInput(const ir::IOperation &op);
+ void checkOutput(const ir::IOperation &op);
+ void setShape(const ir::OperandIndex &index, const ir::Shape &shape);
+
+private:
+ void visit(const ir::train::operation::Conv2D &op) override;
+ void visit(const ir::train::operation::ElementwiseActivation &op) override;
+ void visit(const ir::train::operation::Loss &op) override;
+ void visit(const ir::train::operation::Permute &op) override;
+ void visit(const ir::train::operation::Pool2D &op) override;
+ void visit(const ir::train::operation::Reshape &op) override;
+ void visit(const ir::train::operation::Softmax &op) override;
+
+private:
+ compiler::train::LoweredTrainableGraph *_lowered_subg;
+};
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_TRAIN_STATIC_BACKWARD_SHAPE_INFERER_H__
diff --git a/runtime/onert/core/src/compiler/train/TensorRegistries.h b/runtime/onert/core/src/compiler/train/TensorRegistries.h
new file mode 100644
index 000000000..8886c9bd4
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/TensorRegistries.h
@@ -0,0 +1,114 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
+#define __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
+
+#include "../../backend/builtin/Config.h"
+#include "../../backend/builtin/train/TensorRegistry.h"
+
+#include <backend/train/ITensorRegistry.h>
+#include <backend/train/TrainableBackendContext.h>
+
+#include <memory>
+#include <unordered_set>
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+class TensorRegistries
+{
+public:
+ TensorRegistries() = default;
+
+ TensorRegistries(const backend::train::TrainableBackendContexts &backend_contexts,
+ bool include_builtin)
+ {
+ for (const auto &e : backend_contexts)
+ {
+ auto tensor_reg = e.second->tensor_registry();
+ if (e.first->config()->id() == backend::builtin::Config::ID)
+ {
+ _builtin_tensor_reg =
+ std::dynamic_pointer_cast<backend::builtin::train::TensorRegistry>(tensor_reg);
+ if (include_builtin)
+ _tensor_regs.insert(tensor_reg);
+ }
+ else
+ {
+ _tensor_regs.insert(tensor_reg);
+ }
+ }
+ }
+
+ std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator begin() const
+ {
+ return _tensor_regs.cbegin();
+ }
+ std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>>::const_iterator end() const
+ {
+ return _tensor_regs.cend();
+ }
+
+ std::shared_ptr<backend::builtin::train::TensorRegistry> getBuiltinTensorRegistry() const
+ {
+ return _builtin_tensor_reg;
+ }
+
+ backend::ITensor *getITensor(ir::OperandIndex index) const
+ {
+ for (const auto &tensor_reg : _tensor_regs)
+ {
+ auto tensor = tensor_reg->getITensor(index);
+ if (tensor)
+ return tensor;
+ }
+ return nullptr;
+ }
+
+ backend::ITensor *getBackPropITensor(ir::OperandIndex index) const
+ {
+ for (const auto &tensor_reg : _tensor_regs)
+ {
+ auto tensor = tensor_reg->getBackPropITensor(index);
+ if (tensor)
+ return tensor;
+ }
+ return nullptr;
+ }
+
+ void iterateTrainableTensors(
+ const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)>
+ &fn) const
+ {
+ for (const auto &tensor_reg : _tensor_regs)
+ tensor_reg->iterateTrainableTensors(fn);
+ }
+
+private:
+ std::unordered_set<std::shared_ptr<backend::train::ITensorRegistry>> _tensor_regs;
+ std::shared_ptr<backend::builtin::train::TensorRegistry> _builtin_tensor_reg;
+};
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_TRAIN_TENSOR_REGISTRIES_H__
diff --git a/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc
new file mode 100644
index 000000000..80ed05aa5
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.cc
@@ -0,0 +1,106 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TrainableOperationConverter.h"
+
+#include "ir/train/Operations.Include.h"
+#include "util/Utils.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+TrainableOperationConverter::TrainableOperationConverter(
+ ir::train::TrainableGraph &tgraph, const ir::train::TrainingInfo *training_info)
+ : UntrainableOperationConverter{tgraph}, _training_info{training_info}
+{
+ // Avoid unused-private-field error
+ UNUSED_RELEASE(_training_info);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::BinaryArithmetic &node)
+{
+ _return_op = std::make_unique<ir::train::operation::BinaryArithmetic>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::Conv2D &node)
+{
+ _return_op = std::make_unique<ir::train::operation::Conv2D>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::DepthwiseConv2D &node)
+{
+ _return_op = std::make_unique<ir::train::operation::DepthwiseConv2D>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::ElementwiseActivation &node)
+{
+ if (node.param().op_type == ir::operation::ElementwiseActivation::Type::RELU)
+ {
+ _return_op = std::make_unique<ir::train::operation::ElementwiseActivation>(node);
+ }
+ else
+ {
+ UntrainableOperationConverter::visit(node);
+ }
+}
+
+void TrainableOperationConverter::visit(const ir::operation::FullyConnected &node)
+{
+ _return_op = std::make_unique<ir::train::operation::FullyConnected>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::Loss &node)
+{
+ _return_op = std::make_unique<ir::train::operation::Loss>(node, _training_info->lossInfo());
+}
+
+void TrainableOperationConverter::visit(const ir::operation::Pad &node)
+{
+ _return_op = std::make_unique<ir::train::operation::Pad>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::Permute &node)
+{
+ _return_op = std::make_unique<ir::train::operation::Permute>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::Pool2D &node)
+{
+ _return_op = std::make_unique<ir::train::operation::Pool2D>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::Reduce &node)
+{
+ _return_op = std::make_unique<ir::train::operation::Reduce>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::Reshape &node)
+{
+ _return_op = std::make_unique<ir::train::operation::Reshape>(node);
+}
+
+void TrainableOperationConverter::visit(const ir::operation::Softmax &node)
+{
+ _return_op = std::make_unique<ir::train::operation::Softmax>(node);
+}
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/train/TrainableOperationConverter.h b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.h
new file mode 100644
index 000000000..59f92f93e
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/TrainableOperationConverter.h
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_TRAIN_TRAINABLE_OPERATION_CONVERTER_H__
+#define __ONERT_COMPILER_TRAIN_TRAINABLE_OPERATION_CONVERTER_H__
+
+#include "UntrainableOperationConverter.h"
+
+#include "ir/train/TrainingInfo.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+class TrainableOperationConverter : public UntrainableOperationConverter
+{
+public:
+ TrainableOperationConverter(ir::train::TrainableGraph &trainable_graph,
+ const ir::train::TrainingInfo *training_info);
+
+ using UntrainableOperationConverter::operator();
+
+private:
+ void visit(const ir::operation::BinaryArithmetic &) override;
+ void visit(const ir::operation::Conv2D &) override;
+ void visit(const ir::operation::DepthwiseConv2D &) override;
+ void visit(const ir::operation::ElementwiseActivation &) override;
+ void visit(const ir::operation::FullyConnected &) override;
+ void visit(const ir::operation::Loss &node) override;
+ void visit(const ir::operation::Pad &node) override;
+ void visit(const ir::operation::Permute &node) override;
+ void visit(const ir::operation::Pool2D &node) override;
+ void visit(const ir::operation::Reduce &node) override;
+ void visit(const ir::operation::Reshape &) override;
+ void visit(const ir::operation::Softmax &) override;
+
+private:
+ const ir::train::TrainingInfo *_training_info;
+};
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_TRAIN_TRAINABLE_OPERATION_CONVERTER_H__
diff --git a/runtime/onert/core/src/compiler/train/TrainingCompiler.cc b/runtime/onert/core/src/compiler/train/TrainingCompiler.cc
new file mode 100644
index 000000000..ab0de8df9
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/TrainingCompiler.cc
@@ -0,0 +1,310 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TrainingCompiler.h"
+
+#include "StaticBackwardShapeInferer.h"
+#include "TrainableOperationConverter.h"
+#include "pass/LossInsertionPass.h"
+#include "../CompilerHelpers.h"
+#include "../ExecutorFactory.h"
+#include "../pass/ConstantOutputPass.h"
+#include "../pass/OddOutputPass.h"
+#include "../pass/PassRunner.h"
+#include "../pass/UnusedOperandEliminationPass.h"
+#include "../ShapeValidator.h"
+#include "../../dumper/dot/DotDumper.h"
+#include "../../exec/train/TrainableExecutors.h"
+#include "../../ir/OperationDumper.h"
+#include "../../ir/verifier/Verifier.h"
+
+#include <compiler/StaticShapeInferer.h>
+#include <compiler/train/LoweredTrainableGraph.h>
+#include <ir/train/TrainableGraph.h>
+
+#include <misc/polymorphic_downcast.h>
+#include <misc/string_helpers.h>
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+TrainingCompiler::TrainingCompiler(const std::shared_ptr<ir::NNPkg> &nnpkg, CompilerOptions *copts,
+ const ir::train::TrainingInfo &training_info)
+ : _model{nnpkg->primary_model()}, _options{copts}, _training_info{training_info}
+{
+ if (nnpkg->model_count() > 1)
+ throw std::runtime_error("TrainingCompiler does not support multiple models yet");
+
+ if (nnpkg->primary_model()->subgraphs_count() > 1)
+ throw std::runtime_error("TrainingCompiler does not support multiple subgraphs yet");
+}
+
+std::shared_ptr<CompilerArtifact> TrainingCompiler::compile(void)
+{
+ /***************************************************
+ * Prepare compilation phase
+ ***************************************************/
+ if (!_options)
+ throw std::runtime_error{"Empty compile option"};
+
+ // Mode check
+ // TODO handle option for each model
+ if (_options->he_profiling_mode)
+ {
+ if (!_options->he_scheduler)
+ throw std::runtime_error("Heterogeneous scheduler must be enabled during profiling.");
+
+ if (_options->executor != "Dataflow")
+ throw std::runtime_error("Profiling mode works only with 'Dataflow' executor");
+ }
+
+ _options->forceInternalOptions();
+ _options->verboseOptions();
+
+ auto custom_kernel_builder = _model->getKernelBuilder();
+
+ _model->iterate([&](const ir::SubgraphIndex &, ir::IGraph &graph) {
+ auto &subg = nnfw::misc::polymorphic_downcast<ir::Graph &>(graph);
+ // Mandatory passes
+ compiler::pass::PassRunner{}
+ .append(std::make_unique<compiler::pass::ConstantOutputPass>(subg))
+ .append(std::make_unique<compiler::pass::OddOutputPass>(subg))
+ .run();
+
+ // Optimizations
+ compiler::pass::PassRunner{}
+ .append(std::make_unique<compiler::pass::UnusedOperandEliminationPass>(subg))
+ .run();
+ });
+
+ std::unordered_map<ir::SubgraphIndex, std::shared_ptr<ir::train::TrainableGraph>>
+ trainable_subgraphs;
+
+ if (_model->hasOnly<ir::Graph>())
+ {
+ // Create trainable subgraphs by copy and converting inference model
+ _model->iterate([&](const ir::SubgraphIndex &subg_index, const ir::IGraph &graph) {
+ const auto &subg = nnfw::misc::polymorphic_downcast<const ir::Graph &>(graph);
+ // Create TrainableGraph by copying Graph
+ auto trainable_subg = std::make_shared<ir::train::TrainableGraph>(subg);
+
+ // Convert operations to trainable operations
+ auto converter = TrainableOperationConverter{*trainable_subg, &_training_info};
+ ir::OperationIndex min_trainable_op_idx;
+ subg.operations().iterate(
+ [&](const onert::ir::OperationIndex &op_index, const onert::ir::IOperation &op) {
+ auto trainable_op = converter(op);
+ if (_training_info.getTrainableOps().find(op_index) !=
+ std::end(_training_info.getTrainableOps()))
+ {
+ trainable_op->enableWeightsUpdate();
+ if (op_index.value() < min_trainable_op_idx.value())
+ {
+ min_trainable_op_idx = op_index;
+ }
+ }
+ auto gen_index = trainable_subg->replaceOperation(op_index, std::move(trainable_op));
+ UNUSED_RELEASE(gen_index);
+ assert(gen_index == op_index);
+ });
+
+ for (ir::OperationIndex idx{min_trainable_op_idx};
+ idx.value() < trainable_subg->operations().size(); idx++)
+ {
+ trainable_subg->enableBackward(idx);
+ }
+
+ trainable_subgraphs[subg_index] = std::move(trainable_subg);
+ });
+ }
+ else
+ {
+ // TODO Support models that have TrainableGraphs
+ throw std::runtime_error("TrainingCompiler: Invalid model");
+ }
+
+ // operation
+ _model.reset();
+
+ // TODO Handle dump level for each model
+ auto dump_level = static_cast<dumper::dot::DotDumper::Level>(_options->graph_dump_level);
+ onert::dumper::dot::DotDumper dot_dumper(dump_level);
+
+ for (const auto &pair : trainable_subgraphs)
+ {
+ const auto &subg_index = pair.first;
+ const auto &subg = pair.second;
+ dot_dumper.dump(*subg, nnfw::misc::str("before_loss_insertion-", subg_index.value()));
+ }
+
+ // Apply pass for trainable subgraphs
+ for (auto &&pair : trainable_subgraphs)
+ {
+ auto trainable_subg = pair.second;
+ auto subg_index = pair.first;
+
+ compiler::pass::PassRunner{}
+ .append(std::make_unique<train::pass::LossInsertionPass>(*trainable_subg, &_training_info,
+ subg_index))
+ .run();
+ }
+
+ for (const auto &pair : trainable_subgraphs)
+ {
+ const auto &subg_index = pair.first;
+ const auto &subg = pair.second;
+ dot_dumper.dump(*subg, nnfw::misc::str("after_loss_insertion-", subg_index.value()));
+ }
+
+ // Change input shape according to batch_size
+ for (auto &&pair : trainable_subgraphs)
+ {
+ auto trainable_subg = pair.second;
+
+ for (const auto &ind : trainable_subg->getInputs())
+ {
+ auto &input = trainable_subg->operands().at(ind);
+ auto new_shape = input.info().shape();
+ // TODO Consider batch size index
+ if (new_shape.dim(0) != 1)
+ throw std::runtime_error("the first dim is not 1. It is not supported yet.");
+ new_shape.dim(0) = _training_info.batchSize();
+ input.info().shape(new_shape);
+ }
+ }
+
+ /***************************************************
+ * Backend independent analysis & optimization phase
+ ***************************************************/
+ // Tracing context
+ auto tracing_ctx = std::make_unique<util::TracingCtx>();
+
+ // Lower: Assign backend
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<compiler::train::LoweredTrainableGraph>>
+ lowered_subgs;
+ {
+ for (auto &&pair : trainable_subgraphs)
+ {
+ auto &subg_index = pair.first;
+ auto trainable_subg = pair.second;
+
+ // Lower: Assign backend
+ lowered_subgs[subg_index] =
+ std::make_unique<compiler::train::LoweredTrainableGraph>(*trainable_subg, *_options);
+ // Set tracing_ctx for copied graph
+ tracing_ctx->setSubgraphIndex(&(lowered_subgs[subg_index]->graph()), subg_index.value());
+ }
+ }
+
+ for (const auto &pair : lowered_subgs)
+ {
+ const auto &subg_index = pair.first;
+ const auto &lowered_subg = pair.second;
+ dot_dumper.dump(*lowered_subg, nnfw::misc::str("after_lower_subg-", subg_index.value()));
+ }
+
+ // Set operands' info for back propagation as default tensor info
+ for (const auto &pair : lowered_subgs)
+ {
+ auto lowered_subg = pair.second.get();
+ auto &tgraph = lowered_subg->trainable_graph();
+ tgraph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &obj) {
+ if (!obj.isConstant())
+ {
+ auto bwd_operand = std::make_unique<ir::Operand>(obj);
+ const auto gen_index = tgraph.addBackwardOperand(index, std::move(bwd_operand));
+ assert(gen_index == index);
+ UNUSED_RELEASE(gen_index);
+ }
+ });
+ }
+
+ // Shape inference.
+ {
+ // Run the StaticShapeInfer of primary subg. All child StaticShapeInferers are called
+ // recursively
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers =
+ createStaticShapeInferers(lowered_subgs);
+
+ const auto primary_subg_idx = ir::SubgraphIndex{0};
+ inferers.at(primary_subg_idx)->infer();
+
+ for (const auto &pair_inferer : inferers)
+ {
+ const auto inferer = pair_inferer.second.get();
+ inferer->dump();
+ }
+
+ // NOTE StaticBackwardShapeInferer is allocated for each subgraph,
+ // so it does not support models that have controlflow operations yet.
+ for (auto &&pair : lowered_subgs)
+ {
+ auto &lowered_subg = pair.second;
+ auto inferer = std::make_unique<StaticBackwardShapeInferer>(lowered_subg.get());
+ inferer->infer();
+ inferer->dump();
+ }
+ }
+
+ // Shape validation
+ for (const auto &pair : lowered_subgs)
+ {
+ auto &lowered_subg = pair.second;
+ compiler::ShapeValidator{lowered_subg->graph()}();
+ }
+
+ // TODO Validate shapes of the tensors for back propagation
+
+ /*************************************************************
+ * Backend independent analysis & optimization phase finished
+ *************************************************************/
+ auto executors = std::make_shared<exec::train::TrainableExecutors>();
+ for (auto &&pair : lowered_subgs)
+ {
+ auto const model_index = ir::ModelIndex{0};
+ auto const subg_index = pair.first;
+ auto &lowered_subg = pair.second;
+ auto const indexed_ranks = lowered_subg->indexed_ranks();
+
+ ir::OperationDumper dumper("Executor generation of Subgraph " +
+ std::to_string(subg_index.value()));
+ lowered_subg->graph().operations().iterate(
+ [&](const ir::OperationIndex &, const ir::IOperation &op) { op.accept(dumper); });
+
+ ExecutorFactoryArgs args;
+ args.tracing_ctx = tracing_ctx.get();
+ args.options = _options;
+ args.model_index = model_index;
+ args.custom_kernel_builder = custom_kernel_builder;
+ auto executor = std::unique_ptr<exec::IExecutor>{
+ ExecutorFactory::get().create(std::move(lowered_subg), executors, args, _training_info)};
+ executor->setIndexedRanks(indexed_ranks);
+ executors->emplace(model_index, subg_index, std::move(executor));
+ }
+
+ /********************************
+ * Code generation phase finished
+ ********************************/
+ return std::make_shared<CompilerArtifact>(executors, std::move(tracing_ctx));
+}
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/train/TrainingCompiler.h b/runtime/onert/core/src/compiler/train/TrainingCompiler.h
new file mode 100644
index 000000000..ab62c0f34
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/TrainingCompiler.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * @file TrainingCompiler.h
+ * @brief This file contains TrainingCompiler class to define and run compilation phase
+ */
+
+#ifndef __ONERT_COMPILER_TRAIN_TRAINING_COMPILER_H_
+#define __ONERT_COMPILER_TRAIN_TRAINING_COMPILER_H_
+
+#include "compiler/CompilerOptions.h"
+#include "compiler/ICompiler.h"
+#include "ir/NNPkg.h"
+#include "ir/train/TrainingInfo.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+/**
+ * @brief Class to compile NN package
+ */
+class TrainingCompiler : public ICompiler
+{
+public:
+ /**
+ * @brief Construct a new TrainingCompiler object for an nnpkg
+ * @param[in] nnpkg nnpkg to compile
+ * @param[in] copts compiler options
+ * @param[in] training_info training information
+ */
+ explicit TrainingCompiler(const std::shared_ptr<ir::NNPkg> &nnpkg, CompilerOptions *copts,
+ const ir::train::TrainingInfo &training_info);
+
+ /**
+ * @brief Construct a TrainingCompiler object
+ *
+ */
+ TrainingCompiler(void) = delete;
+
+ /**
+ * @brief Destroy the TrainingCompiler object
+ */
+ ~TrainingCompiler() = default;
+
+public:
+ /**
+ * @brief Do compilation with the options
+ *
+ * @return std::shared_ptr<CompilerArtifact> Executors as a result of compilation
+ */
+ std::shared_ptr<CompilerArtifact> compile(void);
+
+private:
+ std::shared_ptr<ir::Model> _model;
+ CompilerOptions *_options;
+ const ir::train::TrainingInfo _training_info;
+};
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_TRAIN_TRAINING_COMPILER_H_
diff --git a/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.cc b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.cc
new file mode 100644
index 000000000..22f7604b5
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.cc
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "UntrainableOperationConverter.h"
+
+#include "ir/train/operation/UntrainableOperation.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+UntrainableOperationConverter::UntrainableOperationConverter(ir::train::TrainableGraph &tgraph)
+ : _tgraph{tgraph}, _return_op{nullptr}
+{
+}
+
+std::unique_ptr<ir::train::ITrainableOperation>
+UntrainableOperationConverter::operator()(const ir::IOperation &op)
+{
+ op.accept(*this);
+
+ return std::move(_return_op);
+}
+
+#define OP(InternalName) \
+ void UntrainableOperationConverter::visit(const ir::operation::InternalName &node) \
+ { \
+ _return_op = \
+ std::make_unique<ir::train::operation::UntrainableOperation<ir::operation::InternalName>>( \
+ node); \
+ }
+#include "ir/Operations.lst"
+#undef OP
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.h b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.h
new file mode 100644
index 000000000..e960b3831
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/UntrainableOperationConverter.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_TRAIN_UNTRAINABLE_OPERATION_CONVERTER_H__
+#define __ONERT_COMPILER_TRAIN_UNTRAINABLE_OPERATION_CONVERTER_H__
+
+#include "ir/Operations.Include.h"
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableGraph.h"
+
+#include <memory>
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+
+class UntrainableOperationConverter : public ir::OperationVisitor
+{
+public:
+ UntrainableOperationConverter(ir::train::TrainableGraph &tgraph);
+ std::unique_ptr<ir::train::ITrainableOperation> operator()(const ir::IOperation &op);
+
+#define OP(InternalName) void visit(const ir::operation::InternalName &node);
+#include "ir/Operations.lst"
+#undef OP
+
+protected:
+ ir::train::TrainableGraph &_tgraph;
+ std::unique_ptr<ir::train::ITrainableOperation> _return_op;
+};
+
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_TRAIN_UNTRAINABLE_OPERATION_CONVERTER_H__
diff --git a/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc
new file mode 100644
index 000000000..ea1f21e30
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.cc
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "LossInsertionPass.h"
+
+#include "ir/train/TrainableGraph.h"
+#include "ir/train/TrainingInfo.h"
+#include "ir/train/operation/Loss.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+namespace pass
+{
+
+void LossInsertionPass::run()
+{
+ const auto &loss_info = _training_info->lossInfo();
+
+ if (_trainable_graph.getOutputs().size() != 1)
+ {
+ throw std::runtime_error("LossInsertionPass: Not supported multiple outputs");
+ }
+
+ // TODO Consider SparseCategoricalCrossentropy y_true shape
+ // SparseCategoricalCrossentropy loss has a different y_true shape than y_pred.
+
+ // TODO Implement Loop [0, getOutputs().size())
+ // index: a loop index
+ const auto index = 0;
+ const auto &y_pred_index = _trainable_graph.getOutputs().at(index);
+ const auto &y_pred = _trainable_graph.operands().at(y_pred_index);
+ auto y_true_index = _trainable_graph.addOperand(y_pred.shape(), y_pred.typeInfo());
+ ir::OperandIndexSequence inputs{y_pred_index, y_true_index};
+
+ ir::Shape output_shape;
+ if (loss_info.reduction_type == ir::train::LossReductionType::Sum ||
+ loss_info.reduction_type == ir::train::LossReductionType::SumOverBatchSize)
+ {
+ output_shape = ir::Shape{1};
+ }
+ else
+ {
+ throw std::runtime_error("LossInsertionPass: Not supported reduction type");
+ }
+
+ const ir::TypeInfo float_op(ir::DataType::FLOAT32);
+ auto output_index = _trainable_graph.addOperand(output_shape, float_op);
+ ir::OperandIndexSequence outputs{output_index};
+
+ auto loss_op = std::make_unique<ir::operation::Loss>(inputs, outputs);
+ auto trainable_loss_op = std::make_unique<ir::train::operation::Loss>(*loss_op, loss_info);
+ trainable_loss_op->enableBackward();
+
+ _trainable_graph.addOperation(std::move(trainable_loss_op));
+
+ _trainable_graph.addInput(y_true_index);
+
+ // TODO Add loss as many as output size
+ _trainable_graph.addLoss(output_index, ir::IOIndex{index});
+}
+
+} // namespace pass
+} // namespace train
+} // namespace compiler
+} // namespace onert
diff --git a/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.h b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.h
new file mode 100644
index 000000000..1a313fb11
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/pass/LossInsertionPass.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_TRAIN_PASS_LOSS_INSERTION_PASS_H__
+#define __ONERT_COMPILER_TRAIN_PASS_LOSS_INSERTION_PASS_H__
+
+#include "Pass.h"
+
+#include "ir/Index.h"
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+namespace pass
+{
+
+class LossInsertionPass : public Pass
+{
+public:
+ LossInsertionPass(ir::train::TrainableGraph &trainable_graph,
+ const ir::train::TrainingInfo *training_info,
+ const ir::SubgraphIndex &subg_index)
+ : Pass{trainable_graph, training_info}, _subg_index{subg_index}
+ {
+ }
+
+public:
+ std::string id() final { return "LossInsertionPass"; }
+ void run() final;
+
+private:
+ ir::SubgraphIndex _subg_index;
+};
+
+} // namespace pass
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_TRAIN_PASS_LOSS_INSERTION_PASS_H__
diff --git a/runtime/onert/core/src/compiler/train/pass/Pass.h b/runtime/onert/core/src/compiler/train/pass/Pass.h
new file mode 100644
index 000000000..0e835e19e
--- /dev/null
+++ b/runtime/onert/core/src/compiler/train/pass/Pass.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_COMPILER_TRAIN_PASS_PASS_H__
+#define __ONERT_COMPILER_TRAIN_PASS_PASS_H__
+
+#include "../../pass/IPass.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+class TrainableGraph;
+class TrainingInfo;
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+namespace onert
+{
+namespace compiler
+{
+namespace train
+{
+namespace pass
+{
+
+class Pass : public compiler::pass::IPass
+{
+public:
+ Pass(ir::train::TrainableGraph &trainable_graph, const ir::train::TrainingInfo *training_info)
+ : _trainable_graph{trainable_graph}, _training_info{training_info}
+ {
+ }
+ virtual ~Pass() = default;
+
+protected:
+ ir::train::TrainableGraph &_trainable_graph;
+ const ir::train::TrainingInfo *_training_info;
+};
+
+} // namespace pass
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+#endif // __ONERT_COMPILER_TRAIN_PASS_PASS_H__
diff --git a/runtime/onert/core/src/dumper/dot/DotBuilder.cc b/runtime/onert/core/src/dumper/dot/DotBuilder.cc
index 38a69696e..9257434fa 100644
--- a/runtime/onert/core/src/dumper/dot/DotBuilder.cc
+++ b/runtime/onert/core/src/dumper/dot/DotBuilder.cc
@@ -29,31 +29,12 @@ DotBuilder::DotBuilder() {}
void DotBuilder::update(const Node &node_info)
{
add(node_info);
- for (auto edge : node_info.out_edges())
+ for (auto &&edge : node_info.out_edges())
{
addEdge(node_info, *edge);
}
}
-void DotBuilder::addOpSequence(const DotSubgraphInfo &subgraph_info)
-{
- _dot << "subgraph cluster_" << subgraph_info.index().value() << " {\n";
- _dot << " label=\"" << subgraph_info.label() << "\";\n";
- _dot << " style=filled;\n";
- _dot << " color=lightgrey;\n";
- _dot << " ";
- for (auto op : subgraph_info.operations())
- {
- _dot << "operation" << op.value() << "; ";
- }
- for (auto op : subgraph_info.operands())
- {
- _dot << "operand" << op.value() << "; ";
- }
- _dot << "\n";
- _dot << "}\n";
-}
-
void DotBuilder::writeDot(std::ostream &os)
{
os << "digraph D {\n"
@@ -66,7 +47,7 @@ void DotBuilder::add(const Node &node)
_dot << node.id();
std::stringstream ss;
_dot << "[";
- for (auto attr : node.attributes())
+ for (auto &&attr : node.attributes())
{
_dot << attr.first << "=\"" << attr.second << "\" ";
}
diff --git a/runtime/onert/core/src/dumper/dot/DotBuilder.h b/runtime/onert/core/src/dumper/dot/DotBuilder.h
index 681cbbf5d..30f32f8f9 100644
--- a/runtime/onert/core/src/dumper/dot/DotBuilder.h
+++ b/runtime/onert/core/src/dumper/dot/DotBuilder.h
@@ -25,7 +25,6 @@
#include "OperationNode.h"
#include "OperandNode.h"
-#include "DotSubgraphInfo.h"
using Operation = onert::ir::Operation;
using Object = onert::ir::Operand;
@@ -44,7 +43,6 @@ public:
public:
void update(const Node &dotinfo);
- void addOpSequence(const DotSubgraphInfo &subgraph_info);
void writeDot(std::ostream &os);
diff --git a/runtime/onert/core/src/dumper/dot/DotDumper.cc b/runtime/onert/core/src/dumper/dot/DotDumper.cc
index 118057f09..98524d8d1 100644
--- a/runtime/onert/core/src/dumper/dot/DotDumper.cc
+++ b/runtime/onert/core/src/dumper/dot/DotDumper.cc
@@ -19,8 +19,7 @@
#include "DotDumper.h"
#include "DotBuilder.h"
-#include "DotSubgraphInfo.h"
-#include "ir/OpSequence.h"
+#include "ir/OperandIndexMap.h"
#include "ir/OperationIndexMap.h"
#include "backend/Backend.h"
#include "backend/IConfig.h"
@@ -33,151 +32,153 @@ namespace dumper
namespace dot
{
-void DotDumper::dump(const std::string &tag)
+namespace
{
- if (_level == Level::OFF)
- {
- return;
- }
-
- onert::dumper::dot::DotBuilder dot_builder;
-
- auto &operations = _graph.operations();
- auto &operands = _graph.operands();
-
- ir::OperationIndexMap<std::unique_ptr<Operation>> operation_nodes;
- std::unordered_map<ir::OperandIndex, std::unique_ptr<Operand>> operand_nodes;
-
- auto backend_to_fillcolor = [](const backend::Backend *backend) {
- static const auto map = []() {
- std::unordered_map<const backend::Backend *, std::string> ret;
- uint32_t index = 1; // Start from 1 to avoid 0(red) which is too dark :(
- for (const auto backend : compiler::BackendManager::get().getAll())
- {
- ret.emplace(backend, Node::BG_COLORS[index]);
- index = (index + 1) % (sizeof(Node::BG_COLORS) / sizeof(Node::BG_COLORS[0]));
- }
- return ret;
- }();
-
- auto itr = map.find(backend);
- if (itr == map.end())
- {
- return Node::DEFAULT_FILLCOLOR;
- }
- else
+std::string backend_to_fillcolor(const backend::Backend *backend)
+{
+ static const auto map = []() {
+ std::unordered_map<const backend::Backend *, std::string> ret;
+ uint32_t index = 1; // Start from 1 to avoid 0(red) which is too dark :(
+ for (const auto backend : compiler::BackendManager::get().getAll())
{
- return itr->second;
+ ret.emplace(backend, Node::BG_COLORS[index]);
+ index = (index + 1) % (sizeof(Node::BG_COLORS) / sizeof(Node::BG_COLORS[0]));
}
- };
+ return ret;
+ }();
+ auto itr = map.find(backend);
+ if (itr == map.end())
+ {
+ return Node::DEFAULT_FILLCOLOR;
+ }
+ else
+ {
+ return itr->second;
+ }
+}
- util::Set<ir::OperandIndex> shown_operand_set;
+std::unordered_map<ir::OperandIndex, std::unique_ptr<Operand>>
+generate_dot_operands(const ir::Graph &graph, const DotDumper::Level level)
+{
+ std::unordered_map<ir::OperandIndex, std::unique_ptr<Operand>> dot_operands;
+ const auto &operands = graph.operands();
operands.iterate([&](const ir::OperandIndex &index, const ir::Operand &object) {
- bool showing_cond = false;
- if (_level == Level::ALL)
- {
- showing_cond = true;
- }
- else
- {
- showing_cond = !object.isConstant();
- }
- if (object.isConstant() || _graph.getInputs().contains(index))
- {
- showing_cond = showing_cond && (object.getUses().size() > 0);
- }
+ bool showing_cond =
+ level == DotDumper::Level::ALL
+ ? true
+ : !object.isConstant() || (graph.getInputs() + graph.getOutputs()).contains(index);
if (showing_cond)
{
- shown_operand_set.add(index);
-
auto type = [&]() {
using onert::dumper::dot::Operand;
- if (_graph.getInputs().contains(index))
+ if (graph.getInputs().contains(index))
return Operand::Type::MODEL_INPUT;
- if (_graph.getOutputs().contains(index))
+ if (graph.getOutputs().contains(index))
return Operand::Type::MODEL_OUTPUT;
return Operand::Type::INTERNAL;
}();
auto node = std::make_unique<Operand>(index, type);
+ std::string label = std::to_string(index.value());
+ std::string fillcolor = "";
+ node->setAttribute("label", label);
+ node->setAttribute("fillcolor", fillcolor);
- {
- // Display LowerInfo attributes
- std::string label = std::to_string(index.value());
- std::string fillcolor = "";
- if (_lowered_graph)
- {
- auto lower_info = _lowered_graph->getLowerInfo(index);
- const auto &def_factors = lower_info->def_factors();
- if (def_factors.size() > 0)
- {
- label += "\\n[";
- label += def_factors.getOnlyElement().backend()->config()->id();
- label += "]";
-
- fillcolor = backend_to_fillcolor(lower_info->def_factors().getOnlyElement().backend());
- }
- }
- node->setAttribute("label", label);
- node->setAttribute("fillcolor", fillcolor);
- }
-
- operand_nodes.emplace(index, std::move(node));
+ dot_operands.emplace(index, std::move(node));
}
});
- operations.iterate([&](const ir::OperationIndex &index, const ir::Operation &op) {
+ return dot_operands;
+}
+
+ir::OperationIndexMap<std::unique_ptr<Operation>>
+generate_dot_operations(const ir::Graph &graph,
+ const ir::OperandIndexMap<std::unique_ptr<Operand>> &dot_operands)
+{
+ ir::OperationIndexMap<std::unique_ptr<Operation>> dot_operations;
+ const auto &operations = graph.operations();
+ operations.iterate([&](const ir::OperationIndex &index, const ir::IOperation &op) {
auto node = std::make_unique<Operation>(index, op);
- for (auto input : op.getInputs())
+ for (auto &&input : op.getInputs())
{
using onert::dumper::dot::Operand;
// Constant input and dump level is ALL_BUT_CONSTANTS
- if (operand_nodes.find(input) == operand_nodes.end())
+ if (dot_operands.find(input) == dot_operands.end())
continue;
- auto &input_node = operand_nodes.at(input);
+ auto &input_node = dot_operands.at(input);
input_node->addOutEdge(node.get());
}
- for (auto output : op.getOutputs())
+ for (auto &&output : op.getOutputs() | ir::Remove::UNDEFINED)
{
using onert::dumper::dot::Operand;
- auto &output_node = operand_nodes.at(output);
+ auto &output_node = dot_operands.at(output);
node->addOutEdge(output_node.get());
}
- operation_nodes.emplace(index, std::move(node));
+ dot_operations.emplace(index, std::move(node));
});
- if (_lowered_graph)
- {
- const auto &op_seqs = _lowered_graph->op_seqs();
- op_seqs.iterate([&](const ir::OpSequenceIndex &index, const ir::OpSequence &op_seq) {
- const auto lower_info = _lowered_graph->getLowerInfo(index);
+ return dot_operations;
+}
+
+void update_lower_info(const compiler::ILoweredGraph &lowered_graph,
+ ir::OperandIndexMap<std::unique_ptr<Operand>> *dot_operands)
+{
+ const auto &operands = lowered_graph.graph().operands();
+ operands.iterate([&](const ir::OperandIndex &index, const ir::Operand &) {
+ auto itr = dot_operands->find(index);
+ if (itr != dot_operands->end())
+ {
+ auto &node = itr->second;
+ // Display LowerInfo attributes
+ std::string label = node->getAttribute("label");
+ std::string fillcolor = node->getAttribute("fillcolor");
+ auto lower_info = lowered_graph.lower_info().operand.getRawPtr(index);
+ const auto &def_factors = lower_info->def_factors();
+ if (def_factors.size() > 0)
+ {
+ label += "\\n[";
+ label += def_factors.getOnlyElement().backend()->config()->id();
+ label += "]";
+ fillcolor = backend_to_fillcolor(lower_info->def_factors().getOnlyElement().backend());
+ }
+ node->setAttribute("label", label);
+ node->setAttribute("fillcolor", fillcolor);
+ }
+ });
+}
+
+void update_lower_info(const compiler::ILoweredGraph &lowered_graph,
+ ir::OperationIndexMap<std::unique_ptr<Operation>> *dot_operations)
+{
+ const auto &operations = lowered_graph.graph().operations();
+ operations.iterate([&](const ir::OperationIndex &index, const ir::IOperation &) {
+ const auto lower_info = lowered_graph.lower_info().operation.getRawPtr(index);
+ if (lower_info)
+ {
auto fillcolor = backend_to_fillcolor(lower_info->backend());
- std::string label =
- std::to_string(index.value()) + " [" + lower_info->backend()->config()->id() + "]";
- DotSubgraphInfo subgraph_info{index, op_seq, shown_operand_set, _graph.operations()};
- subgraph_info.label(label);
- subgraph_info.fillcolor(fillcolor);
- dot_builder.addOpSequence(subgraph_info);
-
- // Set fillcolor of all operations in the op_seq
- for (const auto &op_idx : op_seq.operations())
+ std::string backend_label = "[" + lower_info->backend()->config()->id() + "]";
+ auto itr = dot_operations->find(index);
+ if (itr != dot_operations->end())
{
- auto found = operation_nodes.find(op_idx);
- if (found != operation_nodes.end())
- {
- auto &&op = found->second;
- op->setAttribute("fillcolor", fillcolor);
- }
+ auto &node = itr->second;
+ node->setAttribute("label", node->getAttribute("label") + "\n" + backend_label);
+ node->setAttribute("fillcolor", fillcolor);
}
- });
- }
+ }
+ });
+}
+void dump_to_file(const ir::OperandIndexMap<std::unique_ptr<Operand>> &operand_nodes,
+ const ir::OperationIndexMap<std::unique_ptr<Operation>> &operation_nodes,
+ const std::string &tag)
+{
+ onert::dumper::dot::DotBuilder dot_builder;
for (const auto &e : operation_nodes)
dot_builder.update(*e.second);
for (const auto &e : operand_nodes)
@@ -198,6 +199,39 @@ void DotDumper::dump(const std::string &tag)
fb.close();
}
}
+} // namespace
+
+void DotDumper::dump(const ir::Graph &graph, const std::string &tag)
+{
+ if (_level == Level::OFF)
+ {
+ return;
+ }
+
+ const auto dot_operands = generate_dot_operands(graph, _level);
+ const auto dot_operations = generate_dot_operations(graph, dot_operands);
+ dump_to_file(dot_operands, dot_operations, tag);
+}
+
+// TODO Support tensors for training
+void DotDumper::dump(const compiler::ILoweredGraph &lowered_graph, const std::string &tag)
+{
+ if (_level == Level::OFF)
+ {
+ return;
+ }
+
+ auto dot_operands = generate_dot_operands(lowered_graph.graph(), _level);
+ auto dot_operations = generate_dot_operations(lowered_graph.graph(), dot_operands);
+ update_lower_info(lowered_graph, &dot_operands);
+ update_lower_info(lowered_graph, &dot_operations);
+ dump_to_file(dot_operands, dot_operations, tag);
+}
+
+void DotDumper::dump(const ir::train::TrainableGraph &graph, const std::string &tag)
+{
+ dump(graph.graph(), tag);
+}
} // namespace dot
} // namespace dumper
diff --git a/runtime/onert/core/src/dumper/dot/DotDumper.h b/runtime/onert/core/src/dumper/dot/DotDumper.h
index fdbca1642..59f4b3bda 100644
--- a/runtime/onert/core/src/dumper/dot/DotDumper.h
+++ b/runtime/onert/core/src/dumper/dot/DotDumper.h
@@ -15,7 +15,8 @@
*/
#include "ir/Graph.h"
-#include "compiler/LoweredGraph.h"
+#include "ir/train/TrainableGraph.h"
+#include "compiler/ILoweredGraph.h"
#ifndef __ONERT_DUMPER_DOT_DOT_DUMPER_H__
#define __ONERT_DUMPER_DOT_DOT_DUMPER_H__
@@ -38,27 +39,37 @@ public:
};
public:
- DotDumper(const ir::Graph &graph, Level level)
- : _lowered_graph{nullptr}, _graph(graph), _level{level}
- {
- }
- DotDumper(const compiler::LoweredGraph *lowered_graph, Level level)
- : _lowered_graph{lowered_graph}, _graph(_lowered_graph->graph()), _level{level}
- {
- }
+ DotDumper(Level level) : _level{level} {}
public:
/**
- * @brief Dump to dot file as tag name if "GRAPH_DOT_DUMP" is set
+ * @brief Dump graph information to dot file as tag name if "GRAPH_DOT_DUMP" is set
*
+ * @param[in] graph The graph that would be used to get operations and operands
* @param[in] tag The name of dot file that would be created
* @return N/A
*/
- void dump(const std::string &tag);
+ void dump(const ir::Graph &graph, const std::string &tag);
+
+ /**
+ * @brief Dump lowered graph information to dot file as tag name if "GRAPH_DOT_DUMP" is set
+ *
+ * @param[in] graph The graph that would be used to get operations and operands
+ * @param[in] tag The name of dot file that would be created
+ * @return N/A
+ */
+ void dump(const compiler::ILoweredGraph &lowered_graph, const std::string &tag);
+
+ /**
+ * @brief Dump graph information to dot file as tag name if "GRAPH_DOT_DUMP" is set
+ *
+ * @param[in] graph TrainableGraph to be dumped
+ * @param[in] tag The name of dot file to be dumped
+ * @return N/A
+ */
+ void dump(const ir::train::TrainableGraph &graph, const std::string &tag);
private:
- const compiler::LoweredGraph *_lowered_graph;
- const ir::Graph &_graph;
Level _level;
};
diff --git a/runtime/onert/core/src/dumper/dot/DotSubgraphInfo.cc b/runtime/onert/core/src/dumper/dot/DotSubgraphInfo.cc
deleted file mode 100644
index 52e9c758d..000000000
--- a/runtime/onert/core/src/dumper/dot/DotSubgraphInfo.cc
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "DotSubgraphInfo.h"
-
-#include <sstream>
-
-namespace onert
-{
-namespace dumper
-{
-namespace dot
-{
-
-DotSubgraphInfo::DotSubgraphInfo(const ir::OpSequenceIndex &index, const ir::OpSequence &op_seq,
- const util::Set<ir::OperandIndex> &shown_operands,
- const ir::Operations &operations_ctx)
- : _index{index}
-{
- for (const auto &op_idx : op_seq.operations())
- {
- _operations.insert(op_idx);
- const auto &node = operations_ctx.at(op_idx);
- for (auto o : node.getInputs())
- {
- // Must be a shown operand, not op_seq's inputs
- if (shown_operands.contains(o) && !op_seq.getInputs().contains(o))
- {
- _operands.insert(o);
- }
- }
- for (auto o : node.getOutputs())
- {
- // Must be a shown operand, not op_seq's inputs
- if (shown_operands.contains(o) && !op_seq.getOutputs().contains(o))
- {
- _operands.insert(o);
- }
- }
- }
-}
-
-} // namespace dot
-} // namespace dumper
-} // namespace onert
diff --git a/runtime/onert/core/src/dumper/dot/DotSubgraphInfo.h b/runtime/onert/core/src/dumper/dot/DotSubgraphInfo.h
deleted file mode 100644
index 95ba8953e..000000000
--- a/runtime/onert/core/src/dumper/dot/DotSubgraphInfo.h
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __ONERT_CORE_DUMPER_DOT_DOT_SUBGRAPH_INFO_H__
-#define __ONERT_CORE_DUMPER_DOT_DOT_SUBGRAPH_INFO_H__
-
-#include <unordered_set>
-
-#include "ir/Index.h"
-#include <ir/Operations.h>
-#include "ir/OpSequence.h"
-#include "util/Set.h"
-
-namespace onert
-{
-namespace dumper
-{
-namespace dot
-{
-
-class DotSubgraphInfo
-{
-public:
- DotSubgraphInfo(const ir::OpSequenceIndex &index, const ir::OpSequence &op_seq,
- const util::Set<ir::OperandIndex> &shown_operands,
- const ir::Operations &operations_ctx);
-
- ir::OpSequenceIndex index() const { return _index; }
- std::string label() const { return _label; }
- void label(const std::string &val) { _label = val; }
- std::string fillcolor() const { return _fillcolor; }
- void fillcolor(const std::string &val) { _fillcolor = val; }
- const std::unordered_set<ir::OperationIndex> &operations() const { return _operations; }
- const std::unordered_set<ir::OperandIndex> &operands() const { return _operands; }
-
-private:
- ir::OpSequenceIndex _index;
- std::string _label;
- std::string _fillcolor;
- std::unordered_set<ir::OperationIndex> _operations;
- std::unordered_set<ir::OperandIndex> _operands;
-};
-
-} // namespace dot
-} // namespace dumper
-} // namespace onert
-
-#endif // __ONERT_CORE_DUMPER_DOT_DOT_SUBGRAPH_INFO_H__
diff --git a/runtime/onert/core/src/dumper/dot/OperandNode.cc b/runtime/onert/core/src/dumper/dot/OperandNode.cc
index 5a6015ca9..cbc73878f 100644
--- a/runtime/onert/core/src/dumper/dot/OperandNode.cc
+++ b/runtime/onert/core/src/dumper/dot/OperandNode.cc
@@ -18,7 +18,6 @@
#include "OperandNode.h"
#include "ir/Graph.h"
-#include "ir/operand/LowerInfo.h"
namespace onert
{
@@ -33,10 +32,10 @@ const std::string Operand::OPERAND_SHAPE = "ellipse";
const std::string Operand::BG_COLOR_SCHEME = "set18";
Operand::Operand(const ir::OperandIndex &index, Type type)
- : Node{"operand" + std::to_string(index.value())}
+ : Node{"operand" + std::to_string(index.value())}
{
{
- auto type_to_shape = [](Type type) {
+ auto type_to_shape = [](Type type) -> const std::string & {
switch (type)
{
case Type::MODEL_INPUT:
diff --git a/runtime/onert/core/src/dumper/dot/OperandNode.h b/runtime/onert/core/src/dumper/dot/OperandNode.h
index 2e7cc5861..f2aea80ad 100644
--- a/runtime/onert/core/src/dumper/dot/OperandNode.h
+++ b/runtime/onert/core/src/dumper/dot/OperandNode.h
@@ -64,7 +64,6 @@ public:
*
* @param[in] index Operand index
* @param[in] type Operand type
- * @param[in] lower_info Operand LowerInfo
*/
Operand(const ir::OperandIndex &index, Type type);
diff --git a/runtime/onert/core/src/dumper/dot/OperationNode.cc b/runtime/onert/core/src/dumper/dot/OperationNode.cc
index bee137e7c..2ef08c9c6 100644
--- a/runtime/onert/core/src/dumper/dot/OperationNode.cc
+++ b/runtime/onert/core/src/dumper/dot/OperationNode.cc
@@ -18,7 +18,6 @@
#include "OperationNode.h"
#include "ir/Graph.h"
-#include "ir/operation/LowerInfo.h"
#include "backend/IConfig.h"
#include "backend/Backend.h"
@@ -32,8 +31,8 @@ namespace dot
const std::string Operation::OPERATION_SHAPE = "rect";
const std::string Operation::BG_COLOR_SCHEME = "pastel18";
-Operation::Operation(const ir::OperationIndex &index, const ir::Operation &node)
- : Node{"operation" + std::to_string(index.value())}
+Operation::Operation(const ir::OperationIndex &index, const ir::IOperation &node)
+ : Node{"operation" + std::to_string(index.value())}
{
setAttribute("label", std::to_string(index.value()) + " : " + node.name());
setAttribute("shape", OPERATION_SHAPE);
diff --git a/runtime/onert/core/src/dumper/dot/OperationNode.h b/runtime/onert/core/src/dumper/dot/OperationNode.h
index 74a37d3fb..d9292ad0c 100644
--- a/runtime/onert/core/src/dumper/dot/OperationNode.h
+++ b/runtime/onert/core/src/dumper/dot/OperationNode.h
@@ -25,7 +25,7 @@
#define __ONERT_DUMPER_DOT_DOT_NODE_INFO_H__
#include "Node.h"
-#include "ir/Operation.h"
+#include "ir/IOperation.h"
#include "ir/Index.h"
namespace onert
@@ -52,7 +52,7 @@ public:
* @param[in] index operation index
* @param[in] node operation object
*/
- Operation(const ir::OperationIndex &index, const ir::Operation &node);
+ Operation(const ir::OperationIndex &index, const ir::IOperation &node);
};
} // namespace dot
diff --git a/runtime/onert/core/src/compiler/ParamChecker.cc b/runtime/onert/core/src/dumper/h5/Dumper.cc
index c4f80f087..5e12c2dbb 100644
--- a/runtime/onert/core/src/compiler/ParamChecker.cc
+++ b/runtime/onert/core/src/dumper/h5/Dumper.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,20 +14,21 @@
* limitations under the License.
*/
-#include "ParamChecker.h"
+#include "Dumper.h"
-#include "ir/Graph.h"
+#include <iostream>
+#include <sstream>
+#include <stdexcept>
namespace onert
{
-namespace compiler
+namespace dumper
{
-
-void ParamChecker::operator()()
+namespace h5
{
- _model->operations().iterate(
- [&](const ir::OperationIndex &, const ir::Operation &node) { node.accept(*this); });
-}
-} // namespace compiler
+Dumper::Dumper(const std::string &filepath) : _file{filepath, H5F_ACC_CREAT | H5F_ACC_RDWR} {}
+
+} // namespace h5
+} // namespace dumper
} // namespace onert
diff --git a/runtime/onert/core/src/dumper/h5/Dumper.h b/runtime/onert/core/src/dumper/h5/Dumper.h
new file mode 100644
index 000000000..53d0e0332
--- /dev/null
+++ b/runtime/onert/core/src/dumper/h5/Dumper.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_DUMPER_H5_DUMPER_H__
+#define __ONERT_DUMPER_H5_DUMPER_H__
+
+#include "exec/MinMaxMap.h"
+
+#include <H5Cpp.h>
+#include <string>
+
+namespace onert
+{
+namespace dumper
+{
+namespace h5
+{
+
+class Dumper
+{
+public:
+ /**
+ * @brief Construct dumper
+ *
+ * @param[in] path filepath to dump
+ * @throw H5::FileIException on error during file open/create
+ */
+ Dumper(const std::string &filepath);
+
+protected:
+ H5::H5File _file;
+};
+
+} // namespace h5
+} // namespace dumper
+} // namespace onert
+
+#endif // __ONERT_DUMPER_H5_DUMPER_H__
diff --git a/runtime/onert/core/src/dumper/h5/MinMaxDumper.cc b/runtime/onert/core/src/dumper/h5/MinMaxDumper.cc
new file mode 100644
index 000000000..e353ed5cb
--- /dev/null
+++ b/runtime/onert/core/src/dumper/h5/MinMaxDumper.cc
@@ -0,0 +1,87 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MinMaxDumper.h"
+
+#include <iostream>
+#include <sstream>
+#include <stdexcept>
+
+namespace onert
+{
+namespace dumper
+{
+namespace h5
+{
+
+static const char *h5_value_grpname = "value";
+
+/*
+ * ensure grp_name exists in parent
+ */
+H5::Group ensureGroup(H5::Group parent, const std::string &child)
+{
+ H5::Exception::dontPrint();
+ try
+ {
+ return parent.openGroup(child.c_str());
+ }
+ catch (H5::Exception &e)
+ {
+ return parent.createGroup(child.c_str());
+ }
+}
+
+MinMaxDumper::MinMaxDumper(const std::string &filepath) : Dumper(filepath)
+{
+ auto root_grp = _file.openGroup("/");
+ ensureGroup(root_grp, h5_value_grpname);
+}
+
+void MinMaxDumper::dump(const exec::IOMinMaxMap &input_minmax,
+ const exec::OpMinMaxMap &op_minmax) const
+{
+ auto val_grp = _file.openGroup(h5_value_grpname);
+ auto num_run = val_grp.getNumObjs();
+ auto run_grp = val_grp.createGroup(std::string("run_") + std::to_string(num_run));
+ auto model_grp = ensureGroup(run_grp, std::string("model_") + "0");
+ hsize_t dims[] = {2};
+ H5::DataSpace dspace(1, dims); // rank=1, dim(0)=2, {min, max}
+ for (auto &&e : input_minmax)
+ {
+ // key = {subg_idx, io_idx} = e.first
+ const auto subg_idx = e.first.first.value();
+ const auto io_idx = e.first.second.value();
+ auto subg_grp = ensureGroup(model_grp, std::string("subg_") + std::to_string(subg_idx));
+ auto input_dset = subg_grp.createDataSet(std::string("input_") + std::to_string(io_idx),
+ H5::PredType::IEEE_F32BE, dspace);
+ input_dset.write(e.second.data, H5::PredType::NATIVE_FLOAT);
+ }
+ for (auto &&e : op_minmax)
+ {
+ // key = {subg_idx, op_idx} = e.first
+ const auto subg_idx = e.first.first.value();
+ const auto op_idx = e.first.second.value();
+ auto subg_grp = ensureGroup(model_grp, std::string("subg_") + std::to_string(subg_idx));
+ auto op_dset = subg_grp.createDataSet(std::string("op_") + std::to_string(op_idx),
+ H5::PredType::IEEE_F32BE, dspace);
+ op_dset.write(e.second.data, H5::PredType::NATIVE_FLOAT);
+ }
+}
+
+} // namespace h5
+} // namespace dumper
+} // namespace onert
diff --git a/runtime/onert/core/src/dumper/h5/MinMaxDumper.h b/runtime/onert/core/src/dumper/h5/MinMaxDumper.h
new file mode 100644
index 000000000..d7e2c1c31
--- /dev/null
+++ b/runtime/onert/core/src/dumper/h5/MinMaxDumper.h
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_DUMPER_H5_MINMAX_DUMPER_H__
+#define __ONERT_DUMPER_H5_MINMAX_DUMPER_H__
+
+#include "exec/MinMaxMap.h"
+#include "Dumper.h"
+
+#include <H5Cpp.h>
+#include <string>
+
+namespace onert
+{
+namespace dumper
+{
+namespace h5
+{
+
+// The hierachy of single model minmax h5 file
+//
+// GROUP /
+// GROUP value
+// └── GROUP run_{idx}
+// └── GROUP model_{idx}
+// └── GROUP subg_{idx}
+// ├── DATASET op_{idx}
+// │ DATATYPE Float32
+// │ DATASPACE (2)
+// │ DATA { min, max }
+// └── DATASET input_{idx}
+// DATATYPE Float32
+// DATASPACE (2)
+// DATA { min, max }
+// GROUP name (optional, for debug)
+// └── GROUP model_{idx}
+// └── GROUP subg_{idx}
+// ├── ATTRIBUTE op_{idx}
+// │ DATATYPE String
+// │ DATA { "op/name"}
+// └── ATTRIBUTE input_{idx}
+// DATATYPE String
+// DATA { "input/name"}
+//
+class MinMaxDumper : private Dumper
+{
+public:
+ MinMaxDumper(const std::string &filepath);
+ /**
+ * @brief Dump input minmax map
+ *
+ * @param[in] in_minmax input minmax map
+ * @param[in] op_minmax op minmax map
+ */
+ void dump(const exec::IOMinMaxMap &in_minmax, const exec::OpMinMaxMap &op_minmax) const;
+
+private:
+ H5::Group _val_grp;
+};
+
+} // namespace h5
+} // namespace dumper
+} // namespace onert
+
+#endif // __ONERT_DUMPER_H5_MINMAX_DUMPER_H__
diff --git a/runtime/onert/core/src/dumper/text/GraphDumper.cc b/runtime/onert/core/src/dumper/text/GraphDumper.cc
new file mode 100644
index 000000000..c89253bda
--- /dev/null
+++ b/runtime/onert/core/src/dumper/text/GraphDumper.cc
@@ -0,0 +1,108 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GraphDumper.h"
+
+#include "ir/Graph.h"
+#include "compiler/LoweredGraph.h"
+#include "compiler/train/LoweredTrainableGraph.h"
+#include "util/logging.h"
+#include "misc/string_helpers.h"
+
+namespace onert
+{
+namespace dumper
+{
+namespace text
+{
+
+namespace
+{
+
+std::string formatOperandIndexSequence(const ir::OperandIndexSequence &seq)
+{
+ std::vector<std::string> strs;
+ for (auto &&ind : seq)
+ strs.push_back(dumper::text::formatOperandBrief(ind));
+ return nnfw::misc::join(strs.begin(), strs.end(), ",");
+}
+
+} // namespace
+
+std::string formatOperandBrief(ir::OperandIndex ind)
+{
+ std::stringstream ss;
+ ss << ind;
+ return ss.str();
+}
+
+std::string formatOperand(const ir::Graph &, ir::OperandIndex ind)
+{
+ std::stringstream ss;
+ ss << ind;
+ // TODO Print shape, type and maybe more
+ return ss.str();
+}
+
+std::string formatOperation(const ir::IOperation &op, ir::OperationIndex ind)
+{
+ std::stringstream ss;
+
+ ss << formatOperandIndexSequence(op.getOutputs());
+ ss << " = ";
+ ss << ind << "_" << op.name() << "(";
+ ss << formatOperandIndexSequence(op.getInputs());
+ ss << ")";
+ return ss.str();
+}
+
+std::string formatOperation(const ir::Graph &graph, ir::OperationIndex ind)
+{
+ std::stringstream ss;
+ const auto &op = graph.operations().at(ind);
+ return formatOperation(op, ind);
+}
+
+void dumpGraph(const ir::Graph &graph)
+{
+ VERBOSE(GraphDumper) << "{\n";
+ auto ops_topol = graph.topolSortOperations();
+ for (auto &&op_ind : ops_topol)
+ {
+ const auto &op = graph.operations().at(op_ind);
+ VERBOSE(GraphDumper) << " " << formatOperation(op, op_ind) << "\n";
+ }
+ graph.operands().iterate([&](const ir::OperandIndex &idx, const ir::Operand &oprd) {
+ VERBOSE(GraphDumper) << " Origin(" << idx << "): " << oprd.originIndex() << std::endl;
+ });
+ VERBOSE(GraphDumper) << "}\n";
+}
+
+void dumpLoweredGraph(const compiler::LoweredGraph &lgraph)
+{
+ // TODO Graph dump with backend info
+ dumpGraph(lgraph.graph());
+}
+
+void dumpLoweredGraph(const compiler::train::LoweredTrainableGraph &lgraph)
+{
+ // TODO Graph dump with backend info
+ dumpGraph(lgraph.graph());
+}
+
+} // namespace text
+} // namespace dumper
+} // namespace onert
diff --git a/runtime/onert/core/src/dumper/text/GraphDumper.h b/runtime/onert/core/src/dumper/text/GraphDumper.h
new file mode 100644
index 000000000..3cc13c92e
--- /dev/null
+++ b/runtime/onert/core/src/dumper/text/GraphDumper.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_DUMPER_TEXT_GRAPH_DUMPER_H__
+#define __ONERT_DUMPER_TEXT_GRAPH_DUMPER_H__
+
+#include <ir/Index.h>
+
+namespace onert
+{
+namespace ir
+{
+class Graph;
+struct IOperation;
+} // namespace ir
+} // namespace onert
+
+namespace onert
+{
+namespace compiler
+{
+class LoweredGraph;
+
+namespace train
+{
+class LoweredTrainableGraph;
+} // namespace train
+} // namespace compiler
+} // namespace onert
+
+namespace onert
+{
+namespace dumper
+{
+namespace text
+{
+
+std::string formatOperandBrief(ir::OperandIndex ind);
+std::string formatOperand(const ir::Graph &, ir::OperandIndex ind);
+std::string formatOperation(const ir::Graph &graph, ir::OperationIndex ind);
+void dumpGraph(const ir::Graph &graph);
+void dumpLoweredGraph(const compiler::LoweredGraph &lgraph);
+void dumpLoweredGraph(const compiler::train::LoweredTrainableGraph &lgraph);
+
+} // namespace text
+} // namespace dumper
+} // namespace onert
+
+#endif // __ONERT_DUMPER_TEXT_GRAPH_DUMPER_H__
diff --git a/runtime/onert/core/src/exec/DataflowExecutor.cc b/runtime/onert/core/src/exec/DataflowExecutor.cc
index a69ae9cdb..50984cefc 100644
--- a/runtime/onert/core/src/exec/DataflowExecutor.cc
+++ b/runtime/onert/core/src/exec/DataflowExecutor.cc
@@ -54,14 +54,13 @@ void DataflowExecutor::emplaceToReadyJobs(const uint32_t &id)
{
auto &job = _waiting_jobs[id];
assert(job != nullptr);
- auto &op_seq = _lowered_graph->op_seqs().at(_job_to_op_seq[job->index()]);
- auto rank = calculateRank(op_seq.operations());
+ auto rank = calculateRank({_job_to_op[job->index()]});
_ready_jobs.emplace(rank, std::move(job));
}
void DataflowExecutor::notify(uint32_t finished_job_id)
{
- for (auto id : _output_info[finished_job_id])
+ for (auto &&id : _output_info[finished_job_id])
{
assert(_input_info[id] > 0);
auto count = --_input_info[id];
@@ -77,57 +76,54 @@ bool DataflowExecutor::noWaitingJobs()
[](const std::unique_ptr<Job> &job) { return job == nullptr; });
}
-DataflowExecutor::DataflowExecutor(
- std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors,
- const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors,
- const compiler::TensorRegistries &tensor_regs, backend::TensorManagerSet &&tensor_mgrs,
- compiler::CodeMap &&code_map)
- : ExecutorBase{std::move(lowered_graph), input_tensors, output_tensors, tensor_regs,
- std::move(tensor_mgrs)},
- _code_map{std::move(code_map)}
+DataflowExecutor::DataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+ backend::BackendContexts &&backend_contexts,
+ const compiler::TensorRegistries &tensor_regs,
+ compiler::CodeMap &&code_map,
+ const util::TracingCtx *tracing_ctx)
+ : ExecutorBase{std::move(lowered_graph), std::move(backend_contexts), tensor_regs, tracing_ctx},
+ _code_map{std::move(code_map)}
{
VERBOSE(DataflowExecutor) << "Constructing Dataflow Executor" << std::endl;
- const auto &op_seqs = _lowered_graph->op_seqs();
- // Assign jobs convert OpSequenceIndex to job index(uint32_t)
+ // Assign jobs convert OperationIndex to job index(uint32_t)
uint32_t next_job_index = 0;
- std::unordered_map<ir::OpSequenceIndex, uint32_t> op_seq_to_job;
- op_seqs.iterate([&](const ir::OpSequenceIndex &op_seq_index, const ir::OpSequence &) {
- VERBOSE(DataflowExecutor) << "Create a job #" << next_job_index << " with OpSequenceIndex "
- << op_seq_index.value() << std::endl;
+ std::unordered_map<ir::OperationIndex, uint32_t> op_to_job;
+ const auto &operations = _lowered_graph->graph().operations();
+ operations.iterate([&](const ir::OperationIndex &op_ind, const ir::IOperation &) {
+ VERBOSE(DataflowExecutor) << "Create a job " << next_job_index << " with Operation " << op_ind
+ << std::endl;
_finished_jobs.emplace_back(
- std::make_unique<Job>(next_job_index, _code_map.at(op_seq_index).fn_seq.get()));
- op_seq_to_job[op_seq_index] = next_job_index++;
+ std::make_unique<Job>(next_job_index, _code_map.at(op_ind).fn_seq.get()));
+ op_to_job[op_ind] = next_job_index++;
});
_waiting_jobs.resize(next_job_index);
_output_info.resize(next_job_index);
_initial_input_info.resize(next_job_index, 0);
- op_seqs.iterate([&](const ir::OpSequenceIndex &op_seq_index, const ir::OpSequence &op_seq) {
- auto job_index = op_seq_to_job[op_seq_index];
- for (auto output : op_seq.getOutputs())
+ operations.iterate([&](const ir::OperationIndex &op_ind, const ir::IOperation &op) {
+ auto job_index = op_to_job[op_ind];
+ for (auto &&output : op.getOutputs())
{
// Update output and input info
- op_seqs.iterate(
- [&](const ir::OpSequenceIndex &op_seq_cur_index, const ir::OpSequence &op_seq_cur) {
- if (op_seq_cur.getInputs().contains(output))
- {
- auto dep_index = op_seq_to_job[op_seq_cur_index];
- ++_initial_input_info[dep_index];
- _output_info[job_index].push_back(dep_index);
- }
- });
+ operations.iterate([&](const ir::OperationIndex &op_cur_ind, const ir::IOperation &op_cur) {
+ if (op_cur.getInputs().contains(output))
+ {
+ auto dep_index = op_to_job[op_cur_ind];
+ ++_initial_input_info[dep_index];
+ _output_info[job_index].push_back(dep_index);
+ }
+ });
}
});
- for (const auto &s : op_seq_to_job)
- _job_to_op_seq.emplace(s.second, s.first);
+ for (const auto &s : op_to_job)
+ _job_to_op.emplace(s.second, s.first);
_input_info = _initial_input_info;
}
-void DataflowExecutor::executeImpl()
+void DataflowExecutor::executeImpl(const ExecutionObservee &subject)
{
assert(noWaitingJobs());
@@ -145,35 +141,38 @@ void DataflowExecutor::executeImpl()
}
assert(!_ready_jobs.empty()); // Cannot begin if there is no initial jobs
- _subject.notifyModelBegin(this);
+ auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_graph);
+
+ subject.notifySubgraphBegin(profiling_subg_index);
while (!_ready_jobs.empty())
{
auto job = std::move((_ready_jobs.begin())->second);
_ready_jobs.erase(_ready_jobs.begin());
auto job_index = job->index();
- VERBOSE(DataflowExecutor) << "Run job #" << job_index << std::endl;
+ VERBOSE(DataflowExecutor) << "Run job " << job_index << std::endl;
+
+ auto op_ind = _job_to_op[job_index];
+ const backend::Backend *backend = _lowered_graph->lower_info().operation.at(op_ind).backend();
- auto op_seq_index = _job_to_op_seq[job_index];
- auto op_seq = &_lowered_graph->op_seqs().at(op_seq_index);
- const backend::Backend *backend =
- _lowered_graph->getLowerInfo()->op_seq.at(op_seq_index)->backend();
+ subject.notifyJobBegin(this, profiling_subg_index, op_ind, backend);
- _subject.notifyJobBegin(this, op_seq, backend);
+ job->fn_seq()->initRunning();
// check if FunctionSequence needs to handle dynamic tensor
- bool handle_dynamic_tensor = op_seq->has_dynamic_tensor() || dynamic_input_exists;
+ bool handle_dynamic_tensor =
+ _lowered_graph->getHasDynamicTensor(op_ind) || dynamic_input_exists;
job->fn_seq()->enableDynamicShapeInferer(handle_dynamic_tensor);
job->run();
- _subject.notifyJobEnd(this, op_seq, backend);
+ subject.notifyJobEnd(this, profiling_subg_index, op_ind, backend);
notify(job_index);
_finished_jobs[job_index] = std::move(job);
}
assert(noWaitingJobs());
- _subject.notifyModelEnd(this);
+ subject.notifySubgraphEnd(profiling_subg_index);
// Reset input info for the next execution
_input_info = _initial_input_info;
diff --git a/runtime/onert/core/src/exec/DataflowExecutor.h b/runtime/onert/core/src/exec/DataflowExecutor.h
index 8d60e3e4b..750dc244f 100644
--- a/runtime/onert/core/src/exec/DataflowExecutor.h
+++ b/runtime/onert/core/src/exec/DataflowExecutor.h
@@ -17,17 +17,17 @@
#ifndef __ONERT_EXEC_DATAFLOW_EXECUTOR_H__
#define __ONERT_EXEC_DATAFLOW_EXECUTOR_H__
-#include <list>
-#include <map>
-#include <unordered_map>
-
-#include "exec/FunctionSequence.h"
+#include "ExecutorBase.h"
#include "Job.h"
+
+#include "compiler/CodeMap.h"
#include "ir/OperandIndexSequence.h"
-#include "ir/Index.h"
+#include "util/TracingCtx.h"
+
+#include <list>
+#include <map>
#include <memory>
-#include "exec/ExecutorBase.h"
-#include "compiler/CodeMap.h"
+#include <unordered_map>
namespace onert
{
@@ -47,15 +47,14 @@ public:
*
* @param lowered_graph LoweredGraph object
* @param tensor_builders Tensor builders that are currently used
- * @param code_map OpSequence and its code map
+ * @param code_map @c ir::Operation and its code map
*/
DataflowExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors,
- const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors,
- const compiler::TensorRegistries &tensor_regs,
- backend::TensorManagerSet &&tensor_mgrs, compiler::CodeMap &&code_map);
+ backend::BackendContexts &&backend_contexts,
+ const compiler::TensorRegistries &tensor_regs, compiler::CodeMap &&code_map,
+ const util::TracingCtx *tracing_ctx);
- void executeImpl() override;
+ void executeImpl(const ExecutionObservee &subject) override;
protected:
int64_t calculateRank(const std::vector<ir::OperationIndex> &operations);
@@ -88,7 +87,7 @@ protected:
std::multimap<int64_t, std::unique_ptr<Job>, std::greater<int64_t>> _ready_jobs;
/// @brief Which job runs which op and function.
- std::unordered_map<uint32_t, ir::OpSequenceIndex> _job_to_op_seq;
+ std::unordered_map<uint32_t, ir::OperationIndex> _job_to_op;
};
} // namespace exec
diff --git a/runtime/onert/core/src/exec/DynamicShapeInference.cc b/runtime/onert/core/src/exec/DynamicShapeInferer.cc
index 70bddfce4..691a11933 100644
--- a/runtime/onert/core/src/exec/DynamicShapeInference.cc
+++ b/runtime/onert/core/src/exec/DynamicShapeInferer.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "exec/DynamicShapeInference.h"
+#include "exec/DynamicShapeInferer.h"
#include "util/ShapeInference.h"
#include <assert.h>
@@ -23,14 +23,6 @@ namespace onert
namespace exec
{
-inline backend::IDynamicTensorManager *
-dynamicTensorManagerOf(const std::shared_ptr<backend::ITensor> &tensor)
-{
- if (!tensor->dynamic_tensor_manager())
- throw std::runtime_error{"Dynamic Tensor Manager is not available for this tensor."};
- return tensor->dynamic_tensor_manager();
-}
-
void DynamicShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op,
const ir::OperandIndex lhs_idx,
const ir::OperandIndex rhs_idx)
@@ -56,15 +48,15 @@ void DynamicShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op,
So, only when all inputs are static, we can skip dynamic shape inference.
*/
- if ((!lhs->is_dynamic()) && (!rhs->is_dynamic()))
- return;
-
auto output_idx = op.getOutputs().at(0);
auto output = _tensor_registry->getITensor(output_idx);
+ if ((currently_static(lhs) && currently_static(rhs)) && previously_static(output))
+ return;
+
ir::Shape new_shape = shape_inference::inferEltwiseShape(lhs_shape, rhs_shape);
- dynamicTensorManagerOf(output)->applyShape(output_idx, new_shape);
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
@@ -96,30 +88,32 @@ void DynamicShapeInferer::handleSimpleUnaryOp(const ir::Operation &op,
auto output_ind = op.getOutputs().at(0);
auto output = _tensor_registry->getITensor(output_ind);
- dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape);
+ output->applyShape(output_shape);
assert(output->buffer() != nullptr);
}
-void DynamicShapeInferer::visit(const ir::operation::ArgMax &op)
+void DynamicShapeInferer::visit(const ir::operation::ArgMinMax &op)
{
- const auto input_idx{op.getInputs().at(ir::operation::ArgMax::Input::INPUT)};
- const auto &input = _tensor_registry->getITensor(input_idx);
- auto input_shape = input->getShape();
+ const auto input_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::INPUT)};
+ const auto input = _tensor_registry->getITensor(input_idx);
- if (!input->is_dynamic())
- return;
-
- const auto rank = input_shape.rank();
- const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
-
- assert(0 <= axis && axis < rank);
+ const auto axis_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::AXIS)};
+ const auto axis = _tensor_registry->getITensor(axis_idx);
auto output_ind = op.getOutputs().at(0);
auto output = _tensor_registry->getITensor(output_ind);
- ir::Shape new_shape = shape_inference::inferArgMaxShape(input_shape, axis, rank);
+ if (!input->is_dynamic() && !output->is_dynamic())
+ return;
- dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape);
+ auto input_shape = input->getShape();
+ auto axis_value = *reinterpret_cast<const int32_t *>(axis->buffer());
+ const auto rank = input_shape.rank();
+ axis_value = axis_value < 0 ? axis_value + rank : axis_value;
+
+ ir::Shape new_shape = shape_inference::inferArgMinMaxShape(input_shape, axis_value, rank);
+
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
@@ -141,7 +135,68 @@ void DynamicShapeInferer::visit(const ir::operation::BatchMatMul &op)
// TODO
auto new_shape = shape_inference::inferBatchMatMulShape(lhs_shape, rhs_shape, op.param());
- dynamicTensorManagerOf(output)->applyShape(output_index, new_shape);
+ output->applyShape(new_shape);
+}
+
+void DynamicShapeInferer::visit(const ir::operation::BCQFullyConnected &op)
+{
+ const auto input_idx{op.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
+ const auto &input = _tensor_registry->getITensor(input_idx);
+
+ const auto cluster_idx{
+ op.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)};
+ const auto &cluster = _tensor_registry->getITensor(cluster_idx);
+ assert(cluster->is_constant());
+
+ if (!input->is_dynamic())
+ return;
+
+ auto input_shape = input->getShape();
+ auto cluster_shape = cluster->getShape();
+
+ auto cluster_buf = reinterpret_cast<const int32_t *>(cluster->buffer());
+ assert(cluster_buf);
+
+ ir::Shape new_shape =
+ shape_inference::inferBCQFullyConnectedShape(input_shape, cluster_shape, cluster_buf);
+
+ auto output_ind = op.getOutputs().at(0);
+ auto output = _tensor_registry->getITensor(output_ind);
+
+ output->applyShape(new_shape);
+ assert(output->buffer() != nullptr);
+}
+
+void DynamicShapeInferer::visit(const ir::operation::BCQGather &op)
+{
+ const auto indices_idx{op.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
+ const auto &indices = _tensor_registry->getITensor(indices_idx);
+
+ const auto input_binary_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
+ const auto &input_binary = _tensor_registry->getITensor(input_binary_idx);
+
+ const auto cluster_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)};
+ const auto &cluster = _tensor_registry->getITensor(cluster_idx);
+ assert(cluster->is_constant());
+
+ if (!indices->is_dynamic())
+ return;
+
+ auto indices_shape = indices->getShape();
+ auto cluster_shape = cluster->getShape();
+ auto rank = input_binary->getShape().rank();
+
+ auto cluster_buf = reinterpret_cast<const int32_t *>(cluster->buffer());
+ assert(cluster_buf);
+
+ ir::Shape new_shape = shape_inference::inferBCQGatherShape(indices_shape, cluster_shape,
+ cluster_buf, rank, op.param());
+
+ auto output_ind = op.getOutputs().at(0);
+ auto output = _tensor_registry->getITensor(output_ind);
+
+ output->applyShape(new_shape);
+ assert(output->buffer() != nullptr);
}
void DynamicShapeInferer::visit(const ir::operation::BinaryArithmetic &op)
@@ -167,10 +222,10 @@ void DynamicShapeInferer::visit(const ir::operation::BroadcastTo &op)
assert(shape); // It shouldn't be 0.
auto output_shape = shape_inference::inferBroadcastToShape(
- shape->getShape(), reinterpret_cast<const int32_t *>(shape->buffer()));
+ shape->getShape(), reinterpret_cast<const int32_t *>(shape->buffer()));
// set output shape and output buffer
- dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape);
+ output->applyShape(output_shape);
assert(output->buffer() != nullptr);
}
@@ -198,7 +253,7 @@ void DynamicShapeInferer::visit(const ir::operation::Concat &op)
So, only when all inputs are static, we can skip dynamic shape inference.
*/
bool all_static = true;
- for (auto input_ind : op.getInputs())
+ for (auto &&input_ind : op.getInputs())
{
auto input = _tensor_registry->getITensor(input_ind);
if (input->is_dynamic())
@@ -215,15 +270,17 @@ void DynamicShapeInferer::visit(const ir::operation::Concat &op)
{
auto isConcatible = [](const backend::ITensor *input1, const backend::ITensor *input2,
int32_t axis) {
- if (input1->num_dimensions() != input2->num_dimensions())
+ auto shape1 = input1->getShape();
+ auto shape2 = input2->getShape();
+ if (shape1.rank() != shape2.rank())
return false;
- for (size_t i = 0; i < input1->num_dimensions(); i++)
+ for (int i = 0; i < shape1.rank(); i++)
{
- auto positive_axis = (axis >= 0) ? axis : axis + input1->num_dimensions();
+ auto positive_axis = (axis >= 0) ? axis : axis + input1->getShape().rank();
if (i != positive_axis)
- if (input1->dimension(i) != input2->dimension(i))
+ if (shape1.dim(i) != shape2.dim(i))
return false;
}
@@ -233,17 +290,17 @@ void DynamicShapeInferer::visit(const ir::operation::Concat &op)
auto first_input_ind = op.getInputs().at(0);
auto first_input = _tensor_registry->getITensor(first_input_ind);
- for (auto input_ind : op.getInputs())
+ for (auto &&input_ind : op.getInputs())
{
auto input = _tensor_registry->getITensor(input_ind);
- if (input != first_input && !isConcatible(first_input.get(), input.get(), op.param().axis))
+ if (input != first_input && !isConcatible(first_input, input, op.param().axis))
throw std::runtime_error("input shapes does not matched for concat");
}
}
// getting output shape
onert::shape_inference::Shapes in_shapes;
- for (auto input_ind : op.getInputs())
+ for (auto &&input_ind : op.getInputs())
{
auto input = _tensor_registry->getITensor(input_ind);
ir::Shape shape = input->getShape();
@@ -255,7 +312,7 @@ void DynamicShapeInferer::visit(const ir::operation::Concat &op)
auto output = _tensor_registry->getITensor(output_ind);
auto output_shape = shape_inference::inferConcatShape(in_shapes, op.param());
- dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape);
+ output->applyShape(output_shape);
}
void DynamicShapeInferer::visit(const ir::operation::Conv2D &op)
@@ -278,7 +335,7 @@ void DynamicShapeInferer::visit(const ir::operation::Conv2D &op)
ir::Shape output_shape = shape_inference::inferConv2DShape(input_shape, ker_shape, op.param());
- dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape);
+ output->applyShape(output_shape);
assert(output->buffer() != nullptr);
}
@@ -333,12 +390,18 @@ void DynamicShapeInferer::visit(const ir::operation::ExpandDims &op)
auto axis_ind = op.getInputs().at(ir::operation::ExpandDims::AXIS);
auto axis = _tensor_registry->getITensor(axis_ind);
- auto axis_buf = reinterpret_cast<const int32_t *>(axis->buffer());
- assert(axis_buf);
+ auto axis_type = axis->data_type();
+ assert(axis_type == ir::DataType::INT32 || axis_type == ir::DataType::INT64);
+
+ assert(axis->buffer());
+ int32_t axis_value =
+ (axis_type == ir::DataType::INT32)
+ ? reinterpret_cast<const int32_t *>(axis->buffer())[0]
+ : static_cast<int32_t>(reinterpret_cast<const int64_t *>(axis->buffer())[0]);
- auto output_shape = shape_inference::inferExpandDimsShape(input_shape, axis_buf[0]);
+ auto output_shape = shape_inference::inferExpandDimsShape(input_shape, axis_value);
- dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape);
+ output->applyShape(output_shape);
assert(output->buffer() != nullptr);
}
@@ -347,21 +410,26 @@ void DynamicShapeInferer::visit(const ir::operation::Fill &op)
// check if output is not dynamic
auto output_ind = op.getOutputs().at(0);
auto output = _tensor_registry->getITensor(output_ind);
- auto input_ind = op.getInputs().at(ir::operation::Fill::Input::INPUT);
- auto input = _tensor_registry->getITensor(input_ind);
- ir::Shape input_shape = input->getShape();
+ auto shape_ind = op.getInputs().at(ir::operation::Fill::Input::SHAPE);
+ auto shape = _tensor_registry->getITensor(shape_ind);
- if ((!input->is_dynamic()) && (!output->is_dynamic()))
+ if ((!shape->is_dynamic()) && (!output->is_dynamic()))
return;
- assert(input.get()->data_type() == ir::DataType::INT32);
+ const auto dims_type = shape->data_type();
+ assert(dims_type == ir::DataType::INT32 || dims_type == ir::DataType::INT64);
- auto input_buf = reinterpret_cast<const int32_t *>(input->buffer());
- assert(input_buf);
+ auto dims_buf = shape->buffer();
+ assert(dims_buf);
- auto output_shape = shape_inference::inferFillShape(input_shape, input_buf);
+ const auto &dims_shape = shape->getShape();
+ const auto &output_shape = ((dims_type == ir::DataType::INT32)
+ ? shape_inference::inferFillShape<int32_t>(
+ dims_shape, reinterpret_cast<const int32_t *>(dims_buf))
+ : shape_inference::inferFillShape<int64_t>(
+ dims_shape, reinterpret_cast<const int64_t *>(dims_buf)));
- dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape);
+ output->applyShape(output_shape);
assert(output->buffer() != nullptr);
}
@@ -384,7 +452,7 @@ void DynamicShapeInferer::visit(const ir::operation::FullyConnected &op)
auto output_ind = op.getOutputs().at(0);
auto output = _tensor_registry->getITensor(output_ind);
- dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape);
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
@@ -416,7 +484,7 @@ void DynamicShapeInferer::visit(const ir::operation::Gather &op)
auto output_ind = op.getOutputs().at(0);
auto output = _tensor_registry->getITensor(output_ind);
- dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape);
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
@@ -425,11 +493,122 @@ void DynamicShapeInferer::visit(const ir::operation::L2Normalization &op)
handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::INPUT));
}
+void DynamicShapeInferer::visit(const ir::operation::LSTM &op)
+{
+ const auto output_index{op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
+ auto output = _tensor_registry->getITensor(output_index);
+
+ const auto output_state_out_index{
+ op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)};
+
+ const auto cell_state_out_index{op.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)};
+
+ const auto scratch_buffer_index{op.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)};
+
+ if (!output->is_dynamic() &&
+ !(_tensor_registry->getITensor(output_state_out_index) != nullptr &&
+ _tensor_registry->getITensor(output_state_out_index)->is_dynamic()) &&
+ !(_tensor_registry->getITensor(cell_state_out_index) != nullptr &&
+ _tensor_registry->getITensor(cell_state_out_index)->is_dynamic()) &&
+ !(_tensor_registry->getITensor(scratch_buffer_index) != nullptr &&
+ _tensor_registry->getITensor(cell_state_out_index)->is_dynamic()))
+ return;
+
+ const auto input_index{op.getInputs().at(ir::operation::LSTM::Input::INPUT)};
+ const auto input = _tensor_registry->getITensor(input_index);
+ const auto input_shape = input->getShape();
+
+ const auto input_to_output_weights_index{
+ op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
+ const auto input_to_output_weights = _tensor_registry->getITensor(input_to_output_weights_index);
+ const auto input_to_output_weights_shape = input_to_output_weights->getShape();
+
+ const auto recurrent_to_output_weights_index{
+ op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
+ const auto recurrent_to_output_weights =
+ _tensor_registry->getITensor(recurrent_to_output_weights_index);
+ const auto recurrent_to_output_weights_shape = recurrent_to_output_weights->getShape();
+
+ // re-sizing outputs
+ const int n_batch =
+ (input_shape.rank() == 3 && op.param().time_major) ? input_shape.dim(1) : input_shape.dim(0);
+ const int n_cell = input_to_output_weights_shape.dim(0);
+ const int n_output = recurrent_to_output_weights_shape.dim(1);
+ if (input_shape.rank() == 3)
+ {
+ if (op.param().time_major)
+ output->applyShape(ir::Shape{input_shape.dim(0), n_batch, n_output});
+ else
+ output->applyShape(ir::Shape{n_batch, input_shape.dim(1), n_output});
+ }
+ else
+ {
+ assert(input_shape.rank() == 2);
+ output->applyShape(ir::Shape{n_batch, n_output});
+ }
+ assert(output->buffer() != nullptr);
+
+ auto output_state_out = _tensor_registry->getITensor(output_state_out_index);
+ if (output_state_out != nullptr)
+ {
+ output_state_out->applyShape(ir::Shape{n_batch, n_output});
+ assert(output_state_out->buffer() != nullptr);
+ }
+
+ auto cell_state_out = _tensor_registry->getITensor(cell_state_out_index);
+ if (cell_state_out != nullptr)
+ {
+ cell_state_out->applyShape(ir::Shape{n_batch, n_cell});
+ assert(cell_state_out->buffer() != nullptr);
+ }
+
+ auto scratch_buffer = _tensor_registry->getITensor(scratch_buffer_index);
+ if (scratch_buffer != nullptr)
+ {
+ const auto input_to_input_weights_index{
+ op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)};
+ const auto recurrent_to_input_weights_index{
+ op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)};
+
+ const auto input_to_input_weights_shape =
+ _tensor_registry->getITensor(input_to_input_weights_index)->getShape();
+ bool has_input_to_input_weights =
+ input_to_input_weights_shape.dim(0) != 0 && input_to_input_weights_shape.dim(1) != 0;
+
+ const auto recurrent_to_input_weights_shape =
+ _tensor_registry->getITensor(recurrent_to_input_weights_index)->getShape();
+ bool has_recurrent_to_input_weights =
+ recurrent_to_input_weights_shape.dim(0) != 0 && recurrent_to_input_weights_shape.dim(1) != 0;
+
+ // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
+ // true: no CIFG
+ // false: CIFG
+ bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
+ if (has_cifg_param)
+ {
+ scratch_buffer->applyShape(ir::Shape{n_batch, n_cell * 4});
+ }
+ else
+ {
+ scratch_buffer->applyShape(ir::Shape{n_batch, n_cell * 3});
+ }
+ assert(scratch_buffer->buffer() != nullptr);
+ }
+}
+
void DynamicShapeInferer::visit(const ir::operation::MatrixBandPart &op)
{
handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::MatrixBandPart::INPUT));
}
+void DynamicShapeInferer::visit(const ir::operation::DetectionPostProcess & /* op */)
+{
+ // NOTE DetectionPostProcess's undefined outputs' shape are decided on compile time
+ // by static shape inferer.
+ // DetectionPostProcess's outputs' shape are independent with input shape
+ // and decided by parameter value.
+}
+
void DynamicShapeInferer::visit(const ir::operation::OneHot &op)
{
auto output_ind = op.getOutputs().at(0);
@@ -452,7 +631,7 @@ void DynamicShapeInferer::visit(const ir::operation::OneHot &op)
const auto axis_val = op.param().axis;
ir::Shape new_shape = shape_inference::inferOnehotShape(indices_shape, *depth_buf, axis_val);
- dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape);
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
@@ -488,7 +667,7 @@ void DynamicShapeInferer::visit(const ir::operation::Pack &op)
ir::Shape new_shape = shape_inference::inferPackShape(input_shape, axis, rank, num);
- dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape);
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
@@ -512,10 +691,10 @@ void DynamicShapeInferer::visit(const ir::operation::Pad &op)
assert(pad_buf);
auto output_shape =
- shape_inference::inferPadShape(input->getShape(), pad_buf, pad->getShape().num_elements());
+ shape_inference::inferPadShape(input->getShape(), pad_buf, pad->getShape().num_elements());
// change output shape and reallocate output tensor memory
- dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape);
+ output->applyShape(output_shape);
assert(output->buffer() != nullptr);
}
@@ -526,6 +705,26 @@ void DynamicShapeInferer::visit(const ir::operation::Permute & /* op */)
// on-the-fly, as it must support inter-backend inference/allocation.
}
+void DynamicShapeInferer::visit(const ir::operation::Pool2D &op)
+{
+ // check if input is not dynamic
+ auto input_ind = op.getInputs().at(ir::operation::Pool2D::INPUT);
+ auto input = _tensor_registry->getITensor(input_ind);
+
+ if (!input->is_dynamic())
+ return;
+
+ ir::Shape input_shape = input->getShape();
+
+ auto output_ind = op.getOutputs().at(0);
+ auto output = _tensor_registry->getITensor(output_ind);
+
+ ir::Shape output_shape = shape_inference::inferPoolShape(input_shape, op.param());
+
+ output->applyShape(output_shape);
+ assert(output->buffer() != nullptr);
+}
+
void DynamicShapeInferer::visit(const ir::operation::Pow &op)
{
handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Pow::Input::LHS),
@@ -556,18 +755,18 @@ void DynamicShapeInferer::visit(const ir::operation::Range &op)
if (output->data_type() == ir::DataType::FLOAT32)
{
new_shape =
- shape_inference::inferRangeShape<float>(*reinterpret_cast<float *>(start_tensor->buffer()),
- *reinterpret_cast<float *>(limit_tensor->buffer()),
- *reinterpret_cast<float *>(delta_tensor->buffer()));
+ shape_inference::inferRangeShape<float>(*reinterpret_cast<float *>(start_tensor->buffer()),
+ *reinterpret_cast<float *>(limit_tensor->buffer()),
+ *reinterpret_cast<float *>(delta_tensor->buffer()));
}
else if (output->data_type() == ir::DataType::INT32)
{
new_shape = shape_inference::inferRangeShape<int32_t>(
- *reinterpret_cast<int32_t *>(start_tensor->buffer()),
- *reinterpret_cast<int32_t *>(limit_tensor->buffer()),
- *reinterpret_cast<int32_t *>(delta_tensor->buffer()));
+ *reinterpret_cast<int32_t *>(start_tensor->buffer()),
+ *reinterpret_cast<int32_t *>(limit_tensor->buffer()),
+ *reinterpret_cast<int32_t *>(delta_tensor->buffer()));
}
- dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape);
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
@@ -611,7 +810,7 @@ void DynamicShapeInferer::visit(const ir::operation::Reduce &op)
ir::Shape new_shape = shape_inference::inferReduceShape(input_shape, axes_vec, keep_dims);
- dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape);
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
@@ -658,14 +857,14 @@ void DynamicShapeInferer::visit(const ir::operation::Reshape &op)
int32_t *new_shape_buf = reinterpret_cast<int32_t *>(new_shape->buffer());
assert(new_shape_buf);
- auto output_shape = shape_inference::inferReshapeShape(
- new_shape_buf, new_shape->getShape().num_elements(), input->getShape().num_elements());
+ auto output_shape = shape_inference::inferReshapeShape(input->getShape(), new_shape_buf,
+ new_shape->getShape().num_elements());
// if shape is changed, change output shape and reallocate output tensor memory
if (output_shape != output->getShape() || output->buffer() == nullptr)
{
// change on output shape
- dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape);
+ output->applyShape(output_shape);
}
assert(output->buffer() != nullptr);
}
@@ -674,14 +873,14 @@ void DynamicShapeInferer::visit(const ir::operation::Reshape &op)
{
// Let's check the new_shape option
auto shape = op.param().new_shape;
- auto output_shape = shape_inference::inferReshapeShape(shape.data(), shape.size(),
- input->getShape().num_elements());
+ auto output_shape =
+ shape_inference::inferReshapeShape(input->getShape(), shape.data(), shape.size());
// if shape is changed, change output shape and reallocate output tensor memory
if (output_shape != output->getShape() || output->buffer() == nullptr)
{
// change on output shape
- dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape);
+ output->applyShape(output_shape);
}
assert(output->buffer() != nullptr);
}
@@ -705,14 +904,35 @@ void DynamicShapeInferer::visit(const ir::operation::ResizeBilinear &op)
return;
// getting output shape from input shape and Params
- auto output_shape = shape_inference::inferResizeBilinearShape(
- input->getShape(), op.param().height_out, op.param().width_out);
+ int32_t height_out, width_out;
+ if (op.getInputs().size() == 2)
+ {
+ auto size_ind = op.getInputs().at(ir::operation::ResizeBilinear::Input::SIZE);
+ auto size = _tensor_registry->getITensor(size_ind);
+ if (size->data_type() == ir::DataType::INT32)
+ {
+ auto size_buf = reinterpret_cast<const int32_t *>(size->buffer());
+ height_out = size_buf[0];
+ width_out = size_buf[1];
+ }
+ else
+ {
+ throw std::runtime_error("DynamicShapeInferer ResizeBilinear : Unsupported data type");
+ }
+ }
+ else
+ {
+ height_out = op.param().height_out;
+ width_out = op.param().width_out;
+ }
+ auto output_shape =
+ shape_inference::inferResizeBilinearShape(input->getShape(), height_out, width_out);
// if shape is changed, change output shape and reallocate output tensor memory
if (output_shape != output->getShape() || output->buffer() == nullptr)
{
// change on output shape
- dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape);
+ output->applyShape(output_shape);
}
assert(output->buffer() != nullptr);
}
@@ -744,12 +964,12 @@ void DynamicShapeInferer::visit(const ir::operation::Select &op)
// Select output shpae
ir::Shape new_shape =
- shape_inference::inferSelectShape(input_cond_shape, input_true_shape, input_false_shape);
+ shape_inference::inferSelectShape(input_cond_shape, input_true_shape, input_false_shape);
auto output_ind = op.getOutputs().at(0);
auto output = _tensor_registry->getITensor(output_ind);
- dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape);
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
@@ -768,7 +988,7 @@ void DynamicShapeInferer::visit(const ir::operation::Shape &op)
ir::Shape output_shape;
output_shape.append(input_shape.rank());
- dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape);
+ output->applyShape(output_shape);
assert(output->buffer() != nullptr);
}
@@ -794,7 +1014,7 @@ void DynamicShapeInferer::visit(const ir::operation::Slice &op)
ir::Shape new_shape = shape_inference::inferSliceShape(input_shape, begins_buf, sizes_buf);
- dynamicTensorManagerOf(output)->applyShape(output_index, new_shape);
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
@@ -829,9 +1049,9 @@ void DynamicShapeInferer::visit(const ir::operation::SpaceToBatchND &op)
auto padding_data = reinterpret_cast<int32_t *>(padding->buffer());
ir::Shape new_shape = shape_inference::inferSpaceToBatchNDShape(
- input_shape, block_shape_shape, padding_shape, block_shape_data, padding_data);
+ input_shape, block_shape_shape, padding_shape, block_shape_data, padding_data);
- dynamicTensorManagerOf(output)->applyShape(output_idx, new_shape);
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
@@ -840,27 +1060,37 @@ void DynamicShapeInferer::visit(const ir::operation::Split &op)
const auto input_idx{op.getInputs().at(ir::operation::Split::Input::INPUT)};
const auto &input = _tensor_registry->getITensor(input_idx);
- if (!input->is_dynamic())
+ // Return if all tensors are not dynamic
+ bool has_dynamic = false;
+ for (const auto &output_idx : op.getOutputs())
+ {
+ auto output = _tensor_registry->getITensor(output_idx);
+ has_dynamic |= output->is_dynamic();
+ }
+ if (!input->is_dynamic() && !has_dynamic)
{
return;
}
auto input_shape = input->getShape();
- const auto axis = op.param().axis;
+ const auto axis_idx{op.getInputs().at(ir::operation::Split::Input::AXIS)};
+ const auto &axis = _tensor_registry->getITensor(axis_idx);
+
+ auto axis_value = *reinterpret_cast<const int32_t *>(axis->buffer());
const auto num_splits = op.param().num_splits;
const auto rank = input_shape.rank();
- auto axis_resolved = axis < 0 ? axis + rank : axis;
+ axis_value = axis_value < 0 ? axis_value + rank : axis_value;
- assert(0 <= axis_resolved && axis_resolved < rank);
+ assert(0 <= axis_value && axis_value < rank);
- ir::Shape new_shape = shape_inference::inferSplitShape(input_shape, axis_resolved, num_splits);
+ ir::Shape new_shape = shape_inference::inferSplitShape(input_shape, axis_value, num_splits);
for (int out_tensor_idx = 0; out_tensor_idx < num_splits; out_tensor_idx++)
{
auto output_ind = op.getOutputs().at(out_tensor_idx);
auto output = _tensor_registry->getITensor(output_ind);
- dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape);
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
}
@@ -889,7 +1119,7 @@ void DynamicShapeInferer::visit(const ir::operation::Squeeze &op)
auto output_ind = op.getOutputs().at(0);
auto output = _tensor_registry->getITensor(output_ind);
- dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape);
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
@@ -920,17 +1150,16 @@ void DynamicShapeInferer::visit(const ir::operation::StridedSlice &op)
const auto rank = input_shape.rank();
auto op_params = shape_inference::buildStridedSliceParams(
- reinterpret_cast<uint32_t *>(starts->buffer()), reinterpret_cast<uint32_t *>(ends->buffer()),
- reinterpret_cast<uint32_t *>(strides->buffer()), begin_mask, end_mask, shrink_axis_mask,
- rank);
+ reinterpret_cast<uint32_t *>(starts->buffer()), reinterpret_cast<uint32_t *>(ends->buffer()),
+ reinterpret_cast<uint32_t *>(strides->buffer()), begin_mask, end_mask, shrink_axis_mask, rank);
auto output_index = op.getOutputs().at(0);
auto output = _tensor_registry->getITensor(output_index);
ir::Shape output_shape =
- onert::shape_inference::inferStridedSliceShape(input_shape, op_params, rank);
+ onert::shape_inference::inferStridedSliceShape(input_shape, op_params, rank);
- dynamicTensorManagerOf(output)->applyShape(output_index, output_shape);
+ output->applyShape(output_shape);
assert(output->buffer() != nullptr);
}
@@ -952,10 +1181,12 @@ void DynamicShapeInferer::visit(const ir::operation::Tile &op)
auto multiplier_buffer = reinterpret_cast<const int32_t *>(multiplier->buffer());
assert(multiplier_buffer);
- auto output_shape = shape_inference::inferTileShape(input_shape, multiplier_buffer);
+ auto mult_shape = multiplier->getShape();
+ auto output_shape = shape_inference::inferTileShape(
+ input_shape, multiplier_buffer, mult_shape.rank() == 0 ? 1 : mult_shape.dim(0));
// set output shape and output buffer
- dynamicTensorManagerOf(output)->applyShape(output_ind, output_shape);
+ output->applyShape(output_shape);
assert(output->buffer() != nullptr);
}
@@ -967,17 +1198,49 @@ void DynamicShapeInferer::visit(const ir::operation::Transpose &op)
// from op, access the buffer of second input to read new shape
auto input_ind = op.getInputs().at(ir::operation::Transpose::Input::INPUT);
- auto input_tensor = _tensor_registry->getITensor(input_ind);
- auto input_shape = input_tensor->getShape();
+ auto input = _tensor_registry->getITensor(input_ind);
+ auto input_shape = input->getShape();
+
+ /*
+ Here, the state after compilation (static shape inference) could be one of the following:
+
+ input perms output execution-time shape inf required
+ ------------------------------------ --------------------------------
+ case 1) static const static X
+ case 2) static non-const dynamic O
+ case 3) dynamic const dynamic O
+ case 4) dynamic non-const dynamic O
- if (!input_tensor->is_dynamic())
+ So, only when both input1 and ouput are static, we can skip dynamic shape inference.
+ */
+ if ((!input->is_dynamic()) && (!output->is_dynamic()))
return;
- const auto perm{op.param().perm};
- // set output shape, based on input and params
- ir::Shape new_shape = shape_inference::inferTransposeShape(input_shape, perm);
+ auto perm_ind = op.getInputs().at(ir::operation::Transpose::Input::PERMUTATION);
+ auto perm = _tensor_registry->getITensor(perm_ind);
+
+ ir::Shape new_shape;
+ // TODO Change perm->dimension(0) == 0 to perm->num_elements() == 0
+ if (perm->getShape().dim(0) == 0) // This condition means that perm is (n-1...0)
+ {
+ // Call by (n-1...0)
+ new_shape = shape_inference::inferTransposeShape(input_shape, nullptr, 0);
+ }
+ else
+ {
+ // Check rank
+ if (static_cast<size_t>(input->getShape().rank()) != perm->getShape().num_elements())
+ {
+ throw std::runtime_error("DynamicShapeInferer failed, bad rank size: " +
+ std::to_string(perm->getShape().num_elements()));
+ }
- dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape);
+ // set output shape, based on input and params
+ const auto perm_buffer = reinterpret_cast<const int32_t *>(perm->buffer());
+ new_shape =
+ shape_inference::inferTransposeShape(input_shape, perm_buffer, perm->getShape().dim(0));
+ }
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
@@ -1005,7 +1268,7 @@ void DynamicShapeInferer::visit(const ir::operation::Unpack &op)
auto output_ind = op.getOutputs().at(out_tensor_idx);
auto output = _tensor_registry->getITensor(output_ind);
- dynamicTensorManagerOf(output)->applyShape(output_ind, new_shape);
+ output->applyShape(new_shape);
assert(output->buffer() != nullptr);
}
diff --git a/runtime/onert/core/src/exec/EdgeTensor.cc b/runtime/onert/core/src/exec/EdgeTensor.cc
new file mode 100644
index 000000000..569a2f697
--- /dev/null
+++ b/runtime/onert/core/src/exec/EdgeTensor.cc
@@ -0,0 +1,55 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EdgeTensor.h"
+
+namespace onert
+{
+namespace exec
+{
+
+bool EdgeTensor::applyShape(const ir::Shape &new_shape)
+{
+ bool previously_dynamic = is_dynamic();
+ if (!previously_dynamic || _buffer == nullptr)
+ {
+ // Always set shape - when buffer with same size was already allocated, shape could differ
+ setShape(new_shape);
+ set_dynamic();
+ const auto total_size = get_info().total_size();
+ _buffer = std::make_unique<uint8_t[]>(total_size);
+ }
+ else
+ {
+ auto previous_size = total_size();
+ auto new_size = new_shape.num_elements() * ir::sizeOfDataType(data_type());
+ if (previous_size != new_size)
+ {
+ setShape(new_shape);
+ set_dynamic();
+ const auto total_size = get_info().total_size();
+ _buffer = std::make_unique<uint8_t[]>(total_size);
+ }
+ else
+ { // when buffer with same size was already allocated, shape could differ
+ setShape(new_shape);
+ }
+ }
+ return true;
+}
+
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/EdgeTensor.h b/runtime/onert/core/src/exec/EdgeTensor.h
new file mode 100644
index 000000000..8df79c389
--- /dev/null
+++ b/runtime/onert/core/src/exec/EdgeTensor.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_EXEC_EDGE_TENSOR_H__
+#define __ONERT_EXEC_EDGE_TENSOR_H__
+
+#include "backend/IPortableTensor.h"
+
+#include <memory>
+
+namespace onert
+{
+namespace exec
+{
+
+class EdgeTensor : public backend::IPortableTensor
+{
+public:
+ EdgeTensor(const ir::OperandInfo &info, ir::Layout layout)
+ : IPortableTensor(info), _layout{layout}, _buffer{nullptr}, _ref_count{0}
+ {
+ }
+ ~EdgeTensor() = default;
+
+ uint8_t *buffer() const override { return _buffer.get(); }
+ ir::Layout layout() const override { return _layout; }
+ void set_dynamic() override { _info.setDynamic(); }
+ bool applyShape(const ir::Shape &new_shape) override;
+ void setShape(const ir::Shape &new_shape) override { _info.shape(new_shape); }
+
+ void allocate_buffer()
+ {
+ const auto total_size = _info.total_size();
+ _buffer = std::make_unique<uint8_t[]>(total_size);
+ _ref_count = 1;
+ }
+
+ void increase_ref() { _ref_count++; }
+
+ void decrease_ref()
+ {
+ assert(_ref_count > 0);
+ _ref_count--;
+ if (_ref_count == 0)
+ {
+ _buffer.reset();
+ }
+ }
+
+private:
+ ir::Layout _layout;
+ std::unique_ptr<uint8_t[]> _buffer;
+ int32_t _ref_count;
+};
+
+} // namespace exec
+} // namespace onert
+
+#endif // __ONERT_EXEC_EDGE_TENSOR_H__
diff --git a/runtime/onert/core/src/exec/ExecTime.cc b/runtime/onert/core/src/exec/ExecTime.cc
index 6bf2744a9..4b82655b9 100644
--- a/runtime/onert/core/src/exec/ExecTime.cc
+++ b/runtime/onert/core/src/exec/ExecTime.cc
@@ -14,12 +14,10 @@
* limitations under the License.
*/
-#include "exec/ExecTime.h"
+#include "ExecTime.h"
-#include <fstream>
-#include <cassert>
-#include <limits>
#include <algorithm>
+#include <cassert>
namespace onert
{
diff --git a/runtime/onert/core/src/exec/ExecTime.h b/runtime/onert/core/src/exec/ExecTime.h
index 846d0930b..95f460053 100644
--- a/runtime/onert/core/src/exec/ExecTime.h
+++ b/runtime/onert/core/src/exec/ExecTime.h
@@ -34,7 +34,7 @@ class ExecTime
{
public:
explicit ExecTime(const std::vector<const backend::Backend *> &backends)
- : _json(backends, _measurements)
+ : _json(backends, _measurements)
{
}
@@ -94,7 +94,7 @@ public:
/**
* @brief Update metrics file with new data.
*/
- void uploadOperationsExecTime() const { _json.uploadOperationsExecTime(); }
+ void storeOperationsExecTime() const { _json.storeOperationsExecTime(); }
static const int64_t NOT_FOUND = -1;
private:
diff --git a/runtime/onert/core/src/exec/ExecTime.test.cc b/runtime/onert/core/src/exec/ExecTime.test.cc
new file mode 100644
index 000000000..939184e4e
--- /dev/null
+++ b/runtime/onert/core/src/exec/ExecTime.test.cc
@@ -0,0 +1,106 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ExecTime.h"
+
+#include "backend/IConfig.h"
+#include "backend/Backend.h"
+
+#include <gtest/gtest.h>
+
+#include <string>
+
+namespace
+{
+using namespace onert;
+using namespace exec;
+using namespace backend;
+
+struct MockConfig : public IConfig
+{
+ std::string id() override { return "b1"; }
+ bool initialize() override { return true; };
+ bool supportPermutation() override { return false; }
+ ir::Layout supportLayout(const ir::IOperation &, ir::Layout) override
+ {
+ return ir::Layout::UNKNOWN;
+ }
+ bool supportDynamicTensor() override { return false; }
+ bool supportFP16() override { return false; }
+};
+
+struct MockBackend : public ::onert::backend::Backend
+{
+ std::shared_ptr<onert::backend::IConfig> config() const override
+ {
+ return std::make_shared<MockConfig>();
+ }
+ std::unique_ptr<onert::backend::BackendContext> newContext(ContextData &&) const override
+ {
+ return nullptr;
+ }
+};
+
+TEST(ExecTime, roundtrip_ok)
+{
+ const auto *b = new MockBackend();
+ std::vector<const Backend *> bs = {b};
+ {
+ ExecTime et(bs);
+ et.updateOperationExecTime(b, "op1", true, 100, 100);
+ et.updateOperationExecTime(b, "op1", true, 200, 200);
+ et.updateOperationExecTime(b, "op1", false, 100, 888);
+ et.storeOperationsExecTime();
+ }
+ {
+ ExecTime et(bs);
+ auto time = et.getOperationExecTime(b, "op1", true, 100);
+ ASSERT_EQ(time, 100);
+ // Check interpolation
+ time = et.getOperationExecTime(b, "op1", true, 150);
+ ASSERT_EQ(time, 150);
+ time = et.getOperationExecTime(b, "op1", false, 100);
+ ASSERT_EQ(time, 888);
+ et.storeOperationsExecTime();
+ }
+ // clean up
+ EXPECT_EQ(remove("exec_time.json"), 0);
+}
+
+TEST(ExecTime, structure)
+{
+
+ const auto *b = new MockBackend();
+ std::vector<const Backend *> bs = {b};
+ {
+ ExecTime et(bs);
+ et.updateOperationExecTime(b, "op1", true, 100, 100);
+ et.updateOperationExecTime(b, "op1", true, 200, 200);
+ et.storeOperationsExecTime();
+ }
+ {
+ ExecTime et(bs);
+ auto time = et.getOperationExecTime(b, "op1", true, 100);
+ ASSERT_EQ(time, 100);
+ // Check interpolation
+ time = et.getOperationExecTime(b, "op1", true, 200);
+ ASSERT_EQ(time, 200);
+ et.storeOperationsExecTime();
+ }
+ // clean up
+ EXPECT_EQ(remove("exec_time.json"), 0);
+}
+} // unnamed namespace
diff --git a/runtime/onert/core/src/exec/Execution.cc b/runtime/onert/core/src/exec/Execution.cc
index 7feb3ab68..895a82ff8 100644
--- a/runtime/onert/core/src/exec/Execution.cc
+++ b/runtime/onert/core/src/exec/Execution.cc
@@ -16,6 +16,8 @@
#include "exec/Execution.h"
+#include "ir/DataType.h"
+#include "train/TrainableExecutors.h"
#include "util/logging.h"
namespace onert
@@ -23,116 +25,120 @@ namespace onert
namespace exec
{
-Execution::Execution(const std::shared_ptr<ExecutorMap> &executors) : _executors{executors}
+Execution::Execution(const std::shared_ptr<IExecutors> &executors) : _executors{executors}
{
assert(executors != nullptr);
- assert(executors->at(ir::SubgraphIndex{0}) != nullptr);
- const auto &primary_subg = primary_subgraph();
- _io_desc.inputs.resize(primary_subg.getInputs().size());
- _io_desc.outputs.resize(primary_subg.getOutputs().size());
+ assert(executors->entryExecutor() != nullptr);
+
+ // Initialize I/O description
+ _ctx.desc.inputs.resize(_executors->inputSize());
+ for (uint32_t i = 0; i < _executors->inputSize(); ++i)
+ _ctx.desc.inputs.at(i) = std::make_unique<InputDesc>(_executors->inputInfo(ir::IOIndex(i)));
+
+ _ctx.desc.outputs.resize(_executors->outputSize());
+ for (uint32_t i = 0; i < _executors->outputSize(); ++i)
+ _ctx.desc.outputs.at(i) = std::make_unique<OutputDesc>(_executors->outputInfo(ir::IOIndex(i)));
+ _ctx.shape_updated = false;
+
+ // Initialize options
+ ExecutionOptions::fromGlobalConfig(_ctx.options);
}
void Execution::changeInputShape(const ir::IOIndex &index, const ir::Shape &new_shape)
{
- // This should be called BEFORE setInput.
- if (_io_desc.inputs.at(index.value()) != 0)
- throw std::runtime_error("Error in calling order");
-
// This will be used later to set input tensor dynamic
// Note that 'compiled' model will not be updated with new_shape
// but new_shape will change model input shape while 'running' the model
- _io_desc.dynamic_input_shapes[index] = new_shape;
-}
-
-// TODO Remove default parameter
-void Execution::setInput(const ir::IOIndex &index, const void *buffer, size_t length,
- ir::Layout layout)
-{
- const auto input_index = primary_subgraph().getInputs().at(index);
- const auto info = primary_subgraph().operands().at(input_index).info();
-
- // TODO handle when (!buffer && length != 0) : setting the input as an optional tensor
-
- // check if size enough for input is passed
- // if input_shape_sig is set, input_shape_sig overrides shape in info
- // note: input_shape_sig contains shape passed by nnfw_set_input_tensorinfo()
+ auto &input_desc = _ctx.desc.inputs.at(index.value());
+ if (new_shape != input_desc->info.shape())
{
- auto input_shape_sig = _io_desc.dynamic_input_shapes.find(index);
- auto size_required = (input_shape_sig != _io_desc.dynamic_input_shapes.end())
- ? input_shape_sig->second.num_elements() *
- onert::ir::sizeOfDataType(info.typeInfo().type())
- : info.total_size();
+ input_desc->info.shape(new_shape);
+ _ctx.shape_updated = true;
- if (length < size_required)
- {
- throw std::runtime_error{"Too small length"};
- }
+ VERBOSE(Execution) << "Model input shape will be changed at the start of execute()"
+ << "(index: " << index << ")" << std::endl;
}
-
- _io_desc.inputs.at(index.value()) = std::make_unique<InputDesc>(info, buffer, length, layout);
}
// TODO Remove default parameter
-void Execution::setInput(const ir::IOIndex &index, const ir::TypeInfo &type, const ir::Shape &shape,
- const void *buffer, size_t length, ir::Layout layout)
+void Execution::setInput(const ir::IOIndex &index, const void *buffer, size_t length)
{
- auto info = ir::OperandInfo::createStaticInfo(shape, type);
-
- if (length < info.total_size())
- {
- throw std::runtime_error{"Too small length"};
- }
-
- _io_desc.inputs.at(index.value()) = std::make_unique<InputDesc>(info, buffer, length, layout);
+ // Length validation in execute(): datatype can be changed by API call
+ auto &input_desc = _ctx.desc.inputs.at(index.value());
+ input_desc->buffer = buffer;
+ input_desc->size = length;
}
-// TODO Remove default parameter
-void Execution::setOutput(const ir::IOIndex &index, void *buffer, size_t length, ir::Layout layout)
+void Execution::setInput(const ir::IOIndex &index, const ir::Shape &shape, const void *buffer,
+ size_t length)
{
- const auto output_index = primary_subgraph().getOutputs().at(index);
- const auto info = primary_subgraph().operands().at(output_index).info();
-
- if (length < info.total_size())
- {
- throw std::runtime_error{"Too small length"};
- }
-
- _io_desc.outputs.at(index.value()) = std::make_unique<OutputDesc>(info, buffer, length, layout);
+ changeInputShape(index, shape);
+ setInput(index, buffer, length);
}
-// TODO Remove default parameter
-void Execution::setOutput(const ir::IOIndex &index, const ir::TypeInfo &type,
- const ir::Shape &shape, void *buffer, size_t length, ir::Layout layout)
+void Execution::setOutput(const ir::IOIndex &index, void *buffer, size_t length)
{
- auto info = ir::OperandInfo::createStaticInfo(shape, type);
+ // Length validation in execute()
+ // - datatype can be changed by API call
+ // - shape can be changed by dynamic shape inference
+ auto &output_desc = _ctx.desc.outputs.at(index.value());
+ output_desc->buffer = buffer;
+ output_desc->size = length;
+}
- if (length < info.total_size())
- {
- throw std::runtime_error{"Too small length"};
- }
+void Execution::setOutput(const ir::IOIndex &index, const ir::Shape &shape, void *buffer,
+ size_t length)
+{
+ auto &output_desc = _ctx.desc.outputs.at(index.value());
+ output_desc->info.shape(shape);
- _io_desc.outputs.at(index.value()) = std::make_unique<OutputDesc>(info, buffer, length, layout);
+ setOutput(index, buffer, length);
}
void Execution::setInputLayout(const ir::IOIndex &index, ir::Layout layout)
{
- const auto &input_desc = _io_desc.inputs.at(index.value());
- _io_desc.inputs.at(index.value()) =
- std::make_unique<InputDesc>(input_desc->info, input_desc->buffer, input_desc->size, layout);
+ _ctx.desc.inputs.at(index.value())->layout = layout;
}
void Execution::setOutputLayout(const ir::IOIndex &index, ir::Layout layout)
{
- const auto &output_desc = _io_desc.outputs.at(index.value());
- _io_desc.outputs.at(index.value()) = std::make_unique<OutputDesc>(
- output_desc->info, output_desc->buffer, output_desc->size, layout);
+ _ctx.desc.outputs.at(index.value())->layout = layout;
+}
+
+void Execution::setInputType(const ir::IOIndex &index, const ir::TypeInfo &typeInfo)
+{
+ _ctx.desc.inputs.at(index.value())->info.typeInfo(typeInfo);
+ _ctx.shape_updated = true;
+}
+
+void Execution::setOutputType(const ir::IOIndex &index, const ir::TypeInfo &typeInfo)
+{
+ _ctx.desc.outputs.at(index.value())->info.typeInfo(typeInfo);
+ _ctx.shape_updated = true;
}
void Execution::execute()
{
VERBOSE(Execution) << "Start execution" << std::endl;
- primary_executor()->execute(_io_desc);
+ // Input length validation check
+ for (const auto &input : _ctx.desc.inputs)
+ {
+ if (input->info.total_size() > input->size)
+ throw std::runtime_error{"Too small input buffer length"};
+ }
+
+ // Output length validation check
+ if (!_ctx.shape_updated)
+ {
+ for (const auto &output : _ctx.desc.outputs)
+ {
+ if (output->info.total_size() > output->size)
+ throw std::runtime_error{"Too small output buffer length"};
+ }
+ }
+
+ _executors->execute(_ctx);
finished = true;
VERBOSE(Execution) << "Execution finished" << std::endl;
@@ -155,28 +161,66 @@ void Execution::waitFinish()
bool Execution::isFinished(void) const { return finished; }
-ir::Shape Execution::getInputShape(ir::IOIndex ind) const
+void Execution::train(uint32_t training_step)
+{
+ auto execs = dynamic_cast<exec::train::TrainableExecutors *>(_executors.get());
+ if (!execs)
+ {
+ throw std::runtime_error{"Supported only TrainableExecutors"};
+ }
+
+ execs->train(_ctx, training_step);
+ finished = true;
+}
+
+float Execution::getLoss(const ir::IOIndex &ind)
{
- auto itr = _io_desc.dynamic_input_shapes.find(ind);
- if (itr == _io_desc.dynamic_input_shapes.end())
+ auto execs = dynamic_cast<exec::train::TrainableExecutors *>(_executors.get());
+ if (!execs)
{
- auto operand_idx = primary_subgraph().getInputs().at(ind.value());
- return primary_subgraph().operands().at(operand_idx).shape();
+ throw std::runtime_error{"Supported only TrainableExecutors"};
}
- else
+
+ return execs->getLoss(ind);
+}
+
+void Execution::iterateTrainableTensors(
+ const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &fn)
+ const
+{
+ auto execs = dynamic_cast<exec::train::TrainableExecutors *>(_executors.get());
+ if (!execs)
{
- return itr->second;
+ throw std::runtime_error{"Supported only TrainableExecutors"};
}
+ execs->iterateTrainableTensors(fn);
}
+ir::Shape Execution::getInputShape(ir::IOIndex ind) const
+{
+ return _ctx.desc.inputs.at(ind.value())->info.shape();
+}
+
+// NNAPI return fail if ANeuralNetworksExecution_getOutputOperandRank or
+// ANeuralNetworksExecution_getOutputOperandDimensions is called before execution.
+// On the other hand, NNFW API return static shape inference result if nnfw_output_tensorinfo is
+// called before execution.
+// To handle both case, this method retun static shape inference result and fail will be handled on
+// NNAPI frontend.
ir::Shape Execution::getOutputShape(ir::IOIndex ind) const
{
- if (!isFinished())
- throw std::runtime_error("Cannot get output shape before execution is finished");
+ return _ctx.desc.outputs.at(ind.value())->info.shape();
+}
- const auto &output_desc = _io_desc.outputs.at(ind.value());
+size_t Execution::getInputTotalSize(ir::IOIndex ind) const
+{
+ // TODO Support dynamic shape
+ return _ctx.desc.inputs.at(ind.value())->info.total_size();
+}
- return output_desc->info.shape();
+size_t Execution::getOutputTotalSize(ir::IOIndex ind) const
+{
+ return _ctx.desc.outputs.at(ind.value())->info.total_size();
}
} // namespace exec
diff --git a/runtime/onert/core/src/exec/Execution.test.cc b/runtime/onert/core/src/exec/Execution.test.cc
new file mode 100644
index 000000000..15f94445a
--- /dev/null
+++ b/runtime/onert/core/src/exec/Execution.test.cc
@@ -0,0 +1,783 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "exec/Execution.h"
+
+#include "compiler/Compiler.h"
+#include "compiler/CompilerFactory.h"
+#include "ir/Graph.h"
+#include "ir/operation/BinaryArithmetic.h"
+#include "util/TracingCtx.h"
+
+#include <gtest/gtest.h>
+#include <thread>
+
+namespace
+{
+
+using namespace onert::ir;
+
+class CompiledMockUpModel
+{
+public:
+ CompiledMockUpModel()
+ {
+ // Model: two elementwise add operation
+ // model input: lhs, rhs1
+ // model output: second add result (result2)
+ // constant: rhs2
+ // result1 <= (lhs + rhs)
+ // result2 <= (result1 + rhs2)
+ // lhs, rhs1, rh2, result1, result2 shape: {1, 2, 2, 1}
+ // activation: none (constant)
+ graph = std::make_shared<Graph>();
+ // 1st add operands (result1 <= lhs + rhs1)
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+ static float rhs2_data[4] = {3, 1, -1, 5};
+ auto operand_lhs = graph->addOperand(shape, type);
+ auto operand_rhs1 = graph->addOperand(shape, type);
+ auto operand_result1 = graph->addOperand(shape, type);
+ auto operand_rhs2 = graph->addOperand(shape, type);
+ auto operand_result2 = graph->addOperand(shape, type);
+ graph->operands()
+ .at(operand_rhs2)
+ .data(std::make_unique<CachedData>(reinterpret_cast<const uint8_t *>(&rhs2_data), 16));
+ // 2nd add operations (result2 <= result1 + rhs2)
+ operation::BinaryArithmetic::Param param1;
+ param1.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param1.activation = Activation::NONE;
+ auto input_set1 = OperandIndexSequence{operand_lhs, operand_rhs1};
+ auto output_set1 = OperandIndexSequence{operand_result1};
+ graph->addOperation(
+ std::make_unique<operation::BinaryArithmetic>(input_set1, output_set1, param1));
+ operation::BinaryArithmetic::Param param2;
+ param2.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param2.activation = Activation::NONE;
+ auto input_set2 = OperandIndexSequence{operand_result1, operand_rhs2};
+ auto output_set2 = OperandIndexSequence{operand_result2};
+ graph->addOperation(
+ std::make_unique<operation::BinaryArithmetic>(input_set2, output_set2, param2));
+ // Identify model inputs and outputs
+ graph->addInput(operand_lhs);
+ graph->addInput(operand_rhs1);
+ graph->addOutput(operand_result2);
+ graph->verify();
+
+ // Compile
+ auto model = std::make_shared<onert::ir::Model>();
+ model->push(onert::ir::SubgraphIndex{0}, graph);
+ coptions = onert::compiler::CompilerOptions::fromGlobalConfig();
+ onert::compiler::Compiler compiler{model, coptions.get()};
+ artifact = compiler.compile();
+ }
+
+public:
+ std::shared_ptr<Graph> graph;
+ std::unique_ptr<onert::compiler::CompilerOptions> coptions;
+ std::shared_ptr<onert::compiler::CompilerArtifact> artifact;
+};
+
+class CompiledMockUpMultiModel
+{
+public:
+ CompiledMockUpMultiModel()
+ {
+ // Model0: a float elementwise add operation
+ // Model0 input: lhs0, rhs0
+ // Model0 output: add result (result0)
+
+ // Model1: a qasymm8 elementwise add operation
+ // Model1 input: result0, rhs1
+ // Model1 output: add result (result1)
+
+ // Model2: a float elementwise add operation
+ // Model2 input: result0, result1
+ // Model2 output: add result (result2)
+
+ // constant: rhs2
+ // result0 <= (lhs0 + rhs0)
+ // result1 <= (result0 + rhs1)
+ // result2 <= (result0 + result1)
+ // lhs0, rhs0, rh1, result0, result1, result2 shape: {1, 2, 2, 1}
+ // activation: none (constant)
+
+ // Update edge information
+ edges.pkg_inputs.emplace_back(ModelIndex{0}, SubgraphIndex{0}, IOIndex{0});
+ edges.pkg_inputs.emplace_back(ModelIndex{0}, SubgraphIndex{0}, IOIndex{1});
+ edges.pkg_outputs.emplace_back(ModelIndex{2}, SubgraphIndex{0}, IOIndex{0});
+ // From
+ const auto result0 = IODesc{ModelIndex{0}, SubgraphIndex{0}, IOIndex{0}};
+ const auto result1 = IODesc{ModelIndex{1}, SubgraphIndex{0}, IOIndex{0}};
+ // To
+ const auto lhs1 = IODesc{ModelIndex{1}, SubgraphIndex{0}, IOIndex{0}};
+ const auto lhs2 = IODesc{ModelIndex{2}, SubgraphIndex{0}, IOIndex{0}};
+ const auto rhs2 = IODesc{ModelIndex{2}, SubgraphIndex{0}, IOIndex{1}};
+ edges.edges.insert({result0, lhs1});
+ edges.edges.insert({result0, lhs2});
+ edges.edges.insert({result1, rhs2});
+
+ for (size_t i = 0; i < 3; ++i)
+ {
+ graphs.emplace_back(std::make_shared<Graph>());
+ }
+ Shape shape{1, 2, 2, 1};
+
+ // Model0's add operands (result1 <= lhs0 + rhs0)
+ DataType types[3] = {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::FLOAT32};
+ auto operand_lhs0 = graphs[0]->addOperand(shape, TypeInfo{types[0]});
+ auto operand_rhs0 = graphs[0]->addOperand(shape, TypeInfo{types[0]});
+ auto operand_result0 = graphs[0]->addOperand(shape, TypeInfo{types[0]});
+
+ // Model0's add operation
+ operation::BinaryArithmetic::Param param0;
+ param0.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param0.activation = Activation::NONE;
+ auto input_set0 = OperandIndexSequence{operand_lhs0, operand_rhs0};
+ auto output_set0 = OperandIndexSequence{operand_result0};
+ graphs[0]->addOperation(
+ std::make_unique<operation::BinaryArithmetic>(input_set0, output_set0, param0));
+
+ // Model0's inputs/outputs
+ graphs[0]->addInput(operand_lhs0);
+ graphs[0]->addInput(operand_rhs0);
+ graphs[0]->addOutput(operand_result0);
+ graphs[0]->verify();
+
+ // Model1's add operands (result2 <= Model0 result + rhs1)
+ // static float rhs1_data[4] = {3, 1, -1, 5};
+ static uint8_t rhs1_data[4] = {131, 129, 127, 133};
+ const float scale = 1;
+ const int32_t zero_point = 128;
+ auto operand_lhs1 = graphs[1]->addOperand(shape, TypeInfo{types[1], scale, zero_point});
+ auto operand_rhs1 = graphs[1]->addOperand(shape, TypeInfo{types[1], scale, zero_point});
+ auto operand_result1 = graphs[1]->addOperand(shape, TypeInfo{types[1], scale, zero_point});
+ graphs[1]
+ ->operands()
+ .at(operand_rhs1)
+ .data(std::make_unique<CachedData>(reinterpret_cast<const uint8_t *>(&rhs1_data), 4));
+
+ // Model1's add operation
+ operation::BinaryArithmetic::Param param1;
+ param1.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param1.activation = Activation::NONE;
+ auto input_set1 = OperandIndexSequence{operand_lhs1, operand_rhs1};
+ auto output_set1 = OperandIndexSequence{operand_result1};
+ graphs[1]->addOperation(
+ std::make_unique<operation::BinaryArithmetic>(input_set1, output_set1, param1));
+
+ // Model1's inputs/outputs
+ graphs[1]->addInput(operand_lhs1);
+ graphs[1]->addOutput(operand_result1);
+ graphs[1]->verify();
+
+ // Model2's additional operands (result3 <= Model0 result + Model1 result)
+ auto operand_lhs2 = graphs[2]->addOperand(shape, TypeInfo{types[2]});
+ auto operand_rhs2 = graphs[2]->addOperand(shape, TypeInfo{types[2]});
+ auto operand_result2 = graphs[2]->addOperand(shape, TypeInfo{types[2]});
+
+ // Model2's add operation
+ operation::BinaryArithmetic::Param param2;
+ param2.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param2.activation = Activation::NONE;
+ auto input_set2 = OperandIndexSequence{operand_lhs2, operand_rhs2};
+ auto output_set2 = OperandIndexSequence{operand_result2};
+ graphs[2]->addOperation(
+ std::make_unique<operation::BinaryArithmetic>(input_set2, output_set2, param2));
+
+ // Model1's inputs/outputs
+ graphs[2]->addInput(operand_lhs2);
+ graphs[2]->addInput(operand_rhs2);
+ graphs[2]->addOutput(operand_result2);
+ graphs[2]->verify();
+
+ // Compile
+ compile();
+ }
+
+public:
+ void compile()
+ {
+ auto nnpkg = std::make_shared<onert::ir::NNPkg>();
+ coptions = onert::compiler::CompilerOptions::fromGlobalConfig();
+
+ for (uint16_t i = 0; i < graphs.size(); ++i)
+ {
+ auto model = std::make_shared<onert::ir::Model>();
+ model->push(SubgraphIndex{0}, graphs[i]);
+
+ nnpkg->push(onert::ir::ModelIndex{i}, std::move(model));
+ }
+ for (const auto &pkg_input : edges.pkg_inputs)
+ {
+ nnpkg->addInput(pkg_input);
+ }
+ for (const auto &pkg_output : edges.pkg_outputs)
+ {
+ nnpkg->addOutput(pkg_output);
+ }
+ for (const auto &edge : edges.edges)
+ {
+ nnpkg->addEdge(edge.from, edge.to);
+ }
+ auto compiler = onert::compiler::CompilerFactory::get().create(nnpkg, coptions.get());
+ nnpkg.reset();
+ artifact = compiler->compile();
+ }
+
+public:
+ std::vector<std::shared_ptr<Graph>> graphs;
+ std::unique_ptr<onert::compiler::CompilerOptions> coptions;
+ std::shared_ptr<onert::compiler::CompilerArtifact> artifact;
+ ModelEdges edges;
+};
+
+class CompiledMockUpQuantModel
+{
+public:
+ CompiledMockUpQuantModel()
+ {
+ // Model: two elementwise add operation
+ // model input: lhs, rhs1
+ // model output: second add result (result2)
+ // constant: rhs2
+ // result1 <= (lhs + rhs)
+ // result2 <= (result1 + rhs2)
+ // lhs, rhs1, rh2, result1, result2 shape: {1, 2, 2, 1}
+ // activation: none (constant)
+ graph = std::make_shared<Graph>();
+ // 1st add operands (result1 <= lhs + rhs1)
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::QUANT_UINT8_ASYMM, 1.0f, 128};
+ static uint8_t rhs2_data[4] = {131, 129, 127, 133};
+ auto operand_lhs = graph->addOperand(shape, type);
+ auto operand_rhs1 = graph->addOperand(shape, type);
+ auto operand_result1 = graph->addOperand(shape, type);
+ auto operand_rhs2 = graph->addOperand(shape, type);
+ auto operand_result2 = graph->addOperand(shape, type);
+ graph->operands()
+ .at(operand_rhs2)
+ .data(std::make_unique<CachedData>(reinterpret_cast<const uint8_t *>(&rhs2_data), 4));
+ // 2nd add operations (result2 <= result1 + rhs2)
+ operation::BinaryArithmetic::Param param1;
+ param1.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param1.activation = Activation::NONE;
+ auto input_set1 = OperandIndexSequence{operand_lhs, operand_rhs1};
+ auto output_set1 = OperandIndexSequence{operand_result1};
+ graph->addOperation(
+ std::make_unique<operation::BinaryArithmetic>(input_set1, output_set1, param1));
+ operation::BinaryArithmetic::Param param2;
+ param2.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param2.activation = Activation::NONE;
+ auto input_set2 = OperandIndexSequence{operand_result1, operand_rhs2};
+ auto output_set2 = OperandIndexSequence{operand_result2};
+ graph->addOperation(
+ std::make_unique<operation::BinaryArithmetic>(input_set2, output_set2, param2));
+ // Identify model inputs and outputs
+ graph->addInput(operand_lhs);
+ graph->addInput(operand_rhs1);
+ graph->addOutput(operand_result2);
+ graph->verify();
+
+ // Compile
+ auto model = std::make_shared<onert::ir::Model>();
+ model->push(onert::ir::SubgraphIndex{0}, graph);
+ coptions = onert::compiler::CompilerOptions::fromGlobalConfig();
+ onert::compiler::Compiler compiler{model, coptions.get()};
+ artifact = compiler.compile();
+ }
+
+public:
+ std::shared_ptr<Graph> graph;
+ std::unique_ptr<onert::compiler::CompilerOptions> coptions;
+ std::shared_ptr<onert::compiler::CompilerArtifact> artifact;
+};
+
+TEST(ExecInstance, simple)
+{
+ auto mockup = CompiledMockUpModel();
+ auto graph = mockup.graph;
+ auto executors = mockup.artifact->_executors;
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output = IOIndex{0};
+
+ const float input1_buffer[4] = {1, 0, -1, -2};
+ const float input2_buffer[4] = {1, -3, 2, -4};
+ float output_buffer[4] = {};
+ const float output_expected[4] = {5, -2, 0, -1};
+
+ onert::exec::Execution execution{executors};
+
+ execution.setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16);
+ execution.setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16);
+ execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 16);
+ execution.execute();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(output_buffer[i], output_expected[i]);
+ }
+}
+
+TEST(ExecInstance, neg_small_outputbuffer)
+{
+ auto mockup = CompiledMockUpModel();
+ auto graph = mockup.graph;
+ auto executors = mockup.artifact->_executors;
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output = IOIndex{0};
+
+ const float input1_buffer[4] = {1, 0, -1, -2};
+ const float input2_buffer[4] = {1, -3, 2, -4};
+ float output_buffer[2] = {};
+
+ onert::exec::Execution execution{executors};
+
+ execution.setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16);
+ execution.setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16);
+ execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 8);
+ EXPECT_ANY_THROW(execution.execute());
+}
+
+TEST(ExecInstance, neg_small_inoutsize)
+{
+ auto mockup = CompiledMockUpModel();
+ auto graph = mockup.graph;
+ auto executors = mockup.artifact->_executors;
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output = IOIndex{0};
+
+ const float input1_buffer[2] = {1, 0};
+ const float input2_buffer[2] = {1, -3};
+ const auto new_shape = onert::ir::Shape({1, 1, 2, 1});
+ float output_buffer[2] = {};
+
+ onert::exec::Execution execution{executors};
+
+ execution.setInput(input1, new_shape, reinterpret_cast<const void *>(input1_buffer), 8);
+ execution.setInput(input2, new_shape, reinterpret_cast<const void *>(input2_buffer), 2);
+ EXPECT_THROW(execution.execute(), std::exception);
+
+ // Not throw exception because input shape is changed and output buffer is enough
+ execution.setInput(input2, new_shape, reinterpret_cast<const void *>(input2_buffer), 8);
+ execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 16);
+ execution.execute();
+
+ execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 8);
+ // Throw exception by shape inference because output buffer size is small:
+ // output shape is {1, 2, 2, 1}
+ EXPECT_THROW(execution.execute(), std::exception);
+}
+
+TEST(ExecInstance, twoCompile)
+{
+ auto mockup = CompiledMockUpModel();
+ auto graph = mockup.graph;
+ auto executors1 = mockup.artifact->_executors;
+ onert::exec::Execution execution1{executors1};
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output = IOIndex{0};
+
+ const float exe1_input1_buffer[4] = {1, 0, -1, -2};
+ const float exe1_input2_buffer[4] = {1, -3, 2, -4};
+ float exe1_output_buffer[4] = {};
+ const float exe1_output_expected[4] = {5, -2, 0, -1};
+
+ execution1.setInput(input1, reinterpret_cast<const void *>(exe1_input1_buffer), 16);
+ execution1.setInput(input2, reinterpret_cast<const void *>(exe1_input2_buffer), 16);
+ execution1.setOutput(output, reinterpret_cast<void *>(exe1_output_buffer), 16);
+
+ // Make new executor: compile again
+ auto model = std::make_shared<onert::ir::Model>();
+ model->push(onert::ir::SubgraphIndex{0}, graph);
+ auto coptions = onert::compiler::CompilerOptions::fromGlobalConfig();
+ onert::compiler::Compiler compiler{model, coptions.get()};
+ std::shared_ptr<onert::compiler::CompilerArtifact> artifact = compiler.compile();
+ onert::exec::Execution execution2{artifact->_executors};
+
+ const float exe2_input1_buffer[4] = {2, 1, -2, 0};
+ const float exe2_input2_buffer[4] = {-3, 3, 1, 2};
+ float exe2_output_buffer[4] = {};
+ const float exe2_output_expected[4] = {2, 5, -2, 7};
+
+ execution2.setInput(input1, reinterpret_cast<const void *>(exe2_input1_buffer), 16);
+ execution2.setInput(input2, reinterpret_cast<const void *>(exe2_input2_buffer), 16);
+ execution2.setOutput(output, reinterpret_cast<void *>(exe2_output_buffer), 16);
+
+ execution1.execute();
+ execution2.execute();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(exe1_output_buffer[i], exe1_output_expected[i]);
+ EXPECT_EQ(exe2_output_buffer[i], exe2_output_expected[i]);
+ }
+}
+
+// Support two initialized execution instance then ordered execution
+TEST(ExecInstance, twoExecution)
+{
+ auto mockup = CompiledMockUpModel();
+ auto executors = mockup.artifact->_executors;
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output1 = IOIndex{0};
+
+ const float exe1_input1_buffer[4] = {1, 0, -1, -2};
+ const float exe1_input2_buffer[4] = {1, -3, 2, -4};
+ float exe1_output_buffer[4] = {};
+ const float exe1_output_expected[4] = {5, -2, 0, -1};
+ const float exe2_output_expected[4] = {2, 5, -2, 7};
+
+ onert::exec::Execution execution1{executors};
+ execution1.setInput(input1, reinterpret_cast<const void *>(exe1_input1_buffer), 16);
+ execution1.setInput(input2, reinterpret_cast<const void *>(exe1_input2_buffer), 16);
+ execution1.setOutput(output1, reinterpret_cast<void *>(exe1_output_buffer), 16);
+
+ const float exe2_input1_buffer[4] = {2, 1, -2, 0};
+ const float exe2_input2_buffer[4] = {-3, 3, 1, 2};
+ float exe2_output_buffer[4] = {};
+
+ // Make new execution
+ onert::exec::Execution execution2{executors};
+ execution2.setInput(input1, reinterpret_cast<const void *>(exe2_input1_buffer), 16);
+ execution2.setInput(input2, reinterpret_cast<const void *>(exe2_input2_buffer), 16);
+ execution2.setOutput(output1, reinterpret_cast<void *>(exe2_output_buffer), 16);
+
+ execution1.execute();
+ execution2.execute();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(exe1_output_buffer[i], exe1_output_expected[i]);
+ EXPECT_EQ(exe2_output_buffer[i], exe2_output_expected[i]);
+ }
+}
+
+TEST(ExecInstance, quantModel_floatIO)
+{
+ auto mockup = CompiledMockUpQuantModel();
+ auto graph = mockup.graph;
+ auto executors = mockup.artifact->_executors;
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output = IOIndex{0};
+
+ const float input1_buffer[4] = {1, 0, -1, -2};
+ const float input2_buffer[4] = {1, -3, 2, -4};
+ float output_buffer[4] = {};
+ const float output_expected[4] = {5, -2, 0, -1};
+
+ onert::exec::Execution execution{executors};
+
+ execution.setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16);
+ execution.setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16);
+ execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 16);
+ execution.setInputType(input1, onert::ir::TypeInfo{onert::ir::DataType::FLOAT32});
+ execution.setInputType(input2, onert::ir::TypeInfo{onert::ir::DataType::FLOAT32});
+ execution.setOutputType(output, onert::ir::TypeInfo{onert::ir::DataType::FLOAT32});
+ execution.execute();
+
+ EXPECT_EQ(output_buffer[0], output_expected[0]);
+ EXPECT_EQ(output_buffer[1], output_expected[1]);
+ EXPECT_EQ(output_buffer[2], output_expected[2]);
+ EXPECT_EQ(output_buffer[3], output_expected[3]);
+}
+
+class Inference
+{
+public:
+ Inference(const float (&input1)[4], const float (&input2)[4], float (&output)[4],
+ std::shared_ptr<onert::exec::IExecutors> &executors)
+ : _input1{input1}, _input2{input2}, _output{output}, _executors{executors}
+ {
+ // DO NOTHING
+ }
+
+ void inference(void)
+ {
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output1 = IOIndex{0};
+
+ onert::exec::Execution execution{_executors};
+ execution.setInput(input1, reinterpret_cast<const void *>(_input1), 16);
+ execution.setInput(input2, reinterpret_cast<const void *>(_input2), 16);
+ execution.setOutput(output1, reinterpret_cast<void *>(_output), 16);
+
+ execution.execute();
+ }
+
+private:
+ const float (&_input1)[4];
+ const float (&_input2)[4];
+ float (&_output)[4];
+ std::shared_ptr<onert::exec::IExecutors> &_executors;
+};
+
+// Support multi-thread execution
+TEST(ExecInstance, twoThreads)
+{
+ auto mockup = CompiledMockUpModel();
+ auto executors = mockup.artifact->_executors;
+
+ const float exe1_input1_buffer[4] = {1, 0, -1, -2};
+ const float exe1_input2_buffer[4] = {1, -3, 2, -4};
+ float exe1_output_buffer[4] = {};
+ const float exe1_output_expected[4] = {5, -2, 0, -1};
+
+ Inference execution1{exe1_input1_buffer, exe1_input2_buffer, exe1_output_buffer, executors};
+
+ const float exe2_input1_buffer[4] = {2, 1, -2, 0};
+ const float exe2_input2_buffer[4] = {-3, 3, 1, 2};
+ float exe2_output_buffer[4] = {};
+ const float exe2_output_expected[4] = {2, 5, -2, 7};
+
+ Inference execution2{exe2_input1_buffer, exe2_input2_buffer, exe2_output_buffer, executors};
+
+ std::thread t1{&Inference::inference, &execution1};
+ std::thread t2{&Inference::inference, &execution2};
+
+ t1.join();
+ t2.join();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(exe1_output_buffer[i], exe1_output_expected[i]);
+ EXPECT_EQ(exe2_output_buffer[i], exe2_output_expected[i]);
+ }
+}
+
+// Support asynchronous execution
+TEST(ExecInstance, async)
+{
+ auto mockup = CompiledMockUpModel();
+ auto graph = mockup.graph;
+ auto executors = mockup.artifact->_executors;
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output = IOIndex{0};
+
+ const float input1_buffer[4] = {1, 0, -1, -2};
+ const float input2_buffer[4] = {1, -3, 2, -4};
+ float output_buffer[4] = {};
+ const float output_expected[4] = {5, -2, 0, -1};
+
+ onert::exec::Execution execution{executors};
+
+ execution.setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16);
+ execution.setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16);
+ execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 16);
+ execution.startExecute();
+ execution.waitFinish();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(output_buffer[i], output_expected[i]);
+ }
+}
+
+TEST(ExecInstance, multi_model_simple)
+{
+ auto mockup = CompiledMockUpMultiModel();
+ auto executors = mockup.artifact->_executors;
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output = IOIndex{0};
+
+ const float input1_buffer[4] = {1, 0, -1, -2};
+ const float input2_buffer[4] = {1, -3, 2, -4};
+ float output_buffer[4] = {};
+ const float output_expected[4] = {7, -5, 1, -7};
+
+ onert::exec::Execution execution{executors};
+
+ execution.setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16);
+ execution.setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16);
+ execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 16);
+ execution.execute();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(output_buffer[i], output_expected[i]);
+ }
+}
+
+TEST(ExecInstance, multi_model_twoCompile)
+{
+ auto mockup = CompiledMockUpMultiModel();
+ auto executors1 = mockup.artifact->_executors;
+ onert::exec::Execution execution1{executors1};
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output = IOIndex{0};
+
+ const float exe1_input1_buffer[4] = {1, 0, -1, -2};
+ const float exe1_input2_buffer[4] = {1, -3, 2, -4};
+ float exe1_output_buffer[4] = {};
+ const float exe1_output_expected[4] = {7, -5, 1, -7};
+
+ execution1.setInput(input1, reinterpret_cast<const void *>(exe1_input1_buffer), 16);
+ execution1.setInput(input2, reinterpret_cast<const void *>(exe1_input2_buffer), 16);
+ execution1.setOutput(output, reinterpret_cast<void *>(exe1_output_buffer), 16);
+
+ // Make new executor: compile again
+ mockup.compile();
+ onert::exec::Execution execution2{mockup.artifact->_executors};
+
+ const float exe2_input1_buffer[4] = {2, 1, -2, 0};
+ const float exe2_input2_buffer[4] = {-3, 3, 1, 2};
+ float exe2_output_buffer[4] = {};
+ const float exe2_output_expected[4] = {1, 9, -3, 9};
+
+ execution2.setInput(input1, reinterpret_cast<const void *>(exe2_input1_buffer), 16);
+ execution2.setInput(input2, reinterpret_cast<const void *>(exe2_input2_buffer), 16);
+ execution2.setOutput(output, reinterpret_cast<void *>(exe2_output_buffer), 16);
+
+ execution1.execute();
+ execution2.execute();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(exe1_output_buffer[i], exe1_output_expected[i]);
+ EXPECT_EQ(exe2_output_buffer[i], exe2_output_expected[i]);
+ }
+}
+
+// Support two initialized execution instance then ordered execution
+TEST(ExecInstance, multi_model_twoExecution)
+{
+ auto mockup = CompiledMockUpMultiModel();
+ auto executors = mockup.artifact->_executors;
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output1 = IOIndex{0};
+
+ const float exe1_input1_buffer[4] = {1, 0, -1, -2};
+ const float exe1_input2_buffer[4] = {1, -3, 2, -4};
+ float exe1_output_buffer[4] = {};
+ const float exe1_output_expected[4] = {7, -5, 1, -7};
+ const float exe2_output_expected[4] = {1, 9, -3, 9};
+
+ onert::exec::Execution execution1{executors};
+ execution1.setInput(input1, reinterpret_cast<const void *>(exe1_input1_buffer), 16);
+ execution1.setInput(input2, reinterpret_cast<const void *>(exe1_input2_buffer), 16);
+ execution1.setOutput(output1, reinterpret_cast<void *>(exe1_output_buffer), 16);
+
+ const float exe2_input1_buffer[4] = {2, 1, -2, 0};
+ const float exe2_input2_buffer[4] = {-3, 3, 1, 2};
+ float exe2_output_buffer[4] = {};
+
+ // Make new execution
+ onert::exec::Execution execution2{executors};
+ execution2.setInput(input1, reinterpret_cast<const void *>(exe2_input1_buffer), 16);
+ execution2.setInput(input2, reinterpret_cast<const void *>(exe2_input2_buffer), 16);
+ execution2.setOutput(output1, reinterpret_cast<void *>(exe2_output_buffer), 16);
+
+ execution1.execute();
+ execution1.execute();
+ execution2.execute();
+ execution2.execute();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(exe1_output_buffer[i], exe1_output_expected[i]);
+ EXPECT_EQ(exe2_output_buffer[i], exe2_output_expected[i]);
+ }
+}
+
+// Multi-model is not thread-safe yet
+
+// Support asynchronous execution
+TEST(ExecInstance, multi_model_async)
+{
+ auto mockup = CompiledMockUpMultiModel();
+ auto executors = mockup.artifact->_executors;
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output = IOIndex{0};
+
+ const float input1_buffer[4] = {1, 0, -1, -2};
+ const float input2_buffer[4] = {1, -3, 2, -4};
+ float output_buffer[4] = {};
+ const float output_expected[4] = {7, -5, 1, -7};
+
+ onert::exec::Execution execution{executors};
+
+ execution.setInput(input1, reinterpret_cast<const void *>(input1_buffer), 16);
+ execution.setInput(input2, reinterpret_cast<const void *>(input2_buffer), 16);
+ execution.setOutput(output, reinterpret_cast<void *>(output_buffer), 16);
+ execution.startExecute();
+ execution.waitFinish();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(output_buffer[i], output_expected[i]);
+ }
+}
+
+TEST(ExecInstance, multi_model_dequant_input_quant_output)
+{
+ auto mockup = CompiledMockUpMultiModel();
+ auto executors = mockup.artifact->_executors;
+
+ auto input1 = IOIndex{0};
+ auto input2 = IOIndex{1};
+ auto output = IOIndex{0};
+
+ const uint8_t input1_buffer[4] = {138, 128, 118, 108}; // {1, 0, -1, -2}
+ const uint8_t input2_buffer[4] = {138, 98, 148, 88}; // {1, -3, 2, -4}
+ uint8_t output_buffer[4] = {};
+ const uint8_t output_expected[4] = {198, 78, 138, 58}; // {7, -5, 1, -7}
+ float scale = 0.1;
+ int32_t zero_point = 128;
+
+ onert::exec::Execution execution{executors};
+
+ onert::ir::TypeInfo type_info{onert::ir::DataType::QUANT_UINT8_ASYMM, scale, zero_point};
+ execution.setInputType(input1, type_info);
+ execution.setInput(input1, execution.getInputShape(input1),
+ reinterpret_cast<const void *>(input1_buffer), 4);
+ execution.setInputType(input2, type_info);
+ execution.setInput(input2, execution.getInputShape(input2),
+ reinterpret_cast<const void *>(input2_buffer), 4);
+ execution.setOutputType(output, type_info);
+ execution.setOutput(output, execution.getOutputShape(output),
+ reinterpret_cast<void *>(output_buffer), 4);
+ execution.execute();
+
+ for (auto i = 0; i < 4; i++)
+ {
+ EXPECT_EQ(output_buffer[i], output_expected[i]);
+ }
+}
+
+// TODO Add an unittest multi_model_quant_input_dequant_output
+
+} // namespace
diff --git a/runtime/onert/core/src/exec/ExecutionContext.cc b/runtime/onert/core/src/exec/ExecutionContext.cc
new file mode 100644
index 000000000..aec10ee5b
--- /dev/null
+++ b/runtime/onert/core/src/exec/ExecutionContext.cc
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "exec/ExecutionContext.h"
+
+#include "util/ConfigSource.h"
+
+namespace onert
+{
+namespace exec
+{
+
+void ExecutionOptions::fromGlobalConfig(ExecutionOptions &options)
+{
+ options.dump_minmax = util::getConfigBool(util::config::MINMAX_DUMP);
+ options.trace = util::getConfigBool(util::config::TRACING_MODE);
+ options.profile = util::getConfigBool(util::config::PROFILING_MODE);
+}
+
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/ExecutionObservee.cc b/runtime/onert/core/src/exec/ExecutionObservee.cc
index ddb1fb6a0..22881b8c8 100644
--- a/runtime/onert/core/src/exec/ExecutionObservee.cc
+++ b/runtime/onert/core/src/exec/ExecutionObservee.cc
@@ -21,42 +21,72 @@ namespace onert
namespace exec
{
-void ExecutionObservee::add(std::unique_ptr<IExecutionObserver> observer)
+ExecutionObservee::ExecutionObservee(const ExecObservers &observers,
+ const ExecutionOptions &options)
{
- _observers.emplace_back(std::move(observer));
+ // TODO Use execution option
+ if (options.dump_minmax)
+ {
+ auto observer = observers.get(ObserverType::MINMAX_DUMP);
+ if (!observer)
+ throw std::runtime_error{"MinMaxRecorder is only supported on LinearExecutor, single model"};
+
+ _observers.emplace_back(observer);
+ }
+
+ if (options.trace)
+ {
+ auto observer = observers.get(ObserverType::TRACING);
+ if (!observer)
+ throw std::runtime_error{"Cannot find TracingObserver"};
+
+ _observers.emplace_back(observer);
+ }
+
+ if (options.profile)
+ {
+ auto observer = observers.get(ObserverType::PROFILE);
+ if (!observer)
+ throw std::runtime_error{
+ "Profiling is only supported on DataflowExecutor with heterogenous scheduler"};
+
+ _observers.emplace_back(observer);
+ }
}
-void ExecutionObservee::notifyModelBegin(IExecutor *executor)
+void ExecutionObservee::notifySubgraphBegin(ir::SubgraphIndex ind) const
{
- for (auto &o : _observers)
+ for (auto &&o : _observers)
{
- o->handleBegin(executor);
+ o->handleSubgraphBegin(ind);
}
}
-void ExecutionObservee::notifyModelEnd(IExecutor *executor)
+void ExecutionObservee::notifySubgraphEnd(ir::SubgraphIndex ind) const
{
- for (auto &o : _observers)
+ for (auto &&o : _observers)
{
- o->handleEnd(executor);
+ o->handleSubgraphEnd(ind);
}
}
-void ExecutionObservee::notifyJobBegin(IExecutor *executor, const ir::OpSequence *op_seq,
- const backend::Backend *backend)
+void ExecutionObservee::notifyJobBegin(IExecutor *executor, ir::SubgraphIndex subg_ind,
+ ir::OperationIndex op_ind,
+ const backend::Backend *backend) const
{
- for (auto &o : _observers)
+ for (auto &&o : _observers)
{
- o->handleBegin(executor, op_seq, backend);
+ o->handleJobBegin(executor, subg_ind, op_ind, backend);
}
}
-void ExecutionObservee::notifyJobEnd(IExecutor *executor, const ir::OpSequence *op_seq,
- const backend::Backend *backend)
+void ExecutionObservee::notifyJobEnd(IExecutor *executor, ir::SubgraphIndex subg_ind,
+ ir::OperationIndex op_ind,
+ const backend::Backend *backend) const
{
- for (auto &o : _observers)
+ for (auto &&o : _observers)
{
- o->handleEnd(executor, op_seq, backend);
+ o->handleJobEnd(executor, subg_ind, op_ind, backend);
}
}
diff --git a/runtime/onert/core/src/exec/ExecutionObservee.h b/runtime/onert/core/src/exec/ExecutionObservee.h
index 49d409a3a..e6461c788 100644
--- a/runtime/onert/core/src/exec/ExecutionObservee.h
+++ b/runtime/onert/core/src/exec/ExecutionObservee.h
@@ -17,9 +17,11 @@
#ifndef __ONERT_EXEC_EXECUTION_OBSERVEE_H__
#define __ONERT_EXEC_EXECUTION_OBSERVEE_H__
-#include <list>
+#include "ExecutionObservers.h"
+
+#include "ir/Index.h"
-#include "exec/ExecutionObservers.h"
+#include <list>
namespace onert
{
@@ -34,20 +36,21 @@ class ExecutionObservee
{
public:
/**
- * @brief Register an observer
+ * @brief Register enabled observers
*
- * @param observer Observer to be added
+ * @param observer Observers generated by compiler
*/
- void add(std::unique_ptr<IExecutionObserver> observer);
- void notifyModelBegin(IExecutor *executor);
- void notifyModelEnd(IExecutor *executor);
- void notifyJobBegin(IExecutor *executor, const ir::OpSequence *op_seq,
- const backend::Backend *backend);
- void notifyJobEnd(IExecutor *executor, const ir::OpSequence *op_seq,
- const backend::Backend *backend);
+ ExecutionObservee(const ExecObservers &observers, const ExecutionOptions &options);
+ void notifySubgraphBegin(ir::SubgraphIndex ind) const;
+ void notifySubgraphEnd(ir::SubgraphIndex ind) const;
+ void notifyJobBegin(IExecutor *executor, ir::SubgraphIndex subg_ind, ir::OperationIndex op_ind,
+ const backend::Backend *backend) const;
+ void notifyJobEnd(IExecutor *executor, ir::SubgraphIndex subg_ind, ir::OperationIndex op_ind,
+ const backend::Backend *backend) const;
+ bool isEmpty() const { return _observers.size() == 0; }
private:
- std::list<std::unique_ptr<IExecutionObserver>> _observers;
+ std::list<IExecutionObserver *> _observers;
};
} // namespace exec
diff --git a/runtime/onert/core/src/exec/ExecutionObservers.cc b/runtime/onert/core/src/exec/ExecutionObservers.cc
index 060f874de..a58daeabd 100644
--- a/runtime/onert/core/src/exec/ExecutionObservers.cc
+++ b/runtime/onert/core/src/exec/ExecutionObservers.cc
@@ -14,14 +14,58 @@
* limitations under the License.
*/
-#include "exec/ExecutionObservers.h"
+#include "ExecutionObservers.h"
-#include <string>
+#include "../util/EventWriter.h"
#include "util/logging.h"
-#include "exec/IExecutor.h"
-#include "misc/polymorphic_downcast.h"
-#include "ir/OpSequence.h"
+
+#include <misc/polymorphic_downcast.h>
+
+#include <string>
+#include <sstream>
+
+namespace
+{
+
+void setUserData(const onert::ir::Graph &g, const onert::ir::IOperation *op,
+ decltype(EventCollector::Event::userData) &data)
+{
+ // From a tensor of shape [a, b, c], this will return a string "shape(a b c)".
+ // String like "[1, 2, 3]" looks better but this will be considered as a list in Json
+ // so text search (e.g., Ctrl-F in Chrome Tracing) could be difficult
+ auto build_shape_str = [&](onert::ir::OperandIndex operand_idx) {
+ std::string shape_str;
+ auto &shape = g.operands().at(operand_idx).info().shape();
+ for (int i = 0; i < shape.rank(); i++)
+ {
+ if (i == 0)
+ shape_str = "shape(" + std::to_string(shape.dim(i));
+ else
+ shape_str += " " + std::to_string(shape.dim(i));
+ }
+ shape_str += ")";
+
+ return shape_str;
+ };
+
+ auto &inputs = op->getInputs();
+ auto size = inputs.size();
+ for (size_t i = 0; i < size; i++)
+ {
+ auto operand_idx = inputs.at(i);
+ if (operand_idx.undefined())
+ continue;
+
+ std::string key("input_shape_" + std::to_string(i));
+ std::string value = build_shape_str(operand_idx);
+ data.emplace_back(std::make_pair(key, value));
+ }
+
+ // add other userData as needed
+}
+
+} // namespace
namespace onert
{
@@ -29,8 +73,8 @@ namespace onert
namespace exec
{
-void ProfileObserver::handleBegin(onert::exec::IExecutor *, const ir::OpSequence *,
- const onert::backend::Backend *backend)
+void ProfileObserver::handleJobBegin(onert::exec::IExecutor *, ir::SubgraphIndex,
+ ir::OperationIndex, const onert::backend::Backend *backend)
{
_timer = backend->config()->timer();
if (_timer == nullptr)
@@ -38,14 +82,14 @@ void ProfileObserver::handleBegin(onert::exec::IExecutor *, const ir::OpSequence
_timer->handleBegin();
}
-void ProfileObserver::handleEnd(IExecutor *exec, const ir::OpSequence *op_seq,
- const backend::Backend *backend)
+void ProfileObserver::handleJobEnd(IExecutor *exec, ir::SubgraphIndex,
+ const ir::OperationIndex op_ind, const backend::Backend *backend)
{
_timer->handleEnd();
const auto timer_res = _timer->getTime();
- // NOTE This assumes there is just one operation in a op_seq
- const auto &node = _graph.operations().at(op_seq->operations().at(0));
+ // NOTE This assumes there is just one operation in a op
+ const auto &node = _graph.operations().at(op_ind);
auto node_name = node.name();
VERBOSE(ProfileInfo) << "Time for " << node_name << " : " << timer_res << std::endl;
@@ -54,7 +98,7 @@ void ProfileObserver::handleEnd(IExecutor *exec, const ir::OpSequence *op_seq,
ir::DataType::QUANT_UINT8_ASYMM;
uint32_t size = 0;
- for (const auto &ind : node.getInputs() + node.getOutputs())
+ for (const auto &ind : (node.getInputs() + node.getOutputs()) | ir::Remove::UNDEFINED)
{
size += exec->graph().operands().at(ind).info().total_size();
}
@@ -69,64 +113,66 @@ void ProfileObserver::handleEnd(IExecutor *exec, const ir::OpSequence *op_seq,
}
};
-ChromeTracingObserver::ChromeTracingObserver(const std::string &filepath, const ir::Graph &graph)
- : _ofs{filepath, std::ofstream::out}, _recorder{}, _collector{&_recorder}, _graph{graph}
+TracingObserver::TracingObserver(const std::string &workspace_dir, const ir::Graph &graph,
+ const util::TracingCtx *tracing_ctx)
+ : _recorder{std::make_unique<EventRecorder>()}, _collector{_recorder.get()}, _graph{graph},
+ _workspace_dir{workspace_dir}, _tracing_ctx{tracing_ctx}, _triggered{false}
{
+ // DO NOTHING
}
-ChromeTracingObserver::~ChromeTracingObserver()
+TracingObserver::~TracingObserver()
{
try
{
- _recorder.writeToFile(_ofs);
+ // Write file if this observer is triggered at least once
+ if (_triggered)
+ {
+ auto event_writer = EventWriter::get(_workspace_dir);
+ event_writer->startToUse();
+ event_writer->readyToFlush(std::move(_recorder));
+ }
}
catch (const std::exception &e)
{
- std::cerr << "E: Fail to record event in ChromeTracingObserver: " << e.what() << std::endl;
+ std::cerr << "E: Fail to record event in TracingObserver: " << e.what() << std::endl;
}
}
-void ChromeTracingObserver::handleBegin(IExecutor *)
+void TracingObserver::handleSubgraphBegin(ir::SubgraphIndex subg_ind)
{
- _collector.onEvent(EventCollector::Event{EventCollector::Edge::BEGIN, "runtime", "Graph"});
-}
+ _triggered = true;
-void ChromeTracingObserver::handleBegin(IExecutor *, const ir::OpSequence *op_seq,
- const backend::Backend *backend)
-{
- std::string backend_id = backend->config()->id();
- _collector.onEvent(EventCollector::Event{EventCollector::Edge::BEGIN, backend_id,
- opSequenceTag(op_seq, _graph.operations())});
+ _collector.onEvent(
+ EventCollector::SubgEvent{_tracing_ctx, EventCollector::Edge::BEGIN, subg_ind.value()});
}
-void ChromeTracingObserver::handleEnd(IExecutor *, const ir::OpSequence *op_seq,
- const backend::Backend *backend)
+void TracingObserver::handleJobBegin(IExecutor *, ir::SubgraphIndex subg_ind,
+ ir::OperationIndex op_ind, const backend::Backend *backend)
{
std::string backend_id = backend->config()->id();
- _collector.onEvent(EventCollector::Event{EventCollector::Edge::END, backend_id,
- opSequenceTag(op_seq, _graph.operations())});
+ const auto &op = _graph.operations().at(op_ind);
+ auto ev = EventCollector::OpSeqEvent{_tracing_ctx, EventCollector::Edge::BEGIN,
+ subg_ind.value(), backend_id,
+ op_ind.value(), op.name()};
+ // add shape of inputs
+ setUserData(_graph, &op, ev.userData);
+ _collector.onEvent(ev);
}
-void ChromeTracingObserver::handleEnd(IExecutor *)
+void TracingObserver::handleJobEnd(IExecutor *, ir::SubgraphIndex subg_ind,
+ ir::OperationIndex op_ind, const backend::Backend *backend)
{
- _collector.onEvent(EventCollector::Event{EventCollector::Edge::END, "runtime", "Graph"});
+ std::string backend_id = backend->config()->id();
+ _collector.onEvent(EventCollector::OpSeqEvent{_tracing_ctx, EventCollector::Edge::END,
+ subg_ind.value(), backend_id, op_ind.value(),
+ _graph.operations().at(op_ind).name()});
}
-std::string ChromeTracingObserver::opSequenceTag(const ir::OpSequence *op_seq,
- const ir::Operations &operations)
+void TracingObserver::handleSubgraphEnd(ir::SubgraphIndex subg_ind)
{
- if (op_seq->size() == 0)
- return "Empty OpSequence";
-
- const auto &first_op_idx = op_seq->operations().at(0);
- const auto &first_op_node = operations.at(first_op_idx);
- std::string tag = "$" + std::to_string(first_op_idx.value());
- tag += " " + first_op_node.name();
- if (op_seq->size() > 1)
- {
- tag += " (+" + std::to_string(op_seq->size() - 1) + ")";
- }
- return tag;
+ _collector.onEvent(
+ EventCollector::SubgEvent{_tracing_ctx, EventCollector::Edge::END, subg_ind.value()});
}
} // namespace exec
diff --git a/runtime/onert/core/src/exec/ExecutionObservers.h b/runtime/onert/core/src/exec/ExecutionObservers.h
index ac0076ed2..e59d58766 100644
--- a/runtime/onert/core/src/exec/ExecutionObservers.h
+++ b/runtime/onert/core/src/exec/ExecutionObservers.h
@@ -17,44 +17,82 @@
#ifndef __ONERT_EXEC_OBSREVERS_H__
#define __ONERT_EXEC_OBSREVERS_H__
-#include "exec/IFunction.h"
-#include "ir/OpSequence.h"
#include "ExecTime.h"
-#include "util/ITimer.h"
+#include "../util/EventCollector.h"
+#include "../util/EventRecorder.h"
+#include "../util/EventWriter.h"
+
#include "exec/IExecutor.h"
-#include "util/EventCollector.h"
-#include "util/EventRecorder.h"
+#include "ir/Index.h"
+#include "ir/IOperation.h"
+#include "util/ITimer.h"
+#include "util/TracingCtx.h"
namespace onert
{
namespace exec
{
+
+enum class ObserverType
+{
+ PROFILE,
+ TRACING,
+ MINMAX_DUMP,
+};
+
class IExecutionObserver
{
public:
/// @brief Invoked just before model (not individual operation) execution begins
- virtual void handleBegin(IExecutor *) { return; }
+ virtual void handleSubgraphBegin(ir::SubgraphIndex) { return; }
- virtual void handleBegin(IExecutor *, const ir::OpSequence *, const backend::Backend *) = 0;
- virtual void handleEnd(IExecutor *, const ir::OpSequence *, const backend::Backend *) = 0;
+ virtual void handleJobBegin(IExecutor *, ir::SubgraphIndex, ir::OperationIndex,
+ const backend::Backend *) = 0;
+ virtual void handleJobEnd(IExecutor *, ir::SubgraphIndex, ir::OperationIndex,
+ const backend::Backend *) = 0;
/// @brief Invoked just after model (not individual operation) execution ends
- virtual void handleEnd(IExecutor *) { return; }
+ virtual void handleSubgraphEnd(ir::SubgraphIndex) { return; }
+
+ virtual ObserverType type() const = 0;
virtual ~IExecutionObserver() = default;
};
+class ExecObservers
+{
+public:
+ void add(std::unique_ptr<IExecutionObserver> &&observer)
+ {
+ _observers.emplace(observer->type(), std::move(observer));
+ }
+
+ IExecutionObserver *get(ObserverType type) const
+ {
+ if (_observers.find(type) != _observers.end())
+ return _observers.at(type).get();
+
+ return nullptr;
+ }
+
+private:
+ std::unordered_map<ObserverType, std::unique_ptr<IExecutionObserver>> _observers;
+};
+
class ProfileObserver : public IExecutionObserver
{
public:
explicit ProfileObserver(std::shared_ptr<ExecTime> et, const ir::Graph &graph)
- : _et(std::move(et)), _graph(graph)
+ : _et(std::move(et)), _graph(graph)
{
}
- void handleBegin(IExecutor *, const ir::OpSequence *, const backend::Backend *) override;
- void handleEnd(IExecutor *, const ir::OpSequence *, const backend::Backend *) override;
+ void handleJobBegin(IExecutor *, ir::SubgraphIndex, ir::OperationIndex,
+ const backend::Backend *) override;
+ void handleJobEnd(IExecutor *, ir::SubgraphIndex, ir::OperationIndex,
+ const backend::Backend *) override;
- void handleEnd(IExecutor *) override { _et->uploadOperationsExecTime(); }
+ void handleSubgraphEnd(ir::SubgraphIndex) override { _et->storeOperationsExecTime(); }
+ ObserverType type() const override { return ObserverType::PROFILE; }
private:
std::unique_ptr<util::ITimer> _timer;
@@ -62,24 +100,27 @@ private:
const ir::Graph &_graph;
};
-class ChromeTracingObserver : public IExecutionObserver
+class TracingObserver : public IExecutionObserver
{
public:
- ChromeTracingObserver(const std::string &filepath, const ir::Graph &graph);
- ~ChromeTracingObserver();
- void handleBegin(IExecutor *) override;
- void handleBegin(IExecutor *, const ir::OpSequence *, const backend::Backend *) override;
- void handleEnd(IExecutor *, const ir::OpSequence *, const backend::Backend *) override;
- void handleEnd(IExecutor *) override;
-
-private:
- static std::string opSequenceTag(const ir::OpSequence *op_seq, const ir::Operations &operations);
+ TracingObserver(const std::string &workspace_dir, const ir::Graph &graph,
+ const util::TracingCtx *tracing_ctx);
+ ~TracingObserver();
+ void handleSubgraphBegin(ir::SubgraphIndex) override;
+ void handleJobBegin(IExecutor *, ir::SubgraphIndex, ir::OperationIndex,
+ const backend::Backend *) override;
+ void handleJobEnd(IExecutor *, ir::SubgraphIndex, ir::OperationIndex,
+ const backend::Backend *) override;
+ void handleSubgraphEnd(ir::SubgraphIndex) override;
+ ObserverType type() const override { return ObserverType::TRACING; }
private:
- std::ofstream _ofs;
- EventRecorder _recorder;
+ std::unique_ptr<EventRecorder> _recorder;
EventCollector _collector;
const ir::Graph &_graph;
+ std::string _workspace_dir;
+ const util::TracingCtx *_tracing_ctx;
+ bool _triggered;
};
} // namespace exec
diff --git a/runtime/onert/core/src/exec/ExecutorBase.cc b/runtime/onert/core/src/exec/ExecutorBase.cc
index f835a9675..2526e4e6e 100644
--- a/runtime/onert/core/src/exec/ExecutorBase.cc
+++ b/runtime/onert/core/src/exec/ExecutorBase.cc
@@ -16,10 +16,10 @@
#include "ExecutorBase.h"
-#include "backend/ITensor.h"
-#include "backend/controlflow/UserTensor.h"
-#include "backend/cpu_common/Tensor.h"
-#include "util/logging.h"
+#include "ShapeConverter.h"
+
+#include "util/ConfigSource.h"
+#include <misc/polymorphic_downcast.h>
namespace onert
{
@@ -27,214 +27,68 @@ namespace exec
{
ExecutorBase::ExecutorBase(std::unique_ptr<compiler::LoweredGraph> &&lowered_graph,
- const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors,
- const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors,
+ backend::BackendContexts &&backend_contexts,
const compiler::TensorRegistries &tensor_regs,
- backend::TensorManagerSet &&tensor_mgrs)
- : _lowered_graph{std::move(lowered_graph)}, _graph{_lowered_graph->graph()},
- _input_tensors{input_tensors}, _output_tensors{output_tensors},
- _tensor_mgrs{std::move(tensor_mgrs)}, _mutex()
+ const util::TracingCtx *tracing_ctx)
+ : _lowered_graph{std::move(lowered_graph)}, _backend_contexts{std::move(backend_contexts)},
+ _graph{_lowered_graph->graph()}, _mutex(), _tracing_ctx(tracing_ctx)
{
- // TODO Fix the way of knowing whether it is primary or not
- bool primary_executor = !(_input_tensors.empty() && _output_tensors.empty());
- if (!primary_executor)
- {
- auto build_input_tensor_list = [&](const onert::ir::OperandIndexSequence &ind_seq) {
- std::vector<std::shared_ptr<backend::ITensor>> list;
- for (auto ind : ind_seq)
- {
- std::shared_ptr<backend::ITensor> tensor = tensor_regs.getITensor(ind);
- assert(tensor != nullptr);
- DynAllocInfo dyn_alloc_info{ind};
- _input_to_dyn_alloc_info.emplace(tensor, dyn_alloc_info);
- list.push_back(tensor);
- }
- return list;
- };
- auto build_output_tensor_list = [&](const onert::ir::OperandIndexSequence &ind_seq) {
- std::vector<std::shared_ptr<backend::ITensor>> list;
- for (auto ind : ind_seq)
- {
- std::shared_ptr<backend::ITensor> tensor = tensor_regs.getITensor(ind);
- assert(tensor != nullptr);
- DynAllocInfo dyn_alloc_info{ind};
- _output_to_dyn_alloc_info.emplace(tensor, dyn_alloc_info);
- list.push_back(tensor);
- }
- return list;
- };
- _input_tensors = build_input_tensor_list(_graph.getInputs());
- _output_tensors = build_output_tensor_list(_graph.getOutputs());
- }
- else
- {
- assert(input_tensors.size() == _graph.getInputs().size());
- assert(output_tensors.size() == _graph.getOutputs().size());
- for (uint32_t i = 0; i < input_tensors.size(); i++)
- {
- auto tensor = input_tensors[i];
- auto ind = _graph.getInputs().at(i);
- DynAllocInfo dyn_alloc_info{ind};
- _input_to_dyn_alloc_info.emplace(tensor, dyn_alloc_info);
- }
- for (uint32_t i = 0; i < output_tensors.size(); i++)
+ auto build_tensor_list = [&](const auto &ind_seq, auto &tensors) {
+ assert(tensors.empty());
+ for (auto &&ind : ind_seq)
{
- auto tensor = output_tensors[i];
- auto ind = _graph.getOutputs().at(i);
- DynAllocInfo dyn_alloc_info{ind};
- _output_to_dyn_alloc_info.emplace(tensor, dyn_alloc_info);
+ backend::ITensor *tensor = tensor_regs.getITensor(ind);
+ assert(tensor != nullptr);
+ auto io_tensor = nnfw::misc::polymorphic_downcast<backend::builtin::IOTensor *>(tensor);
+ tensors.push_back(io_tensor);
}
- }
+ };
+ build_tensor_list(_graph.getInputs(), _input_tensors);
+ build_tensor_list(_graph.getOutputs(), _output_tensors);
}
-void ExecutorBase::execute(const std::vector<std::shared_ptr<backend::ITensor>> &src_tensors,
- const std::shared_ptr<IPermuteFunction> &pre_fn)
+void ExecutorBase::execute(const std::vector<backend::IPortableTensor *> &inputs,
+ const std::vector<backend::IPortableTensor *> &outputs,
+ const ExecutionOptions &options)
{
// For thread-safe, use mutex
// TODO: if all used backends on this executor are thread-safe,
// do not need to use mutex (otherwise, use mutex)
// Deadlock occurs when an Executor is called recursively.
std::lock_guard<std::mutex> lock(_mutex);
+ _current_options = options;
- assert(src_tensors.size() == _graph.getInputs().size());
- assert(src_tensors.size() == _input_tensors.size());
- for (uint32_t n = 0; n < _graph.getInputs().size(); ++n)
+ assert(inputs.size() == _graph.getInputs().size());
+ assert(inputs.size() == _input_tensors.size());
+ for (uint32_t n = 0; n < inputs.size(); ++n)
{
- // when user changes input shape, the input tensor is dynamic and its memory is not allocated.
- // This code find the info to allocate dynamic tensor, and allocate memory based on the source
- // tensor's shape set by caller.
- const auto src_tensor = src_tensors[n];
+ const auto input = inputs[n];
+ assert(input->buffer() != nullptr || input->get_info().total_size() == 0);
auto input_tensor = _input_tensors[n];
- // If src_tensor or input_tensor is nullptr, pre_fn does not copy the tensors
- if (src_tensor != nullptr && input_tensor != nullptr)
- {
- auto dyn_alloc_info = _input_to_dyn_alloc_info.find(_input_tensors[n]);
- const auto orig_input_shape = input_tensor->getShape();
- const auto changed_input_shape =
- convertShape(src_tensor->getShape(), src_tensor->layout(), input_tensor->layout());
- if (orig_input_shape != changed_input_shape)
- {
- if (dyn_alloc_info == _input_to_dyn_alloc_info.end())
- {
- // The input_tensor is a dynamic tensor of backend that doesn't support dynamic tensor
- throw std::runtime_error("Unknown dim is found at execution time for a backend that "
- "does not support dynamic tensor");
- }
- else
- {
- input_tensor->set_dynamic();
- }
- }
- }
+ assert(input_tensor != nullptr);
+ input_tensor->setTensor(input);
}
- // TODO Move calling permute_fn.run() into executeImpl()
- assert(pre_fn);
- pre_fn->run();
-
- executeImpl();
-}
-
-void ExecutorBase::execute(const IODescription &desc)
-{
- // For thread-safe, use mutex
- // TODO: if all used backends on this executor are thread-safe,
- // do not need to use mutex (otherwise, use mutex)
- std::lock_guard<std::mutex> lock(_mutex);
-
- // Set input(s)
- assert(_input_tensors.size() == desc.inputs.size());
- for (uint32_t i = 0; i < _input_tensors.size(); ++i)
+ assert(outputs.size() == _graph.getOutputs().size());
+ assert(outputs.size() == _output_tensors.size());
+ for (uint32_t n = 0; n < outputs.size(); ++n)
{
- // TODO Remove dynamic_cast
- auto tensor = std::dynamic_pointer_cast<backend::controlflow::UserTensor>(_input_tensors[i]);
- assert(tensor);
- auto input_shape = desc.dynamic_input_shapes.find(ir::IOIndex{i});
- if (input_shape != desc.dynamic_input_shapes.end())
- {
- tensor->set_dynamic();
- tensor->setShape(input_shape->second);
- }
- // TODO Better design for ITensor? (we need const_cast as ITensor is writable)
- tensor->setBuffer(static_cast<uint8_t *>(const_cast<void *>(desc.inputs[i]->buffer)),
- desc.inputs[i]->size);
-
- handleDynamicInputTensor(ir::IOIndex{i}, desc);
+ const auto output = outputs[n];
+ assert(output->buffer() != nullptr || output->get_info().total_size() == 0);
+ auto output_tensor = _output_tensors[n];
+ assert(output_tensor != nullptr);
+ output_tensor->setTensor(output);
}
- assert(_output_tensors.size() == desc.outputs.size());
- for (uint32_t i = 0; i < _output_tensors.size(); ++i)
- {
- // TODO Remove dynamic_cast
- auto tensor = std::dynamic_pointer_cast<backend::controlflow::UserTensor>(_output_tensors[i]);
- assert(tensor);
- tensor->set_dynamic(); // It can't be resized but shape could change
- // TODO Better design for ITensor? (we need const_cast as ITensor is writable)
- tensor->setBuffer(static_cast<uint8_t *>(const_cast<void *>(desc.outputs[i]->buffer)),
- desc.outputs[i]->size);
- }
-
- executeImpl();
-
- // Update output(s) desc
- for (uint32_t n = 0; n < _graph.getOutputs().size(); ++n)
- {
- ir::IOIndex output_index{n};
- // Optional output
- if (desc.outputs.at(n) == nullptr)
- {
- continue;
- }
- auto &output = *desc.outputs.at(n);
-
- // set shape of outputDesc to tensor shape since tensor can be dynamic
- const auto output_tensor_shape = _output_tensors[n]->getShape();
- output.info.shape(
- convertShape(output_tensor_shape, _output_tensors[n]->layout(), output.layout));
- }
-}
+ // Create observee
+ ExecutionObservee subject(_observers, options);
-/**
- * @brief Changes tensor shape and allocate memory
- * if input shape was changed by nnfw_set_input_tensorinfo()
- *
- * @note Cases are:
- * 1) static operand -> nnfw_set_input_tensorinfo() -> execute() -> execute()
- * (a) (b)
- *
- * at (a), operand is static, tensor is static - memory dealloc is not needed
- * (DynamicTensorManager cannot dealloc memory allocated by StaticTensorManager)
- * at (b), operand is static, tensor is dynamic - memory dealloc is needed
- *
- * 2) dynamic operand -> nnfw_set_input_tensorinfo() -> execute() -> execute()
- * (a) (b)
- *
- * at (a), operand is dynamic, tensor is dynamic - memory dealloc is not needed
- * since it has not been allocated yet
- * at (b), operand is dynamic, tensor is dynamic - memory dealloc is needed
- */
-void ExecutorBase::handleDynamicInputTensor(ir::IOIndex io_ind, const IODescription &desc)
-{
- auto shape_sig_found = desc.dynamic_input_shapes.find(io_ind);
- if (shape_sig_found != desc.dynamic_input_shapes.end())
- {
- auto dyn_alloc_info = _input_to_dyn_alloc_info.find(_input_tensors[io_ind.value()]);
- if (dyn_alloc_info == _input_to_dyn_alloc_info.end())
- throw std::runtime_error("Unknown dim is found at execution time for a backend that "
- "does not support dynamic tensor");
-
- auto changed_input_shape = shape_sig_found->second;
- auto operand_ind = dyn_alloc_info->second.ind;
-
- auto dyn_tensor_manager = _input_tensors[io_ind.value()]->dynamic_tensor_manager();
- assert(dyn_tensor_manager);
- dyn_tensor_manager->applyShape(operand_ind, changed_input_shape);
- }
+ executeImpl(subject);
}
bool ExecutorBase::hasDynamicInput()
{
- for (auto &tensor : _input_tensors)
+ for (auto &&tensor : _input_tensors)
{
if (tensor->is_dynamic())
return true;
diff --git a/runtime/onert/core/src/exec/ExecutorBase.h b/runtime/onert/core/src/exec/ExecutorBase.h
index a13be7dbf..2ae63ddd4 100644
--- a/runtime/onert/core/src/exec/ExecutorBase.h
+++ b/runtime/onert/core/src/exec/ExecutorBase.h
@@ -17,25 +17,20 @@
#ifndef __ONERT_EXEC_EXECUTOR_BASE_H__
#define __ONERT_EXEC_EXECUTOR_BASE_H__
-#include <mutex>
+#include "ExecutionObservee.h"
+#include "../backend/builtin/IOTensor.h"
+#include "../compiler/TensorRegistries.h"
-#include "IPermuteFunction.h"
-#include "Source.h"
-#include "exec/ExecutionObservers.h"
-#include "Sink.h"
-#include "ShapeConverter.h"
-#include "exec/IExecutor.h"
#include "compiler/LoweredGraph.h"
-#include "ir/LowerInfoMap.h"
-#include "backend/IConfig.h"
-#include "backend/Backend.h"
-#include "exec/ExecTime.h"
-#include "exec/IFunction.h"
-#include "backend/IDynamicTensorManager.h"
-#include "backend/ITensorManager.h"
-#include "exec/ExecutionObservee.h"
-#include "compiler/TensorRegistries.h"
-#include <list>
+#include "exec/IExecutor.h"
+#include "exec/ExecutionContext.h"
+#include "ir/Graph.h"
+#include "ir/OperationIndexMap.h"
+#include "util/TracingCtx.h"
+
+#include <memory>
+#include <mutex>
+#include <vector>
namespace onert
{
@@ -51,47 +46,51 @@ public:
* @param tensor_builders Tensor builders that are currently used
*/
ExecutorBase(std::unique_ptr<compiler::LoweredGraph> &&lowered_graph,
- const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors,
- const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors,
- const compiler::TensorRegistries &tensor_regs,
- backend::TensorManagerSet &&tensor_mgrs);
+ backend::BackendContexts &&backend_contexts,
+ const compiler::TensorRegistries &tensor_regs, const util::TracingCtx *tracing_ctx);
virtual ~ExecutorBase() = default;
- const ir::Graph &graph() final { return _graph; }
+ const ir::Graph &graph() const final { return _graph; }
- /**
- * @brief Execute without IODescription
- *
- * @param src_tensor Tensor list that will be copied to input tensors of this
- * @param pre_fn The permutation function that copy from src_tensor to input tensors of this
- */
- void execute(const std::vector<std::shared_ptr<backend::ITensor>> &src_tensors,
- const std::shared_ptr<IPermuteFunction> &pre_fn);
+ void execute(const std::vector<backend::IPortableTensor *> &inputs,
+ const std::vector<backend::IPortableTensor *> &outputs,
+ const ExecutionOptions &options) override;
- void execute(const IODescription &desc) final;
+ uint32_t inputSize() const override { return _input_tensors.size(); }
- // Used only in Dataflow and Parallel Executors
- void setIndexedRanks(std::shared_ptr<ir::OperationIndexMap<int64_t>> ranks) final
+ uint32_t outputSize() const override { return _output_tensors.size(); }
+
+ const ir::OperandInfo &inputInfo(uint32_t index) const override
{
- _indexed_ranks = std::move(ranks);
- };
+ return _input_tensors[index]->get_info();
+ }
- virtual void executeImpl(void) = 0;
+ const ir::OperandInfo &outputInfo(uint32_t index) const override
+ {
+ return _output_tensors[index]->get_info();
+ }
- void addObserver(std::unique_ptr<IExecutionObserver> ref) { _subject.add(std::move(ref)); };
+ ir::Layout inputLayout(uint32_t index) const override { return _input_tensors[index]->layout(); }
- const std::vector<std::shared_ptr<backend::ITensor>> &getInputTensors() const
+ ir::Layout outputLayout(uint32_t index) const override
{
- return _input_tensors;
+ return _output_tensors[index]->layout();
}
- const std::vector<std::shared_ptr<backend::ITensor>> &getOutputTensors() const
+ // Used only in Dataflow and Parallel Executors
+ void setIndexedRanks(std::shared_ptr<ir::OperationIndexMap<int64_t>> ranks) final
{
- return _output_tensors;
- }
+ _indexed_ranks = std::move(ranks);
+ };
- const DynAllocInfoMap &getInputsDynamicAllocInfo() const { return _input_to_dyn_alloc_info; }
+ virtual void executeImpl(const ExecutionObservee &subject) = 0;
+
+ void addObserver(std::unique_ptr<IExecutionObserver> ref) { _observers.add(std::move(ref)); };
+
+ backend::BackendContexts &getBackendContexts() { return _backend_contexts; }
+
+ const ExecutionOptions &currentOptions() const override { return _current_options; }
protected:
/**
@@ -100,19 +99,23 @@ protected:
bool hasDynamicInput();
protected:
- ExecutionObservee _subject;
+ ExecObservers _observers;
std::shared_ptr<ir::OperationIndexMap<int64_t>> _indexed_ranks;
std::unique_ptr<compiler::LoweredGraph> _lowered_graph;
+ backend::BackendContexts _backend_contexts;
const ir::Graph &_graph;
- std::vector<std::shared_ptr<backend::ITensor>> _input_tensors;
- std::vector<std::shared_ptr<backend::ITensor>> _output_tensors;
- DynAllocInfoMap _input_to_dyn_alloc_info;
- DynAllocInfoMap _output_to_dyn_alloc_info;
- backend::TensorManagerSet _tensor_mgrs;
+ std::vector<backend::builtin::IOTensor *> _input_tensors;
+ std::vector<backend::builtin::IOTensor *> _output_tensors;
std::mutex _mutex;
-
-private:
- void handleDynamicInputTensor(ir::IOIndex input_index, const IODescription &desc);
+ const util::TracingCtx *_tracing_ctx;
+ /**
+ * It is set by execute() method only in thread-safe environment.
+ * It is used for non-primary executor call on builtin backend
+ * and accessed by entryExecutor's currentOptions() method.
+ *
+ * TODO: Find better way to pass config to non-primary executor
+ */
+ ExecutionOptions _current_options;
};
} // namespace exec
diff --git a/runtime/onert/core/src/exec/FunctionSequence.cc b/runtime/onert/core/src/exec/FunctionSequence.cc
index fb31f7582..578123a54 100644
--- a/runtime/onert/core/src/exec/FunctionSequence.cc
+++ b/runtime/onert/core/src/exec/FunctionSequence.cc
@@ -16,8 +16,6 @@
#include "exec/FunctionSequence.h"
-#include "ir/Operation.h"
-#include "backend/IDynamicTensorManager.h"
#include "backend/ITensorRegistry.h"
#include "util/logging.h"
@@ -28,19 +26,19 @@ namespace exec
void FunctionSequence::run()
{
- // TODO Find out when `_enable_dynamic_shape_inferer` is true but `_dynamic_tensor_ctx` is false
if (_enable_dynamic_shape_inferer && _dynamic_tensor_ctx)
{
- if (_dynamic_tensor_ctx->op_seq->size() != _functions.size())
- throw std::runtime_error("operation and functions should be mapped one by one");
+ // acl_cl and acl_neon backend don't support dynamic shape.
+ // _dynamic_tensor_ctx is always nullptr for acl_cl and acl_neon
+ // Thus, those two bakends cannot reach here.
+
+ // Do dynamic shape inference
+ _dynamic_tensor_ctx->op->accept(*_dynamic_tensor_ctx->dynamic_shape_inferer);
- auto op_seq_iter = _dynamic_tensor_ctx->op_seq->begin();
for (const auto &function : _functions)
{
- // set shape of output and allocate memory when needed
- auto &op = _dynamic_tensor_ctx->operations->at(*op_seq_iter);
- op.accept(*_dynamic_tensor_ctx->dynamic_shape_inferer);
-
+ // NOTE the function could be also FunctionSequence so we do this
+ // TODO Remove this or do this recursively
auto *sub_func_seq = dynamic_cast<FunctionSequence *>(function.get());
if (sub_func_seq != nullptr)
{
@@ -50,22 +48,12 @@ void FunctionSequence::run()
// run kernel
function->run();
-
- // deallocate input tensors which is no longer used
- _dynamic_tensor_ctx->dynamic_tensor_manager->deallocInput(*op_seq_iter);
-
- op_seq_iter++;
}
}
else
{
for (const auto &function : _functions)
{
- auto *sub_func_seq = dynamic_cast<FunctionSequence *>(function.get());
- if (sub_func_seq != nullptr)
- {
- sub_func_seq->enableDynamicShapeInferer(false);
- }
function->run();
}
}
diff --git a/runtime/onert/core/src/exec/IPermuteFunction.cc b/runtime/onert/core/src/exec/IPermuteFunction.cc
new file mode 100644
index 000000000..9d548e6dc
--- /dev/null
+++ b/runtime/onert/core/src/exec/IPermuteFunction.cc
@@ -0,0 +1,320 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IPermuteFunction.h"
+
+#include <cker/operation/Quantize.h>
+#include <cker/operation/Dequantize.h>
+#include "backend/IPortableTensor.h"
+#include "exec/IFunction.h"
+#include "ir/Index.h"
+#include "ir/Shape.h"
+#include <memory>
+#include <misc/polymorphic_downcast.h>
+#include <typeinfo>
+#include "util/Utils.h"
+#include <vector>
+#include <unordered_map>
+
+namespace
+{
+using namespace onert;
+
+inline nnfw::cker::Shape getShape(const backend::ITensor *tensor)
+{
+ const ir::Shape shape = tensor->getShape();
+
+ assert(tensor->layout() == ir::Layout::NHWC);
+
+ auto rank = shape.rank();
+ nnfw::cker::Shape ret(rank);
+ auto data = ret.DimsData();
+ for (int i = 0; i < rank; ++i)
+ {
+ data[i] = shape.dim(i);
+ }
+ return ret;
+}
+
+// Quantize per element
+template <typename InputT, typename OutputT>
+void elementwiseQuantize(const backend::ITensor *src_tensor, backend::ITensor *dst_tensor)
+{
+ const auto scale = dst_tensor->data_scale();
+ const auto zero_point = dst_tensor->data_zero_point();
+
+ int min_val = std::numeric_limits<OutputT>::min();
+ int max_val = std::numeric_limits<OutputT>::max();
+
+ auto loop_shape = src_tensor->getShape();
+ const auto src_layout = src_tensor->layout();
+ const auto dst_layout = dst_tensor->layout();
+ const bool is_permutation = src_layout != dst_layout && loop_shape.rank() == 4;
+ ShapeLoop(loop_shape, [&](const onert::ir::Coordinates &coords) {
+ const InputT *input_data =
+ reinterpret_cast<const InputT *>(src_tensor->buffer() + src_tensor->calcOffset(coords));
+ int32_t unclamped = static_cast<int32_t>(round(*input_data / scale)) + zero_point;
+ int32_t clamped = std::min(std::max(unclamped, min_val), max_val);
+
+ ir::Coordinates dst_coords =
+ is_permutation ? ir::convertCoordinates(coords, src_layout, dst_layout) : coords;
+ OutputT *output_data =
+ reinterpret_cast<OutputT *>(dst_tensor->buffer() + dst_tensor->calcOffset(dst_coords));
+ *output_data = clamped;
+ });
+}
+
+// TODO Optimize the case where tensors has the same layout
+template <typename InputT, typename OutputT>
+void quantize(const backend::ITensor *src_tensor, backend::ITensor *dst_tensor)
+{
+ if (!src_tensor->has_padding() && !dst_tensor->has_padding() &&
+ src_tensor->layout() == dst_tensor->layout() && !src_tensor->is_dynamic())
+ {
+ assert(!dst_tensor->is_dynamic());
+
+ // Call optimized neon kernel
+ nnfw::cker::Quantize(getShape(src_tensor),
+ reinterpret_cast<const InputT *>(src_tensor->buffer()),
+ getShape(dst_tensor), reinterpret_cast<OutputT *>(dst_tensor->buffer()),
+ dst_tensor->data_scale(), dst_tensor->data_zero_point());
+ }
+ else
+ {
+ elementwiseQuantize<InputT, OutputT>(src_tensor, dst_tensor);
+ }
+}
+
+// Dequantize per element
+template <typename InputT, typename OutputT>
+void elementwiseDequantize(const backend::ITensor *src_tensor, backend::ITensor *dst_tensor)
+{
+ const auto scale = src_tensor->data_scale();
+ const auto zero_point = src_tensor->data_zero_point();
+
+ auto loop_shape = src_tensor->getShape();
+ const auto src_layout = src_tensor->layout();
+ const auto dst_layout = dst_tensor->layout();
+ const bool is_permutation = src_layout != dst_layout && loop_shape.rank() == 4;
+ ShapeLoop(loop_shape, [&](const onert::ir::Coordinates &coords) {
+ const InputT *input_data =
+ reinterpret_cast<const InputT *>(src_tensor->buffer() + src_tensor->calcOffset(coords));
+ const OutputT result = static_cast<OutputT>(scale * (*input_data - zero_point));
+
+ ir::Coordinates dst_coords =
+ is_permutation ? ir::convertCoordinates(coords, src_layout, dst_layout) : coords;
+ OutputT *output_data =
+ reinterpret_cast<OutputT *>(dst_tensor->buffer() + dst_tensor->calcOffset(dst_coords));
+ *output_data = result;
+ });
+}
+
+// TODO Optimize the case where tensors has the same layout
+template <typename InputT, typename OutputT>
+void dequantize(const backend::ITensor *src_tensor, backend::ITensor *dst_tensor)
+{
+ if (!src_tensor->has_padding() && !dst_tensor->has_padding() &&
+ src_tensor->layout() == dst_tensor->layout() && !src_tensor->is_dynamic())
+ {
+ assert(!dst_tensor->is_dynamic());
+
+ // Call optimized neon kernel
+ nnfw::cker::Dequantize(getShape(src_tensor),
+ reinterpret_cast<const InputT *>(src_tensor->buffer()),
+ getShape(dst_tensor), reinterpret_cast<OutputT *>(dst_tensor->buffer()),
+ src_tensor->data_scale(), src_tensor->data_zero_point());
+ }
+ else
+ {
+ elementwiseDequantize<InputT, OutputT>(src_tensor, dst_tensor);
+ }
+}
+
+template <typename SRC_T, typename DST_T,
+ std::enable_if_t<std::is_base_of<backend::ITensor, SRC_T>::value &&
+ std::is_base_of<backend::ITensor, DST_T>::value,
+ bool> = true>
+void typeAwareQuantize(const SRC_T *src_tensor, DST_T *dst_tensor)
+{
+ // TODO Support other types
+ if (src_tensor->data_type() == ir::DataType::FLOAT32)
+ {
+ switch (dst_tensor->data_type())
+ {
+ case ir::DataType::QUANT_UINT8_ASYMM:
+ {
+ quantize<float, uint8_t>(src_tensor, dst_tensor);
+ break;
+ }
+ case ir::DataType::QUANT_INT8_SYMM:
+ {
+ quantize<float, int8_t>(src_tensor, dst_tensor);
+ break;
+ }
+ case ir::DataType::QUANT_INT16_SYMM:
+ {
+ quantize<float, int16_t>(src_tensor, dst_tensor);
+ break;
+ }
+ default:
+ {
+ throw std::runtime_error("IPermuteFunction: Unsupported quantization type");
+ break;
+ }
+ }
+ }
+ else if (dst_tensor->data_type() == ir::DataType::FLOAT32)
+ {
+ switch (src_tensor->data_type())
+ {
+ case ir::DataType::QUANT_UINT8_ASYMM:
+ {
+ dequantize<uint8_t, float>(src_tensor, dst_tensor);
+ break;
+ }
+ case ir::DataType::QUANT_INT8_SYMM:
+ {
+ dequantize<int8_t, float>(src_tensor, dst_tensor);
+ break;
+ }
+ case ir::DataType::QUANT_INT16_SYMM:
+ {
+ dequantize<int16_t, float>(src_tensor, dst_tensor);
+ break;
+ }
+ default:
+ {
+ throw std::runtime_error("IPermuteFunction: Unsupported dequantization type");
+ break;
+ }
+ }
+ }
+ else
+ {
+ throw std::runtime_error("IPermuteFunction: Unsupported type for type-aware quantization yet");
+ }
+}
+
+} // namespace
+
+namespace onert
+{
+namespace exec
+{
+
+void IPermuteFunction::IPermuteFunction::run()
+{
+ // TODO Optimization : Make control does not reach here? when (_src_tensors.size() == 0)
+ assert(_src_tensors.size() == _dst_tensors.size());
+ if (_src_tensors_offsets.size() == 0)
+ {
+ _src_tensors_offsets.resize(_src_tensors.size());
+ _dst_tensors_offsets.resize(_dst_tensors.size());
+ }
+ assert(_src_tensors.size() == _src_tensors_offsets.size());
+ assert(_src_tensors_offsets.size() == _dst_tensors_offsets.size());
+
+ for (size_t i = 0; i < _src_tensors.size(); ++i)
+ {
+ auto src_tensor = _src_tensors.at(i);
+ auto dst_tensor = _dst_tensors.at(i);
+ auto &src_offsets = _src_tensors_offsets.at(i);
+ auto &dst_offsets = _dst_tensors_offsets.at(i);
+ if (src_tensor != dst_tensor)
+ {
+ const auto rank = src_tensor->getShape().rank();
+ permute(src_tensor, dst_tensor, rank, src_offsets, dst_offsets);
+ }
+ }
+}
+
+void IPermuteFunction::permute(backend::ITensor *src_tensor, backend::ITensor *dst_tensor,
+ size_t rank, std::vector<size_t> &src_offsets,
+ std::vector<size_t> &dst_offsets)
+{
+ if (src_tensor->total_size() == 0)
+ {
+ assert(dst_tensor->total_size() == 0);
+ return;
+ }
+
+ assert(src_tensor != dst_tensor);
+ if (underlying_type(src_tensor->data_type()) != underlying_type(dst_tensor->data_type()))
+ {
+ typeAwareQuantize(src_tensor, dst_tensor);
+ return;
+ }
+
+ switch (src_tensor->data_type())
+ {
+ case ir::DataType::FLOAT32:
+ permute<float>(src_tensor, dst_tensor, rank, src_offsets, dst_offsets);
+ break;
+ case ir::DataType::INT32:
+ permute<int32_t>(src_tensor, dst_tensor, rank, src_offsets, dst_offsets);
+ break;
+ case ir::DataType::UINT32:
+ permute<uint32_t>(src_tensor, dst_tensor, rank, src_offsets, dst_offsets);
+ break;
+ case ir::DataType::BOOL8:
+ case ir::DataType::QUANT_UINT8_ASYMM:
+ case ir::DataType::UINT8:
+ permute<uint8_t>(src_tensor, dst_tensor, rank, src_offsets, dst_offsets);
+ break;
+ case ir::DataType::QUANT_INT8_ASYMM:
+ case ir::DataType::QUANT_INT8_SYMM:
+ permute<int8_t>(src_tensor, dst_tensor, rank, src_offsets, dst_offsets);
+ break;
+ case ir::DataType::INT64:
+ permute<int64_t>(src_tensor, dst_tensor, rank, src_offsets, dst_offsets);
+ break;
+ case ir::DataType::QUANT_INT16_SYMM:
+ permute<int16_t>(src_tensor, dst_tensor, rank, src_offsets, dst_offsets);
+ break;
+ default:
+ throw std::runtime_error("IPermuteFunction: Not supported data type");
+ break;
+ }
+}
+
+const std::type_info &IPermuteFunction::underlying_type(ir::DataType type) const
+{
+ switch (type)
+ {
+ case ir::DataType::FLOAT32:
+ return typeid(float);
+ case ir::DataType::INT32:
+ return typeid(int32_t);
+ case ir::DataType::UINT32:
+ return typeid(uint32_t);
+ case ir::DataType::INT64:
+ return typeid(int64_t);
+ case ir::DataType::BOOL8:
+ case ir::DataType::QUANT_UINT8_ASYMM:
+ case ir::DataType::UINT8:
+ return typeid(uint8_t);
+ case ir::DataType::QUANT_INT8_ASYMM:
+ case ir::DataType::QUANT_INT8_SYMM:
+ return typeid(int8_t);
+ case ir::DataType::QUANT_INT16_SYMM:
+ return typeid(int16_t);
+ default:
+ throw std::runtime_error("IPermuteFunction: Not supported data type");
+ }
+}
+
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/IPermuteFunction.h b/runtime/onert/core/src/exec/IPermuteFunction.h
index 6b4d15380..ccac66cad 100644
--- a/runtime/onert/core/src/exec/IPermuteFunction.h
+++ b/runtime/onert/core/src/exec/IPermuteFunction.h
@@ -25,21 +25,48 @@
#include "backend/ITensor.h"
#include "exec/IFunction.h"
-#include "ir/Index.h"
-#include "ir/Shape.h"
#include <memory>
-#include <typeinfo>
-#include "util/Utils.h"
#include <vector>
+#include <unordered_map>
namespace onert
{
namespace exec
{
+inline void UpdateOffsets(::onert::backend::ITensor *src, ::onert::backend::ITensor *dst,
+ const ::onert::ir::Shape &loop_shape, std::vector<size_t> &src_offsets,
+ std::vector<size_t> &dst_offsets)
+{
+ ShapeLoop(loop_shape, [&](const onert::ir::Coordinates &coords) {
+ src_offsets.emplace_back(src->calcOffset(coords));
+ dst_offsets.emplace_back(dst->calcOffset(coords));
+ });
+}
+
+inline void CopyStatic(const uint8_t *src_buffer, uint8_t *dst_buffer,
+ const std::vector<size_t> &src_offsets,
+ const std::vector<size_t> &dst_offsets, size_t copy_len)
+{
+ assert(src_offsets.size() == dst_offsets.size());
+ for (size_t i = 0; i < src_offsets.size(); ++i)
+ {
+ memcpy(dst_buffer + dst_offsets.at(i), src_buffer + src_offsets.at(i), copy_len);
+ }
+}
+
+inline void CopyDynamic(const ::onert::backend::ITensor *src, const ::onert::backend::ITensor *dst,
+ uint8_t *dst_buffer, const ::onert::ir::Shape &loop_shape, size_t copy_len)
+{
+ ShapeLoop(loop_shape, [&](const onert::ir::Coordinates &coords) {
+ // Copy src tensor's data to dst_buffer with calculated offset of dst tensor
+ memcpy(dst_buffer + dst->calcOffset(coords), src->buffer() + src->calcOffset(coords), copy_len);
+ });
+}
+
class IPermuteFunction : public IFunction
{
-private:
+protected:
enum class PermuteType
{
NHWC_TO_NCHW,
@@ -48,63 +75,69 @@ private:
};
public:
- virtual void run() override
+ virtual void run() override;
+
+ virtual void prepare() override { optimize(); }
+
+ virtual void optimize() = 0;
+
+protected:
+ void permute(backend::ITensor *src_tensor, backend::ITensor *dst_tensor, size_t rank,
+ std::vector<size_t> &src_offsets, std::vector<size_t> &dst_offsets);
+
+private:
+ // TODO make src const by proving const access()
+ template <class T>
+ void permute(backend::ITensor *src, backend::ITensor *dst, size_t rank,
+ std::vector<size_t> &src_offsets, std::vector<size_t> &dst_offsets)
{
- assert(_src_tensors.size() > 0);
- assert(_src_tensors.size() == _dst_tensors.size());
- auto src_it = _src_tensors.begin();
- auto dst_it = _dst_tensors.begin();
- while (src_it != _src_tensors.end())
+ assert(src->total_size() != 0 && dst->total_size() != 0);
+ // If dst is subtensor, we have to use clEnqueueMapBuffer instead of clEnqueueWirteBuffer
+ if (dst->needMemoryMap() && !dst->is_subtensor())
{
- const auto src_tensor = *src_it;
- auto dst_tensor = *dst_it;
- if (src_tensor != dst_tensor)
+ // A assertion to check mapping without calling map()
+ // Now there is no case where both src and dst have cl buffer.
+ assert(!src->needMemoryMap());
+
+ if (!src->has_padding() && !dst->has_padding() && src->layout() == dst->layout())
{
- // TODO Change to permute in parallel
- assert(underlying_type(src_tensor->data_type()) ==
- underlying_type(dst_tensor->data_type()));
- const auto rank = src_tensor->num_dimensions();
- switch (src_tensor->data_type())
- {
- case ir::DataType::FLOAT32:
- permute<float>(src_tensor, dst_tensor, rank);
- break;
- case ir::DataType::INT32:
- permute<int32_t>(src_tensor, dst_tensor, rank);
- break;
- case ir::DataType::UINT32:
- permute<uint32_t>(src_tensor, dst_tensor, rank);
- break;
- case ir::DataType::BOOL8:
- case ir::DataType::QUANT_UINT8_ASYMM:
- case ir::DataType::UINT8:
- permute<uint8_t>(src_tensor, dst_tensor, rank);
- break;
- case ir::DataType::QUANT_INT8_SYMM:
- permute<int8_t>(src_tensor, dst_tensor, rank);
- break;
- case ir::DataType::INT64:
- permute<int64_t>(src_tensor, dst_tensor, rank);
- break;
- default:
- throw std::runtime_error("IPermuteFunction: Not supported data type");
- break;
- }
+ src->access([&](backend::ITensor &) { dst->enqueueWriteBuffer(src->buffer(), false); });
+ }
+ else
+ {
+ // TODO Optimize this block in case of that padding size of dst is big.
+ _buffers_map[dst].reserve(dst->total_size());
+ auto dst_buffer = _buffers_map[dst].data();
+ src->access([&](backend::ITensor &) {
+ permute<T>(src, dst, rank, dst_buffer, dst->total_size(), src_offsets, dst_offsets);
+ });
+ dst->enqueueWriteBuffer(dst_buffer, false);
}
- src_it++;
- dst_it++;
+ }
+ else if (src->needMemoryMap() && !src->is_subtensor() && !src->has_padding() &&
+ !dst->has_padding() && src->layout() == dst->layout())
+ {
+ assert(!dst->needMemoryMap());
+ dst->access([&](backend::ITensor &) { src->enqueueReadBuffer(dst->buffer(), true); });
+ }
+ else
+ {
+ auto fn = [&](backend::ITensor &) {
+ dst->access([&](backend::ITensor &) {
+ permute<T>(src, dst, rank, dst->buffer(), dst->total_size(), src_offsets, dst_offsets);
+ });
+ };
+ src->access(fn);
}
}
- virtual void prepare() override { optimize(); }
-
- virtual void optimize() = 0;
-
-private:
template <class T>
- void permute(const std::shared_ptr<backend::ITensor> &src, std::shared_ptr<backend::ITensor> &dst,
- size_t rank)
+ void permute(backend::ITensor *src, backend::ITensor *dst, size_t rank, uint8_t *dst_buffer,
+ size_t dst_size, std::vector<size_t> &src_offsets, std::vector<size_t> &dst_offsets)
{
+ assert(dst_buffer != nullptr);
+ assert(dst_size == dst->total_size());
+
const auto permute_type = [&]() -> PermuteType {
if (src->layout() == ir::Layout::NHWC && dst->layout() == ir::Layout::NCHW)
{
@@ -119,166 +152,130 @@ private:
return PermuteType::COPY;
}
}();
- auto fn = [&](backend::ITensor &src_tensor) {
- dst->access([&](backend::ITensor &dst_tensor) {
- auto src_buffer = src_tensor.buffer();
- auto src_size = src_tensor.total_size();
- auto dst_buffer = dst_tensor.buffer();
- if (permute_type == PermuteType::COPY)
+ if (rank == 4 && permute_type != PermuteType::COPY)
+ {
+ switch (permute_type)
+ {
+ case PermuteType::NHWC_TO_NCHW:
{
- assert(src_tensor.layout() == dst_tensor.layout());
- if (!src_tensor.has_padding() && !dst_tensor.has_padding())
- {
- assert(src_size <= dst_tensor.total_size());
- memcpy(dst_buffer, src_buffer, src_size);
- return;
- }
+ ir::FeatureShape shape;
+ auto dst_shape = dst->getShape();
+ shape.N = dst_shape.dim(0);
+ shape.C = dst_shape.dim(1);
+ shape.H = dst_shape.dim(2);
+ shape.W = dst_shape.dim(3);
+
+ typename feature::nchw::View<T>::Strides strides;
+ const auto start_offset = dst->calcOffset({0, 0, 0, 0});
+ strides.W = dst_shape.dim(3) == 1 ? 0 : dst->calcOffset({0, 0, 0, 1}) - start_offset;
+ strides.H = dst_shape.dim(2) == 1 ? 0 : dst->calcOffset({0, 0, 1, 0}) - start_offset;
+ strides.C = dst_shape.dim(1) == 1 ? 0 : dst->calcOffset({0, 1, 0, 0}) - start_offset;
+ strides.N = dst_shape.dim(0) == 1 ? 0 : dst->calcOffset({1, 0, 0, 0}) - start_offset;
+
+ const feature::nhwc::Reader<T> from(src);
+ feature::nchw::View<T> into(shape, strides,
+ reinterpret_cast<T *>(dst_buffer + start_offset), dst_size);
+ feature::iterate(shape) << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) {
+ const auto value = from.at(batch, row, col, ch);
+ into.at(batch, ch, row, col) = value;
+ };
+ break;
}
- switch (rank)
+ case PermuteType::NCHW_TO_NHWC:
{
- case 0:
- case 1:
- {
- const int32_t copy_len = dst_tensor.dimension(0);
+ ir::FeatureShape shape;
+ auto dst_shape = dst->getShape();
+ shape.N = dst_shape.dim(0);
+ shape.H = dst_shape.dim(1);
+ shape.W = dst_shape.dim(2);
+ shape.C = dst_shape.dim(3);
- memcpy(dst_buffer, src_buffer, copy_len * sizeof(T));
- break;
- }
- case 2:
- {
- const int32_t dim_0 = dst_tensor.dimension(0);
- const int32_t copy_len = dst_tensor.dimension(1);
+ typename feature::nhwc::View<T>::Strides strides;
+ const auto start_offset = dst->calcOffset({0, 0, 0, 0});
+ strides.C = dst_shape.dim(3) == 1 ? 0 : dst->calcOffset({0, 0, 0, 1}) - start_offset;
+ strides.W = dst_shape.dim(2) == 1 ? 0 : dst->calcOffset({0, 0, 1, 0}) - start_offset;
+ strides.H = dst_shape.dim(1) == 1 ? 0 : dst->calcOffset({0, 1, 0, 0}) - start_offset;
+ strides.N = dst_shape.dim(0) == 1 ? 0 : dst->calcOffset({1, 0, 0, 0}) - start_offset;
- for (int32_t i = 0; i < dim_0; ++i)
- {
- ir::Coordinates coords{i, 0};
- memcpy(dst_buffer + dst_tensor.calcOffset(coords),
- src_buffer + src_tensor.calcOffset(coords), copy_len * sizeof(T));
- }
- break;
- }
- case 3:
- {
- const int32_t dim_0 = dst_tensor.dimension(0);
- const int32_t dim_1 = dst_tensor.dimension(1);
- const int32_t copy_len = dst_tensor.dimension(2);
+ const feature::nchw::Reader<T> from(src);
+ feature::nhwc::View<T> into(shape, strides,
+ reinterpret_cast<T *>(dst_buffer + start_offset), dst_size);
+ feature::iterate(shape) << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) {
+ const auto value = from.at(batch, ch, row, col);
+ into.at(batch, row, col, ch) = value;
+ };
+ break;
+ }
+ default:
+ {
+ throw std::runtime_error("Unsupported Permutation");
+ break;
+ }
+ }
+ }
+ else if (!src->has_padding() && !dst->has_padding())
+ {
+ auto src_size = src->total_size();
+ assert(src_size <= dst->total_size());
+ memcpy(dst_buffer, src->buffer(), src_size);
+ }
+ else
+ {
+ auto loop_shape = src->getShape();
+ const auto copy_axis = loop_shape.rank() - 1;
+ const auto copy_len = loop_shape.dim(copy_axis) * sizeof(T);
+ loop_shape.dim(copy_axis) = 1;
- for (auto i = 0; i < dim_0; ++i)
- {
- for (auto j = 0; j < dim_1; ++j)
- {
- ir::Coordinates coords{i, j, 0};
- memcpy(dst_buffer + dst_tensor.calcOffset(coords),
- src_buffer + src_tensor.calcOffset(coords), copy_len * sizeof(T));
- }
- }
- break;
- }
- case 4:
- {
- switch (permute_type)
- {
- case PermuteType::NHWC_TO_NCHW:
- {
- ir::FeatureShape shape;
- shape.N = dst_tensor.dimension(0);
- shape.C = dst_tensor.dimension(1);
- shape.H = dst_tensor.dimension(2);
- shape.W = dst_tensor.dimension(3);
- const feature::nhwc::Reader<T> from(&src_tensor);
- feature::nchw::View<T> into(&dst_tensor);
- feature::iterate(shape)
- << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) {
- const auto value = from.at(batch, row, col, ch);
- into.at(batch, ch, row, col) = value;
- };
- break;
- }
- case PermuteType::NCHW_TO_NHWC:
- {
- ir::FeatureShape shape;
- shape.N = src_tensor.dimension(0);
- shape.C = src_tensor.dimension(1);
- shape.H = src_tensor.dimension(2);
- shape.W = src_tensor.dimension(3);
- const feature::nchw::Reader<T> from(&src_tensor);
- feature::nhwc::View<T> into(&dst_tensor);
- feature::iterate(shape)
- << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) {
- const auto value = from.at(batch, ch, row, col);
- into.at(batch, row, col, ch) = value;
- };
- break;
- }
- case PermuteType::COPY:
- {
- const int32_t dim_0 = dst_tensor.dimension(0);
- const int32_t dim_1 = dst_tensor.dimension(1);
- const int32_t dim_2 = dst_tensor.dimension(2);
- const int32_t copy_len = dst_tensor.dimension(3);
+ if (src->is_dynamic())
+ {
+ assert(dst->is_dynamic());
+ CopyDynamic(src, dst, dst_buffer, loop_shape, copy_len);
+ }
+ else
+ {
+ // TODO Uncomment the assertion below
+ // assert(!dst->is_dynamic() || dst is output of graph);
+ if (src_offsets.size() == 0)
+ {
+ assert(dst_offsets.size() == 0);
- for (auto i = 0; i < dim_0; ++i)
- {
- for (auto j = 0; j < dim_1; ++j)
- {
- for (auto k = 0; k < dim_2; ++k)
- {
- ir::Coordinates coords{i, j, k, 0};
- memcpy(dst_buffer + dst_tensor.calcOffset(coords),
- src_buffer + src_tensor.calcOffset(coords), copy_len * sizeof(T));
- }
- }
- }
- break;
- }
- default:
- {
- throw std::runtime_error("Unsupported Permutation");
- break;
- }
- }
- break;
- }
- default:
- throw std::runtime_error("Unsupported rank in permutation");
- break;
+ auto loop_shape = src->getShape();
+ const auto copy_axis = loop_shape.rank() - 1;
+ loop_shape.dim(copy_axis) = 1;
+ UpdateOffsets(src, dst, loop_shape, src_offsets, dst_offsets);
}
- });
- };
- src->access(fn);
+ CopyStatic(src->buffer(), dst_buffer, src_offsets, dst_offsets, copy_len);
+ }
+ }
}
+protected:
// NOTE The typeid expression is lvalue expression which refers to an object with static storage
// duration, of the polymorphic type const std::type_info or of some type derived from it.
// So std::type_info is non-copyable
- const std::type_info &underlying_type(ir::DataType type) const
- {
- switch (type)
- {
- case ir::DataType::FLOAT32:
- return typeid(float);
- case ir::DataType::INT32:
- return typeid(int32_t);
- case ir::DataType::UINT32:
- return typeid(uint32_t);
- case ir::DataType::INT64:
- return typeid(int64_t);
- case ir::DataType::BOOL8:
- case ir::DataType::QUANT_UINT8_ASYMM:
- case ir::DataType::UINT8:
- return typeid(uint8_t);
- case ir::DataType::QUANT_INT8_SYMM:
- return typeid(int8_t);
- default:
- throw std::runtime_error("IPermuteFunction: Not supported data type");
- }
- }
+ const std::type_info &underlying_type(ir::DataType type) const;
protected:
- std::vector<std::shared_ptr<backend::ITensor>> _src_tensors;
- std::vector<std::shared_ptr<backend::ITensor>> _dst_tensors;
- // TODO Remove this member if it is possible
- std::vector<size_t> _ranks;
+ std::vector<backend::ITensor *> _src_tensors;
+ std::vector<backend::ITensor *> _dst_tensors;
+ std::vector<std::vector<size_t>> _src_tensors_offsets;
+ std::vector<std::vector<size_t>> _dst_tensors_offsets;
+ std::unordered_map<const backend::ITensor *, std::vector<uint8_t>> _buffers_map;
+};
+
+// Simple PermuteLayer
+class PermuteLayer : public onert::exec::IPermuteFunction
+{
+public:
+ PermuteLayer(const std::vector<onert::backend::ITensor *> &inputs,
+ const std::vector<onert::backend::ITensor *> &outputs)
+ {
+ assert(inputs.size() == outputs.size());
+ _src_tensors = inputs;
+ _dst_tensors = outputs;
+ }
+ virtual ~PermuteLayer() {}
+ void optimize() override {}
};
} // namespace exec
diff --git a/runtime/onert/core/src/exec/IPermuteFunction.test.cc b/runtime/onert/core/src/exec/IPermuteFunction.test.cc
new file mode 100644
index 000000000..fb2dd3b95
--- /dev/null
+++ b/runtime/onert/core/src/exec/IPermuteFunction.test.cc
@@ -0,0 +1,920 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "IPermuteFunction.h"
+
+#include <ir/Layout.h>
+#include <ir/Shape.h>
+#include <ir/TypeInfo.h>
+
+#include <cmath>
+#include <gtest/gtest.h>
+
+namespace
+{
+using namespace onert;
+using namespace ir;
+using namespace backend;
+using namespace exec;
+
+class MockUpTensor : public ITensor
+{
+public:
+ MockUpTensor(const Shape &shape, const TypeInfo &type_info, Layout layout, size_t pad)
+ : _shape(shape), _type_info(type_info), _data(nullptr), _layout(layout)
+ {
+ _strides.resize(shape.rank());
+
+ std::vector<size_t> pads(shape.rank(), 0);
+ pads[shape.rank() - 1] = pad;
+ size_t stride = 1;
+ for (int32_t i = _shape.rank() - 1; i >= 0; --i)
+ {
+ _strides.at(i) = stride;
+ stride = stride * (_shape.dim(i) + pads.at(i));
+ }
+ }
+ virtual ~MockUpTensor() {}
+
+ void setBuffer(uint8_t *data) { _data = data; }
+
+ size_t total_size() const override
+ {
+ size_t total_size = _strides[0] * _shape.dim(0);
+ total_size *= sizeOfDataType(data_type());
+ return total_size;
+ }
+
+ size_t calcOffset(const ir::Coordinates &coords) const override
+ {
+ size_t offset = 0;
+ for (size_t i = 0; i < _shape.rank(); ++i)
+ {
+ offset += (_strides[i] * coords[i]);
+ }
+ offset *= sizeOfDataType(data_type());
+ return offset;
+ }
+
+ uint8_t *buffer() const override { return _data; }
+
+ ir::Layout layout() const override { return _layout; }
+ ir::DataType data_type() const override { return _type_info.type(); }
+ float data_scale() const override { return _type_info.scale(); }
+ int32_t data_zero_point() const override { return _type_info.zero_point(); }
+ const std::vector<float> &data_scales() const override { return _type_info.scales(); }
+ const std::vector<int32_t> &data_zero_points() const override { return _type_info.zero_points(); }
+ bool has_padding() const override
+ {
+ return total_size() / sizeOfDataType(data_type()) != _shape.num_elements();
+ }
+ void access(const std::function<void(ITensor &tensor)> &fn) final { fn(*this); }
+
+ bool is_dynamic() const override { return false; }
+ Shape getShape() const override { return _shape; }
+
+private:
+ Shape _shape;
+ TypeInfo _type_info;
+ Layout _layout;
+ uint8_t *_data;
+ std::vector<size_t> _strides;
+};
+
+class MockUpLayer : public IPermuteFunction
+{
+public:
+ MockUpLayer(const std::vector<ITensor *> &inputs, const std::vector<ITensor *> &outputs)
+ {
+ assert(inputs.size() == outputs.size());
+ _src_tensors = inputs;
+ _dst_tensors = outputs;
+ }
+ virtual ~MockUpLayer() {}
+ void optimize() override {}
+};
+
+TEST(IPermuteFunction, float_to_float)
+{
+ // rank 1
+ {
+ const size_t input_pads[4] = {0, 1, 0, 2};
+ const size_t output_pads[4] = {0, 0, 2, 1};
+ const std::vector<Shape> shapes{{1}, {4}, {5}, {2}};
+ float expected_buffer[] = {1, 0, -1, -2, 3};
+ const auto type_info = TypeInfo(DataType::FLOAT32);
+
+ std::vector<std::unique_ptr<MockUpTensor>> inputs(4);
+ std::vector<std::unique_ptr<MockUpTensor>> outputs(4);
+
+ std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ inputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, input_pads[i]);
+ inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer));
+
+ outputs[i] =
+ std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, output_pads[i]);
+ output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size());
+ outputs[i]->setBuffer(output_buffers[i].get());
+ }
+
+ auto mockup_layer = std::make_unique<MockUpLayer>(
+ std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()},
+ std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(),
+ outputs[3].get()});
+ mockup_layer->run();
+
+ for (size_t i = 0; i < 4; ++i)
+ {
+ for (int32_t j = 0; j < shapes[i].dim(0); ++j)
+ {
+ Coordinates coords{j};
+ float result =
+ *reinterpret_cast<float *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords));
+ float expected =
+ *reinterpret_cast<float *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords));
+ EXPECT_EQ(result, expected);
+ }
+ }
+ }
+
+ // rank 2
+ {
+ const size_t input_pads[4] = {0, 1, 0, 2};
+ const size_t output_pads[4] = {0, 0, 2, 1};
+ const std::vector<Shape> shapes{{1, 4}, {2, 2}, {1, 5}, {2, 3}};
+ float expected_buffer[] = {1, 0, -1, -2, 3, -4, 5, -6, 7, -8};
+ const auto type_info = TypeInfo(DataType::FLOAT32);
+
+ std::vector<std::unique_ptr<MockUpTensor>> inputs(4);
+ std::vector<std::unique_ptr<MockUpTensor>> outputs(4);
+ std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ inputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, input_pads[i]);
+ inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer));
+
+ outputs[i] =
+ std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, output_pads[i]);
+ output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size());
+ outputs[i]->setBuffer(output_buffers[i].get());
+ }
+
+ auto mockup_layer = std::make_unique<MockUpLayer>(
+ std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()},
+ std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(),
+ outputs[3].get()});
+ mockup_layer->run();
+
+ for (size_t i = 0; i < 4; ++i)
+ {
+ for (int32_t j = 0; j < shapes[i].dim(0); ++j)
+ {
+ for (int32_t k = 0; k < shapes[i].dim(1); ++k)
+ {
+ Coordinates coords{j, k};
+ float result =
+ *reinterpret_cast<float *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords));
+ float expected =
+ *reinterpret_cast<float *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords));
+ EXPECT_EQ(result, expected);
+ }
+ }
+ }
+ }
+
+ // rank 3
+ {
+ const size_t input_pads[4] = {0, 5, 0, 2};
+ const size_t output_pads[4] = {0, 3, 2, 1};
+ const std::vector<Shape> shapes{{1, 4, 1}, {1, 2, 1}, {2, 1, 5}, {1, 2, 3}};
+ float expected_buffer[] = {1, 0, -1, -2, 3, -4, 5, -6, 7, -8, 9, -10};
+ const auto type_info = TypeInfo(DataType::FLOAT32);
+
+ std::vector<std::unique_ptr<MockUpTensor>> inputs(4);
+ std::vector<std::unique_ptr<MockUpTensor>> outputs(4);
+ std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ inputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, input_pads[i]);
+ inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer));
+
+ outputs[i] =
+ std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, output_pads[i]);
+ output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size());
+ outputs[i]->setBuffer(output_buffers[i].get());
+ }
+
+ auto mockup_layer = std::make_unique<MockUpLayer>(
+ std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()},
+ std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(),
+ outputs[3].get()});
+ mockup_layer->run();
+
+ for (size_t i = 0; i < 4; ++i)
+ {
+ for (int32_t j = 0; j < shapes[i].dim(0); ++j)
+ {
+ for (int32_t k = 0; k < shapes[i].dim(1); ++k)
+ {
+ for (int32_t l = 0; l < shapes[i].dim(2); ++l)
+ {
+ Coordinates coords{j, k, l};
+ float result =
+ *reinterpret_cast<float *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords));
+ float expected =
+ *reinterpret_cast<float *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords));
+ EXPECT_EQ(result, expected);
+ }
+ }
+ }
+ }
+ }
+
+ // rank 4
+ {
+ const size_t input_pads[4] = {0, 0, 1, 2};
+ const size_t output_pads[4] = {0, 3, 2, 1};
+ const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}};
+ float expected_buffer[] = {1, 0, -1, -2, 3, -4, 5, -6, 7, -8, 9, -10};
+ const auto type_info = TypeInfo(DataType::FLOAT32);
+
+ std::vector<std::unique_ptr<MockUpTensor>> inputs(4);
+ std::vector<std::unique_ptr<MockUpTensor>> outputs(4);
+ std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ inputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, input_pads[i]);
+ inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer));
+
+ outputs[i] =
+ std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, output_pads[i]);
+ output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size());
+ outputs[i]->setBuffer(output_buffers[i].get());
+ }
+
+ auto mockup_layer = std::make_unique<MockUpLayer>(
+ std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()},
+ std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(),
+ outputs[3].get()});
+ mockup_layer->run();
+
+ for (size_t i = 0; i < 4; ++i)
+ {
+ for (int32_t j = 0; j < shapes[i].dim(0); ++j)
+ {
+ for (int32_t k = 0; k < shapes[i].dim(1); ++k)
+ {
+ for (int32_t l = 0; l < shapes[i].dim(2); ++l)
+ {
+ for (int32_t m = 0; m < shapes[i].dim(3); ++m)
+ {
+ Coordinates coords{j, k, l, m};
+ float result =
+ *reinterpret_cast<float *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords));
+ float expected =
+ *reinterpret_cast<float *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords));
+ EXPECT_EQ(result, expected);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // rank4 layout
+ {
+ const size_t input_pads[4] = {0, 0, 1, 2};
+ const size_t output_pads[4] = {0, 3, 2, 1};
+ const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}};
+ float expected_buffer[] = {1, 0, -1, -2, 3, -4, 5, -6, 7,
+ -8, 9, -10, 11, -12, 13, -14, 15, -16};
+ const auto type_info = TypeInfo(DataType::FLOAT32);
+
+ std::vector<std::unique_ptr<MockUpTensor>> inputs(4);
+ std::vector<std::unique_ptr<MockUpTensor>> outputs(4);
+ std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ Layout layout = Layout::NHWC;
+ Shape shape = shapes[i];
+ if (i % 2 == 1)
+ {
+ layout = Layout::NCHW;
+ shape = Shape{shapes[i].dim(0), shapes[i].dim(3), shapes[i].dim(1), shapes[i].dim(2)};
+ }
+ inputs[i] = std::make_unique<MockUpTensor>(shape, type_info, layout, input_pads[i]);
+ inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer));
+
+ if (layout == Layout::NHWC)
+ {
+ layout = Layout::NCHW;
+ shape = Shape{shapes[i].dim(0), shapes[i].dim(3), shapes[i].dim(1), shapes[i].dim(2)};
+ }
+ else
+ {
+ layout = Layout::NHWC;
+ shape = shapes[i];
+ }
+ outputs[i] = std::make_unique<MockUpTensor>(shape, type_info, layout, output_pads[i]);
+ output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size());
+ outputs[i]->setBuffer(output_buffers[i].get());
+ }
+
+ auto mockup_layer = std::make_unique<MockUpLayer>(
+ std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()},
+ std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(),
+ outputs[3].get()});
+ mockup_layer->run();
+
+ for (size_t i = 0; i < 4; ++i)
+ {
+ for (int32_t j = 0; j < shapes[i].dim(0); ++j)
+ {
+ for (int32_t k = 0; k < shapes[i].dim(1); ++k)
+ {
+ for (int32_t l = 0; l < shapes[i].dim(2); ++l)
+ {
+ for (int32_t m = 0; m < shapes[i].dim(3); ++m)
+ {
+ Coordinates input_coords;
+ Coordinates output_coords;
+ if (inputs[i]->layout() == Layout::NHWC)
+ {
+ input_coords = Coordinates{j, k, l, m};
+ }
+ else
+ {
+ input_coords = Coordinates{j, m, k, l};
+ }
+ if (outputs[i]->layout() == Layout::NHWC)
+ {
+ output_coords = Coordinates{j, k, l, m};
+ }
+ else
+ {
+ output_coords = Coordinates{j, m, k, l};
+ }
+ float result = *reinterpret_cast<float *>(outputs[i]->buffer() +
+ outputs[i]->calcOffset(output_coords));
+ float expected = *reinterpret_cast<float *>(inputs[i]->buffer() +
+ inputs[i]->calcOffset(input_coords));
+ EXPECT_EQ(result, expected);
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(IPermuteFunction, float_to_qasymm8)
+{
+ const size_t input_pads[4] = {0, 0, 1, 2};
+ const size_t output_pads[4] = {0, 3, 2, 1};
+ const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}};
+ float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70, -80, 90, -100};
+ float scale = 10;
+ int32_t zero_point = 128;
+
+ std::vector<std::unique_ptr<MockUpTensor>> inputs(4);
+ std::vector<std::unique_ptr<MockUpTensor>> outputs(4);
+ std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ inputs[i] = std::make_unique<MockUpTensor>(shapes[i], TypeInfo(DataType::FLOAT32), Layout::NHWC,
+ input_pads[i]);
+ inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer));
+
+ TypeInfo type_info{DataType::QUANT_UINT8_ASYMM, scale, zero_point};
+ outputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, output_pads[i]);
+ output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size());
+ outputs[i]->setBuffer(output_buffers[i].get());
+ }
+
+ auto mockup_layer = std::make_unique<MockUpLayer>(
+ std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()},
+ std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), outputs[3].get()});
+ mockup_layer->run();
+
+ for (size_t i = 0; i < 4; ++i)
+ {
+ for (int32_t j = 0; j < shapes[i].dim(0); ++j)
+ {
+ for (int32_t k = 0; k < shapes[i].dim(1); ++k)
+ {
+ for (int32_t l = 0; l < shapes[i].dim(2); ++l)
+ {
+ for (int32_t m = 0; m < shapes[i].dim(3); ++m)
+ {
+ Coordinates coords{j, k, l, m};
+ uint8_t qasymm8 =
+ *reinterpret_cast<uint8_t *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords));
+ float result = (qasymm8 - zero_point) * scale;
+ float expected =
+ *reinterpret_cast<float *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords));
+ EXPECT_EQ(result, expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(IPermuteFunction, float_to_qsymm8)
+{
+ const size_t input_pads[4] = {0, 0, 1, 2};
+ const size_t output_pads[4] = {0, 3, 2, 1};
+ const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}};
+ float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70, -80, 90, -100};
+ float scale = 10;
+ int32_t zero_point = 0;
+
+ std::vector<std::unique_ptr<MockUpTensor>> inputs(4);
+ std::vector<std::unique_ptr<MockUpTensor>> outputs(4);
+ std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ inputs[i] = std::make_unique<MockUpTensor>(shapes[i], TypeInfo(DataType::FLOAT32), Layout::NHWC,
+ input_pads[i]);
+ inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer));
+
+ TypeInfo type_info{DataType::QUANT_INT8_SYMM, scale, zero_point};
+ outputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, output_pads[i]);
+ output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size());
+ outputs[i]->setBuffer(output_buffers[i].get());
+ }
+
+ auto mockup_layer = std::make_unique<MockUpLayer>(
+ std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()},
+ std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), outputs[3].get()});
+ mockup_layer->run();
+
+ for (size_t i = 0; i < 4; ++i)
+ {
+ for (int32_t j = 0; j < shapes[i].dim(0); ++j)
+ {
+ for (int32_t k = 0; k < shapes[i].dim(1); ++k)
+ {
+ for (int32_t l = 0; l < shapes[i].dim(2); ++l)
+ {
+ for (int32_t m = 0; m < shapes[i].dim(3); ++m)
+ {
+ Coordinates coords{j, k, l, m};
+ int8_t qsymm8 =
+ *reinterpret_cast<int8_t *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords));
+ float result = (qsymm8 - zero_point) * scale;
+ float expected =
+ *reinterpret_cast<float *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords));
+ EXPECT_EQ(result, expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(IPermuteFunction, float_to_qsymm16)
+{
+ const size_t input_pads[4] = {0, 0, 1, 2};
+ const size_t output_pads[4] = {0, 3, 2, 1};
+ const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}};
+ float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70, -80, 90, -100};
+ float scale = 10;
+ int32_t zero_point = 0;
+
+ std::vector<std::unique_ptr<MockUpTensor>> inputs(4);
+ std::vector<std::unique_ptr<MockUpTensor>> outputs(4);
+ std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ inputs[i] = std::make_unique<MockUpTensor>(shapes[i], TypeInfo(DataType::FLOAT32), Layout::NHWC,
+ input_pads[i]);
+ inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer));
+
+ TypeInfo type_info{DataType::QUANT_INT16_SYMM, scale, zero_point};
+ outputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, output_pads[i]);
+ output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size());
+ outputs[i]->setBuffer(output_buffers[i].get());
+ }
+
+ auto mockup_layer = std::make_unique<MockUpLayer>(
+ std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()},
+ std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), outputs[3].get()});
+ mockup_layer->run();
+
+ for (size_t i = 0; i < 4; ++i)
+ {
+ for (int32_t j = 0; j < shapes[i].dim(0); ++j)
+ {
+ for (int32_t k = 0; k < shapes[i].dim(1); ++k)
+ {
+ for (int32_t l = 0; l < shapes[i].dim(2); ++l)
+ {
+ for (int32_t m = 0; m < shapes[i].dim(3); ++m)
+ {
+ Coordinates coords{j, k, l, m};
+ int16_t qsymm16 =
+ *reinterpret_cast<int16_t *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords));
+ float result = (qsymm16 - zero_point) * scale;
+ float expected =
+ *reinterpret_cast<float *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords));
+ EXPECT_EQ(result, expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(IPermuteFunction, qasymm8_to_float)
+{
+ const size_t input_pads[4] = {0, 0, 1, 2};
+ const size_t output_pads[4] = {0, 3, 2, 1};
+ const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}};
+ float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70, -80, 90, -100};
+ float scale = 10;
+ int32_t zero_point = 128;
+ uint8_t input_buffer[12];
+
+ int32_t min_val = std::numeric_limits<uint8_t>::min();
+ int32_t max_val = std::numeric_limits<uint8_t>::max();
+ for (int32_t i = 0; i < sizeof(expected_buffer) / sizeof(float); ++i)
+ {
+ int32_t unclamped = static_cast<int32_t>(std::round(expected_buffer[i] / scale)) + zero_point;
+ input_buffer[i] = std::min(std::max(unclamped, min_val), max_val);
+ }
+
+ std::vector<std::unique_ptr<MockUpTensor>> inputs(4);
+ std::vector<std::unique_ptr<MockUpTensor>> outputs(4);
+ std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ TypeInfo type_info{DataType::QUANT_UINT8_ASYMM, scale, zero_point};
+ inputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, input_pads[i]);
+ inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(input_buffer));
+
+ outputs[i] = std::make_unique<MockUpTensor>(shapes[i], TypeInfo(DataType::FLOAT32),
+ Layout::NHWC, output_pads[i]);
+ output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size());
+ outputs[i]->setBuffer(output_buffers[i].get());
+ }
+
+ auto mockup_layer = std::make_unique<MockUpLayer>(
+ std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()},
+ std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), outputs[3].get()});
+ mockup_layer->run();
+
+ for (size_t i = 0; i < 4; ++i)
+ {
+ for (int32_t j = 0; j < shapes[i].dim(0); ++j)
+ {
+ for (int32_t k = 0; k < shapes[i].dim(1); ++k)
+ {
+ for (int32_t l = 0; l < shapes[i].dim(2); ++l)
+ {
+ for (int32_t m = 0; m < shapes[i].dim(3); ++m)
+ {
+ Coordinates coords{j, k, l, m};
+ float result =
+ *reinterpret_cast<float *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords));
+ uint8_t qasymm8 =
+ *reinterpret_cast<uint8_t *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords));
+ float expected = (qasymm8 - zero_point) * scale;
+ EXPECT_EQ(result, expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(IPermuteFunction, qsymm8_to_float)
+{
+ const size_t input_pads[4] = {0, 0, 1, 2};
+ const size_t output_pads[4] = {0, 3, 2, 1};
+ const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}};
+ float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70, -80, 90, -100};
+ float scale = 10;
+ int32_t zero_point = 0;
+ uint8_t input_buffer[12];
+
+ int32_t min_val = std::numeric_limits<int8_t>::min();
+ int32_t max_val = std::numeric_limits<int8_t>::max();
+ for (int32_t i = 0; i < sizeof(expected_buffer) / sizeof(float); ++i)
+ {
+ int32_t unclamped = static_cast<int32_t>(std::round(expected_buffer[i] / scale)) + zero_point;
+ input_buffer[i] = std::min(std::max(unclamped, min_val), max_val);
+ }
+
+ std::vector<std::unique_ptr<MockUpTensor>> inputs(4);
+ std::vector<std::unique_ptr<MockUpTensor>> outputs(4);
+ std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ TypeInfo type_info{DataType::QUANT_INT8_SYMM, scale, zero_point};
+ inputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, input_pads[i]);
+ inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(input_buffer));
+
+ outputs[i] = std::make_unique<MockUpTensor>(shapes[i], TypeInfo(DataType::FLOAT32),
+ Layout::NHWC, output_pads[i]);
+ output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size());
+ outputs[i]->setBuffer(output_buffers[i].get());
+ }
+
+ auto mockup_layer = std::make_unique<MockUpLayer>(
+ std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()},
+ std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), outputs[3].get()});
+ mockup_layer->run();
+
+ for (size_t i = 0; i < 4; ++i)
+ {
+ for (int32_t j = 0; j < shapes[i].dim(0); ++j)
+ {
+ for (int32_t k = 0; k < shapes[i].dim(1); ++k)
+ {
+ for (int32_t l = 0; l < shapes[i].dim(2); ++l)
+ {
+ for (int32_t m = 0; m < shapes[i].dim(3); ++m)
+ {
+ Coordinates coords{j, k, l, m};
+ float result =
+ *reinterpret_cast<float *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords));
+ int8_t qasymm8 =
+ *reinterpret_cast<int8_t *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords));
+ float expected = (qasymm8 - zero_point) * scale;
+ EXPECT_EQ(result, expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(IPermuteFunction, qsymm16_to_float)
+{
+ const size_t input_pads[4] = {0, 0, 1, 2};
+ const size_t output_pads[4] = {0, 3, 2, 1};
+ const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}};
+ float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70, -80, 90, -100};
+ float scale = 10;
+ int32_t zero_point = 0;
+ uint8_t input_buffer[12];
+
+ int32_t min_val = std::numeric_limits<int16_t>::min();
+ int32_t max_val = std::numeric_limits<int16_t>::max();
+ for (int32_t i = 0; i < sizeof(expected_buffer) / sizeof(float); ++i)
+ {
+ int32_t unclamped = static_cast<int32_t>(std::round(expected_buffer[i] / scale)) + zero_point;
+ input_buffer[i] = std::min(std::max(unclamped, min_val), max_val);
+ }
+
+ std::vector<std::unique_ptr<MockUpTensor>> inputs(4);
+ std::vector<std::unique_ptr<MockUpTensor>> outputs(4);
+ std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ TypeInfo type_info{DataType::QUANT_INT16_SYMM, scale, zero_point};
+ inputs[i] = std::make_unique<MockUpTensor>(shapes[i], type_info, Layout::NHWC, input_pads[i]);
+ inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(input_buffer));
+
+ outputs[i] = std::make_unique<MockUpTensor>(shapes[i], TypeInfo(DataType::FLOAT32),
+ Layout::NHWC, output_pads[i]);
+ output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size());
+ outputs[i]->setBuffer(output_buffers[i].get());
+ }
+
+ auto mockup_layer = std::make_unique<MockUpLayer>(
+ std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()},
+ std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(), outputs[3].get()});
+ mockup_layer->run();
+
+ for (size_t i = 0; i < 4; ++i)
+ {
+ for (int32_t j = 0; j < shapes[i].dim(0); ++j)
+ {
+ for (int32_t k = 0; k < shapes[i].dim(1); ++k)
+ {
+ for (int32_t l = 0; l < shapes[i].dim(2); ++l)
+ {
+ for (int32_t m = 0; m < shapes[i].dim(3); ++m)
+ {
+ Coordinates coords{j, k, l, m};
+ float result =
+ *reinterpret_cast<float *>(outputs[i]->buffer() + outputs[i]->calcOffset(coords));
+ int16_t qasymm8 =
+ *reinterpret_cast<int16_t *>(inputs[i]->buffer() + inputs[i]->calcOffset(coords));
+ float expected = (qasymm8 - zero_point) * scale;
+ EXPECT_EQ(result, expected);
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(IPermuteFunction, float_qasymm8_layout)
+{
+ // float -> quasymm8
+ {
+ const size_t input_pads[4] = {0, 0, 1, 2};
+ const size_t output_pads[4] = {0, 3, 2, 1};
+ const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}};
+ float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70,
+ -80, 90, -100, 110, -120, 130, -140, 150, -160};
+ float scale = 10;
+ int32_t zero_point = 128;
+
+ std::vector<std::unique_ptr<MockUpTensor>> inputs(4);
+ std::vector<std::unique_ptr<MockUpTensor>> outputs(4);
+ std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ Layout layout = Layout::NHWC;
+ Shape shape = shapes[i];
+ if (i % 2 == 1)
+ {
+ layout = Layout::NCHW;
+ shape = Shape{shapes[i].dim(0), shapes[i].dim(3), shapes[i].dim(1), shapes[i].dim(2)};
+ }
+ inputs[i] =
+ std::make_unique<MockUpTensor>(shape, TypeInfo(DataType::FLOAT32), layout, input_pads[i]);
+ inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer));
+
+ if (layout == Layout::NHWC)
+ {
+ layout = Layout::NCHW;
+ shape = Shape{shapes[i].dim(0), shapes[i].dim(3), shapes[i].dim(1), shapes[i].dim(2)};
+ }
+ else
+ {
+ layout = Layout::NHWC;
+ shape = shapes[i];
+ }
+ TypeInfo type_info{DataType::QUANT_UINT8_ASYMM, scale, zero_point};
+ outputs[i] = std::make_unique<MockUpTensor>(shape, type_info, layout, output_pads[i]);
+ output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size());
+ outputs[i]->setBuffer(output_buffers[i].get());
+ }
+
+ auto mockup_layer = std::make_unique<MockUpLayer>(
+ std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()},
+ std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(),
+ outputs[3].get()});
+ mockup_layer->run();
+
+ for (size_t i = 0; i < 4; ++i)
+ {
+ for (int32_t j = 0; j < shapes[i].dim(0); ++j)
+ {
+ for (int32_t k = 0; k < shapes[i].dim(1); ++k)
+ {
+ for (int32_t l = 0; l < shapes[i].dim(2); ++l)
+ {
+ for (int32_t m = 0; m < shapes[i].dim(3); ++m)
+ {
+ Coordinates input_coords;
+ Coordinates output_coords;
+ if (inputs[i]->layout() == Layout::NHWC)
+ {
+ input_coords = Coordinates{j, k, l, m};
+ }
+ else
+ {
+ input_coords = Coordinates{j, m, k, l};
+ }
+ if (outputs[i]->layout() == Layout::NHWC)
+ {
+ output_coords = Coordinates{j, k, l, m};
+ }
+ else
+ {
+ output_coords = Coordinates{j, m, k, l};
+ }
+ uint8_t qasymm8 = *reinterpret_cast<uint8_t *>(outputs[i]->buffer() +
+ outputs[i]->calcOffset(output_coords));
+ float result = (qasymm8 - zero_point) * scale;
+ float expected = *reinterpret_cast<float *>(inputs[i]->buffer() +
+ inputs[i]->calcOffset(input_coords));
+ EXPECT_EQ(result, expected);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // qasymm8 -> float
+ {
+ const size_t input_pads[4] = {0, 0, 1, 2};
+ const size_t output_pads[4] = {0, 3, 2, 1};
+ const std::vector<Shape> shapes{{1, 1, 4, 1}, {2, 1, 2, 3}, {1, 2, 1, 2}, {1, 1, 2, 3}};
+ float expected_buffer[] = {10, 0, -10, -20, 30, -40, 50, -60, 70,
+ -80, 90, -100, 110, -120, 130, -140, 150, -160};
+ float scale = 10;
+ int32_t zero_point = 128;
+ uint8_t input_buffer[18];
+
+ int32_t min_val = std::numeric_limits<int16_t>::min();
+ int32_t max_val = std::numeric_limits<int16_t>::max();
+ for (int32_t i = 0; i < sizeof(expected_buffer) / sizeof(float); ++i)
+ {
+ int32_t unclamped = static_cast<int32_t>(std::round(expected_buffer[i] / scale)) + zero_point;
+ input_buffer[i] = std::min(std::max(unclamped, min_val), max_val);
+ }
+
+ std::vector<std::unique_ptr<MockUpTensor>> inputs(4);
+ std::vector<std::unique_ptr<MockUpTensor>> outputs(4);
+ std::vector<std::unique_ptr<uint8_t[]>> output_buffers(4);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ Layout layout = Layout::NHWC;
+ Shape shape = shapes[i];
+ if (i % 2 == 1)
+ {
+ layout = Layout::NCHW;
+ shape = Shape{shapes[i].dim(0), shapes[i].dim(3), shapes[i].dim(1), shapes[i].dim(2)};
+ }
+ TypeInfo type_info{DataType::QUANT_UINT8_ASYMM, scale, zero_point};
+ inputs[i] = std::make_unique<MockUpTensor>(shape, type_info, layout, input_pads[i]);
+ inputs[i]->setBuffer(reinterpret_cast<uint8_t *>(expected_buffer));
+
+ if (layout == Layout::NHWC)
+ {
+ layout = Layout::NCHW;
+ shape = Shape{shapes[i].dim(0), shapes[i].dim(3), shapes[i].dim(1), shapes[i].dim(2)};
+ }
+ else
+ {
+ layout = Layout::NHWC;
+ shape = shapes[i];
+ }
+ outputs[i] =
+ std::make_unique<MockUpTensor>(shape, TypeInfo(DataType::FLOAT32), layout, output_pads[i]);
+ output_buffers[i] = std::make_unique<uint8_t[]>(outputs[i]->total_size());
+ outputs[i]->setBuffer(output_buffers[i].get());
+ }
+
+ auto mockup_layer = std::make_unique<MockUpLayer>(
+ std::vector<ITensor *>{inputs[0].get(), inputs[1].get(), inputs[2].get(), inputs[3].get()},
+ std::vector<ITensor *>{outputs[0].get(), outputs[1].get(), outputs[2].get(),
+ outputs[3].get()});
+ mockup_layer->run();
+
+ for (size_t i = 0; i < 4; ++i)
+ {
+ for (int32_t j = 0; j < shapes[i].dim(0); ++j)
+ {
+ for (int32_t k = 0; k < shapes[i].dim(1); ++k)
+ {
+ for (int32_t l = 0; l < shapes[i].dim(2); ++l)
+ {
+ for (int32_t m = 0; m < shapes[i].dim(3); ++m)
+ {
+ Coordinates input_coords;
+ Coordinates output_coords;
+ if (inputs[i]->layout() == Layout::NHWC)
+ {
+ input_coords = Coordinates{j, k, l, m};
+ }
+ else
+ {
+ input_coords = Coordinates{j, m, k, l};
+ }
+ if (outputs[i]->layout() == Layout::NHWC)
+ {
+ output_coords = Coordinates{j, k, l, m};
+ }
+ else
+ {
+ output_coords = Coordinates{j, m, k, l};
+ }
+ float result = *reinterpret_cast<float *>(outputs[i]->buffer() +
+ outputs[i]->calcOffset(output_coords));
+ uint8_t qasymm8 = *reinterpret_cast<uint8_t *>(inputs[i]->buffer() +
+ inputs[i]->calcOffset(input_coords));
+ float expected = (qasymm8 - zero_point) * scale;
+ EXPECT_EQ(result, expected);
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+} // namespace
diff --git a/runtime/onert/core/src/exec/JSONExecTime.cc b/runtime/onert/core/src/exec/JSONExecTime.cc
index 72a18def1..d149345fd 100644
--- a/runtime/onert/core/src/exec/JSONExecTime.cc
+++ b/runtime/onert/core/src/exec/JSONExecTime.cc
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#include "exec/JSONExecTime.h"
-#include "backend/IConfig.h"
+#include "JSONExecTime.h"
+
#include <fstream>
namespace onert
@@ -135,7 +135,7 @@ void JSON::printOperation(const std::map<uint32_t, int64_t> &operation_info,
stream.seekp(-2, std::ofstream::end);
}
-void JSON::uploadOperationsExecTime() const
+void JSON::storeOperationsExecTime() const
{
std::ofstream stream(_measurement_file);
if (!stream.is_open())
diff --git a/runtime/onert/core/src/exec/JSONExecTime.h b/runtime/onert/core/src/exec/JSONExecTime.h
index a64cb3133..e01723611 100644
--- a/runtime/onert/core/src/exec/JSONExecTime.h
+++ b/runtime/onert/core/src/exec/JSONExecTime.h
@@ -37,15 +37,15 @@ namespace exec
* _measurements[Backend*]["string"][bool][uint32_t] = int64_t
*/
using MeasurementData = std::unordered_map<
- const backend::Backend *,
- std::unordered_map<std::string, std::unordered_map<bool, std::map<uint32_t, int64_t>>>>;
+ const backend::Backend *,
+ std::unordered_map<std::string, std::unordered_map<bool, std::map<uint32_t, int64_t>>>>;
class JSON
{
public:
explicit JSON(const std::vector<const backend::Backend *> &backends,
MeasurementData &measurements)
- : _measurement_file("exec_time.json"), _backends(), _measurements(measurements)
+ : _measurement_file("exec_time.json"), _backends(), _measurements(measurements)
{
for (const auto b : backends)
{
@@ -54,18 +54,16 @@ public:
loadOperationsExecTime();
};
/**
- * @brief Update _operations_exec_time_file with new data.
+ * @brief Update _measurement_file with new data.
*/
- void uploadOperationsExecTime() const;
+ void storeOperationsExecTime() const;
private:
///@brief file containing measurements
std::string _measurement_file;
std::unordered_map<std::string, const backend::Backend *> _backends;
- std::unordered_map<
- const backend::Backend *,
- std::unordered_map<std::string, std::unordered_map<bool, std::map<uint32_t, int64_t>>>>
- &_measurements;
+ MeasurementData &_measurements;
+
/**
* @brief Helper function for inserting data to OperationExecTimes
*
@@ -86,7 +84,7 @@ private:
void printOperation(const std::map<uint32_t, int64_t> &operation_info,
std::ofstream &stream) const;
/**
- * @brief Parse and load operations_exec_time from _operations_exec_time_file.
+ * @brief Parse and load _measurements from _measurement_file.
*/
void loadOperationsExecTime();
};
diff --git a/runtime/onert/core/src/exec/LinearExecutor.cc b/runtime/onert/core/src/exec/LinearExecutor.cc
index 69dfe9b9b..228c4d3c0 100644
--- a/runtime/onert/core/src/exec/LinearExecutor.cc
+++ b/runtime/onert/core/src/exec/LinearExecutor.cc
@@ -24,41 +24,54 @@ namespace onert
namespace exec
{
-#ifdef RUY_PROFILER
-namespace
-{
-char *seq_to_label(const onert::ir::OpSequence *op_seq, const onert::ir::Operations &operations)
+void LinearExecutor::executeImpl(const ExecutionObservee &subject)
{
- auto node_name = operations.at(*op_seq->begin()).name();
- char *cstr = new char[node_name.length() + 1];
- std::strcpy(cstr, node_name.c_str());
- return cstr;
-}
-} // namespace
+ if (!subject.isEmpty() && _tracing_ctx)
+ {
+ auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_graph);
+
+ subject.notifySubgraphBegin(profiling_subg_index);
+ for (auto &&code : _code)
+ {
+ const auto backend = code.lower_info->backend();
+// TODO : Move ruy profiler into ExecutionObserver
+#ifdef RUY_PROFILER
+ ruy::profiler::ScopeLabel label(code.op->name());
#endif
+ subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend);
-void LinearExecutor::executeImpl()
-{
- _subject.notifyModelBegin(this);
- for (auto &&code : _code)
+ auto &fn_seq = code.fn_seq;
+
+ fn_seq->initRunning();
+
+ bool handle_dynamic_tensor =
+ _lowered_graph->getHasDynamicTensor(code.op_ind) || hasDynamicInput();
+ fn_seq->enableDynamicShapeInferer(handle_dynamic_tensor);
+ fn_seq->run();
+
+ subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend);
+ }
+ subject.notifySubgraphEnd(profiling_subg_index);
+ }
+ else
{
- const auto op_seq = code.op_seq;
- const auto backend = code.lower_info->backend();
+ for (auto &&code : _code)
+ {
// TODO : Move ruy profiler into ExecutionObserver
#ifdef RUY_PROFILER
- ruy::profiler::ScopeLabel label(seq_to_label(op_seq, _graph.operations()));
+ ruy::profiler::ScopeLabel label(code.op->name());
#endif
- _subject.notifyJobBegin(this, op_seq, backend);
- auto &fn_seq = code.fn_seq;
- bool handle_dynamic_tensor = op_seq->has_dynamic_tensor() || hasDynamicInput();
+ auto &fn_seq = code.fn_seq;
- fn_seq->enableDynamicShapeInferer(handle_dynamic_tensor);
- fn_seq->run();
+ fn_seq->initRunning();
- _subject.notifyJobEnd(this, op_seq, backend);
+ bool handle_dynamic_tensor =
+ _lowered_graph->getHasDynamicTensor(code.op_ind) || hasDynamicInput();
+ fn_seq->enableDynamicShapeInferer(handle_dynamic_tensor);
+ fn_seq->run();
+ }
}
- _subject.notifyModelEnd(this);
}
} // namespace exec
diff --git a/runtime/onert/core/src/exec/LinearExecutor.h b/runtime/onert/core/src/exec/LinearExecutor.h
index c224d3f4f..853632a4e 100644
--- a/runtime/onert/core/src/exec/LinearExecutor.h
+++ b/runtime/onert/core/src/exec/LinearExecutor.h
@@ -22,11 +22,11 @@
#ifndef __ONERT_EXEC_EXECUTOR_H_
#define __ONERT_EXEC_EXECUTOR_H_
-#include "ir/Index.h"
#include "ExecutorBase.h"
-#include "compiler/Linear.h"
-#include "exec/FunctionSequence.h"
+
#include "compiler/CodeMap.h"
+#include "ir/Index.h"
+#include "util/TracingCtx.h"
namespace onert
{
@@ -44,25 +44,22 @@ public:
* @brief Construct a new LinearExecutor object
* @param lowered_graph LoweredGraph object
* @param tensor_builders Tensor builders that are currently used
- * @param code_map OpSequence and its code map
+ * @param code_map @c ir::Operation and its code map
*/
LinearExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors,
- const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors,
- const compiler::TensorRegistries &tensor_regs,
- backend::TensorManagerSet &&tensor_mgrs, compiler::CodeMap &&code_map,
- const std::vector<ir::OpSequenceIndex> &order)
- : ExecutorBase{std::move(lowered_graph), input_tensors, output_tensors, tensor_regs,
- std::move(tensor_mgrs)}
+ backend::BackendContexts &&backend_contexts,
+ const compiler::TensorRegistries &tensor_regs, compiler::CodeMap &&code_map,
+ const std::vector<ir::OperationIndex> &order, const util::TracingCtx *tracing_ctx)
+ : ExecutorBase{std::move(lowered_graph), std::move(backend_contexts), tensor_regs, tracing_ctx}
{
- for (auto index : order)
+ for (auto &&index : order)
{
_code.emplace_back(std::move(code_map.at(index)));
}
}
public:
- void executeImpl(void) override;
+ void executeImpl(const ExecutionObservee &subject) override;
private:
std::vector<compiler::CodeAndInfo> _code;
diff --git a/runtime/onert/core/src/exec/MinMaxData.cc b/runtime/onert/core/src/exec/MinMaxData.cc
new file mode 100644
index 000000000..1d18252e8
--- /dev/null
+++ b/runtime/onert/core/src/exec/MinMaxData.cc
@@ -0,0 +1,135 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MinMaxData.h"
+
+#include <iostream>
+
+namespace onert
+{
+namespace exec
+{
+
+RawMinMaxDumper::RawMinMaxDumper(const std::string &filename) : _filename(filename) {}
+
+void RawMinMaxDumper::dump(const exec::IOMinMaxMap &input_minmax,
+ const exec::OpMinMaxMap &op_minmax) const
+{
+ // Find file is already exist for modifying
+ auto file = std::fopen(_filename.c_str(), "rb+");
+ uint32_t runs = 1;
+
+ // Magic code and version
+ // Match with runtime/onert/odc/MinMaxReader.cc
+ // TODO Use util to share code and version
+ const uint32_t MAGIC_CODE = 0x4F4D4D44;
+ const uint32_t VERSION = 1;
+ if (!file)
+ {
+ // If file is not exist, create new file
+ file = std::fopen(_filename.c_str(), "wb+");
+ if (!file)
+ throw std::runtime_error{"RawMinMaxDumper: Failed to open minmax file " + _filename};
+
+ // Write magic code and version
+ std::fwrite(&MAGIC_CODE, sizeof(uint32_t), 1, file);
+ std::fwrite(&VERSION, sizeof(uint32_t), 1, file);
+ }
+ else
+ {
+ // Check magic code and version
+ std::fseek(file, 0, SEEK_SET);
+ uint32_t read_magic_code = 0;
+ uint32_t read_version = 0;
+ bool rewrite = true;
+ if (std::fread(&read_magic_code, sizeof(uint32_t), 1, file) == 1 &&
+ read_magic_code == MAGIC_CODE &&
+ std::fread(&read_version, sizeof(uint32_t), 1, file) == 1 && read_version == VERSION)
+ rewrite = false;
+
+ // Destroy and create if file is not valid
+ if (rewrite)
+ {
+ std::fclose(file);
+ file = std::fopen(_filename.c_str(), "wb+");
+ if (!file)
+ throw std::runtime_error{"RawMinMaxDumper: Failed to rewrite minmax file " + _filename};
+
+ // Write magic code and version
+ std::fwrite(&MAGIC_CODE, sizeof(uint32_t), 1, file);
+ std::fwrite(&VERSION, sizeof(uint32_t), 1, file);
+ }
+ }
+
+ // Read run count
+ if (std::fread(&runs, sizeof(uint32_t), 1, file) == 1)
+ runs++;
+ else
+ runs = 1;
+
+ // TODO Verify file size
+
+ // Overwrite run count
+ std::fseek(file, sizeof(MAGIC_CODE) + sizeof(VERSION), SEEK_SET);
+ std::fwrite(&runs, sizeof(uint32_t), 1, file);
+
+ // Go to end of file to append new data
+ std::fseek(file, 0, SEEK_END);
+
+ uint32_t input_count = input_minmax.size();
+ uint32_t op_count = op_minmax.size();
+
+ // Write op_count and input_count
+ std::fwrite(&op_count, sizeof(uint32_t), 1, file);
+ std::fwrite(&input_count, sizeof(uint32_t), 1, file);
+
+ // For each op
+ for (auto &&elem : op_minmax)
+ {
+ const uint32_t model_idx = 0;
+ const uint32_t subg_idx = elem.first.first.value();
+ const uint32_t op_idx = elem.first.second.value();
+
+ // Write model/subg/op index
+ std::fwrite(&model_idx, sizeof(uint32_t), 1, file);
+ std::fwrite(&subg_idx, sizeof(uint32_t), 1, file);
+ std::fwrite(&op_idx, sizeof(uint32_t), 1, file);
+
+ // Write min/max
+ std::fwrite(elem.second.data, sizeof(float), 2, file);
+ }
+
+ // For each input
+ for (auto &&elem : input_minmax)
+ {
+ const uint32_t model_idx = 0;
+ const uint32_t subg_idx = elem.first.first.value();
+ const uint32_t input_idx = elem.first.second.value();
+
+ // Write model/subg/input index
+ std::fwrite(&model_idx, sizeof(uint32_t), 1, file);
+ std::fwrite(&subg_idx, sizeof(uint32_t), 1, file);
+ std::fwrite(&input_idx, sizeof(uint32_t), 1, file);
+
+ // Write min/max
+ std::fwrite(elem.second.data, sizeof(float), 2, file);
+ }
+
+ std::fclose(file);
+}
+
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/MinMaxData.h b/runtime/onert/core/src/exec/MinMaxData.h
new file mode 100644
index 000000000..2538d444c
--- /dev/null
+++ b/runtime/onert/core/src/exec/MinMaxData.h
@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_EXEC_MINMAX_DATA_H__
+#define __ONERT_EXEC_MINMAX_DATA_H__
+
+#include "exec/MinMaxMap.h"
+
+#include <string>
+
+namespace onert
+{
+namespace exec
+{
+
+// Because IOMinMaxMap and OpMinMaxMap does not have the ordering and size information,
+// we need to dump model, subgraph id for each minmax
+
+// File structure
+// uint32_t magic code
+// uint32_t version
+// uint32_t num of runs
+
+// For each run
+// uint32_t num of operations
+// uint32_t num of inputs
+
+// For each operation
+// uint32_t model id
+// uint32_t subgraph id
+// uint32_t operation id
+// float min
+// float max
+
+// For each input
+// uint32_t model id
+// uint32_t subgraph id
+// uint32_t input id
+// float min
+// float max
+
+class RawMinMaxDumper
+{
+public:
+ RawMinMaxDumper(const std::string &filename);
+ /**
+ * @brief Dump input minmax map
+ *
+ * @param[in] in_minmax input minmax map
+ * @param[in] op_minmax op minmax map
+ */
+
+ void dump(const exec::IOMinMaxMap &in_minmax, const exec::OpMinMaxMap &op_minmax) const;
+
+private:
+ std::string _filename;
+};
+
+} // namespace exec
+} // namespace onert
+
+#endif // __ONERT_EXEC_MINMAX_DATA_H__
diff --git a/runtime/onert/core/src/exec/MinMaxRecorder.cc b/runtime/onert/core/src/exec/MinMaxRecorder.cc
new file mode 100644
index 000000000..179800011
--- /dev/null
+++ b/runtime/onert/core/src/exec/MinMaxRecorder.cc
@@ -0,0 +1,161 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MinMaxRecorder.h"
+#if MINMAX_H5DUMPER
+#include "../dumper/h5/MinMaxDumper.h"
+#else
+#include "MinMaxData.h"
+#endif
+#include "backend/ITensor.h"
+
+#include <cassert>
+#include <cmath>
+
+namespace onert
+{
+namespace exec
+{
+
+MinMaxRecorder::MinMaxRecorder(const std::string &workspace_dir, const ir::Graph &graph,
+ const backend::BackendContexts &backend_contexts)
+ : _graph{graph}, _backend_contexts{backend_contexts}, _workspace_dir(workspace_dir)
+{
+ // DO NOTHING
+}
+
+std::pair<float, float> minmaxFrom(const backend::ITensor *tensor)
+{
+ const auto data = reinterpret_cast<float *>(tensor->buffer());
+ const auto num_elements = tensor->total_size() / sizeof(float);
+
+ float max = std::numeric_limits<float>::lowest();
+ float min = std::numeric_limits<float>::max();
+
+ bool all_nan = true;
+ for (size_t i = 0; i < num_elements; ++i)
+ {
+ const float number = data[i];
+ if (std::isnan(number))
+ continue;
+
+ if (number == std::numeric_limits<float>::lowest())
+ continue;
+
+ all_nan = false;
+
+ if (number > max)
+ max = number;
+
+ if (number < min)
+ min = number;
+ }
+
+ if (all_nan)
+ throw std::runtime_error("All values are NaN(Not a Number)");
+
+ return {min, max};
+}
+
+void MinMaxRecorder::handleJobEnd(IExecutor *, ir::SubgraphIndex subg_idx,
+ ir::OperationIndex op_idx, const backend::Backend *backend)
+{
+ const auto &tensor_reg = _backend_contexts.at(backend)->tensor_registry;
+ const auto &op = _graph.operations().at(op_idx);
+ const auto &outputs = op.getOutputs();
+ // TODO: Support multiple output
+ if (outputs.size() != 1)
+ throw std::runtime_error("Only 1 output operator is supported for recording minmax.");
+
+ auto tensor = tensor_reg->getITensor(outputs.at(0));
+
+ // Logic copied from MinMaxObserver.cpp.
+
+ // Filter Ops
+ if (tensor->is_constant())
+ return;
+
+ if (tensor->data_type() != ir::DataType::FLOAT32)
+ return;
+
+ switch (op.opcode())
+ {
+ // Operators with multiple outputs
+ case ir::OpCode::If:
+ case ir::OpCode::Split:
+ case ir::OpCode::SplitV:
+ case ir::OpCode::TopKV2:
+ case ir::OpCode::Unpack:
+ case ir::OpCode::While:
+ return;
+ // NOTE: Sin, Cos, Tanh's output is in [-1, 1]
+ // We may not need to dump those operators.
+ default:; // Do Nothing
+ }
+
+ // Otherwise, dump!
+ assert(tensor->data_type() == ir::DataType::FLOAT32);
+ auto minmax = minmaxFrom(tensor);
+ _op_minmax.append({subg_idx, op_idx}, minmax.first, minmax.second);
+}
+
+void MinMaxRecorder::handleSubgraphBegin(ir::SubgraphIndex subg_idx)
+{
+ // Make sure there is only cpu backend except for builtin backend
+ std::set<std::string> backend_names;
+ backend::ITensorRegistry *tensor_reg = nullptr;
+ for (const auto &pair : _backend_contexts)
+ {
+ backend_names.insert(pair.first->config()->id());
+ if (pair.first->config()->id() == "cpu")
+ {
+ tensor_reg = pair.second->tensor_registry.get();
+ }
+ }
+ if (backend_names != std::set<std::string>{"builtin", "cpu"})
+ throw std::runtime_error("MinMaxRecorder must have cpu backend only.");
+
+ const auto &inputs = _graph.getInputs(); //.at(op_idx);
+ for (uint32_t i = 0; i < inputs.size(); ++i)
+ {
+ auto input_idx = inputs.at(i);
+ auto tensor = tensor_reg->getITensor(input_idx);
+
+ if (tensor->is_constant())
+ return;
+ if (tensor->data_type() != ir::DataType::FLOAT32)
+ return;
+
+ auto minmax = minmaxFrom(tensor);
+ _input_minmax.append({subg_idx, ir::IOIndex{i}}, minmax.first, minmax.second);
+ }
+}
+
+void MinMaxRecorder::handleSubgraphEnd(ir::SubgraphIndex)
+{
+ // It would be better to dump at the end of model execution, not subgraph
+ // But it requires more changes than subgraph.
+#if MINMAX_H5DUMPER
+ auto h5dumper = dumper::h5::MinMaxDumper(_workspace_dir + "/minmax.h5");
+ h5dumper.dump(_input_minmax, _op_minmax);
+#else
+ auto raw_dumper = RawMinMaxDumper(_workspace_dir + "/minmax.bin");
+ raw_dumper.dump(_input_minmax, _op_minmax);
+#endif
+}
+
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/MinMaxRecorder.h b/runtime/onert/core/src/exec/MinMaxRecorder.h
new file mode 100644
index 000000000..ed5163972
--- /dev/null
+++ b/runtime/onert/core/src/exec/MinMaxRecorder.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_EXEC_MINMAX_RECORDER__
+#define __ONERT_EXEC_MINMAX_RECORDER__
+
+#include "ExecutionObservers.h"
+#include "ir/Index.h"
+#include "exec/MinMaxMap.h"
+
+#include <string>
+
+namespace onert
+{
+namespace exec
+{
+
+class MinMaxRecorder : public IExecutionObserver
+{
+public:
+ MinMaxRecorder(const std::string &workspace_dir, const ir::Graph &graph,
+ const backend::BackendContexts &backend_contexts);
+ void handleJobBegin(IExecutor *, ir::SubgraphIndex, ir::OperationIndex,
+ const backend::Backend *) override
+ {
+ return;
+ }
+ void handleJobEnd(IExecutor *, ir::SubgraphIndex, ir::OperationIndex,
+ const backend::Backend *) override;
+ void handleSubgraphBegin(ir::SubgraphIndex) override;
+ void handleSubgraphEnd(ir::SubgraphIndex) override;
+ ObserverType type() const override { return ObserverType::MINMAX_DUMP; }
+
+private:
+ const ir::Graph &_graph;
+ const backend::BackendContexts &_backend_contexts;
+ std::string _workspace_dir;
+ OpMinMaxMap _op_minmax;
+ IOMinMaxMap _input_minmax;
+};
+
+} // namespace exec
+} // namespace onert
+
+#endif // __ONERT_EXEC_MINMAX_RECORDER__
diff --git a/runtime/onert/core/src/exec/MultiModelExecutors.cc b/runtime/onert/core/src/exec/MultiModelExecutors.cc
new file mode 100644
index 000000000..920b17d45
--- /dev/null
+++ b/runtime/onert/core/src/exec/MultiModelExecutors.cc
@@ -0,0 +1,589 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MultiModelExecutors.h"
+
+namespace
+{
+
+using namespace onert;
+
+int32_t find_input_index(const std::vector<ir::IODesc> &pkg_inputs,
+ const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index,
+ const ir::IOIndex &io_index)
+{
+ for (size_t i = 0; i < pkg_inputs.size(); i++)
+ {
+ auto &input_desc = pkg_inputs[i];
+ if ((std::get<ir::ModelIndex>(input_desc) == model_index) &&
+ (std::get<ir::SubgraphIndex>(input_desc) == subg_index) &&
+ (std::get<ir::IOIndex>(input_desc) == io_index))
+ return static_cast<int32_t>(i);
+ }
+ return -1;
+}
+
+int32_t find_output_index(const std::vector<ir::IODesc> &pkg_outputs,
+ const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index,
+ const ir::IOIndex &io_index)
+{
+ for (size_t i = 0; i < pkg_outputs.size(); i++)
+ {
+ auto &input_desc = pkg_outputs[i];
+ if ((std::get<ir::ModelIndex>(input_desc) == model_index) &&
+ (std::get<ir::SubgraphIndex>(input_desc) == subg_index) &&
+ (std::get<ir::IOIndex>(input_desc) == io_index))
+ return static_cast<int32_t>(i);
+ }
+ return -1;
+}
+
+} // namespace
+
+namespace onert
+{
+namespace exec
+{
+
+void MultiModelExecutors::emplace(const ir::ModelIndex &model_index,
+ const ir::SubgraphIndex &subg_index,
+ std::unique_ptr<IExecutor> exec)
+{
+ _executors.emplace(std::make_pair(model_index, subg_index), std::move(exec));
+}
+
+IExecutor *MultiModelExecutors::at(const ir::ModelIndex &model_index,
+ const ir::SubgraphIndex &subg_index) const
+{
+ return _executors.at(std::make_pair(model_index, subg_index)).get();
+}
+
+uint32_t MultiModelExecutors::inputSize() const { return _model_edges->pkg_inputs.size(); }
+
+uint32_t MultiModelExecutors::outputSize() const { return _model_edges->pkg_outputs.size(); }
+
+const ir::OperandInfo &MultiModelExecutors::inputInfo(const ir::IOIndex &index) const
+{
+ auto const desc = _model_edges->pkg_inputs[index.value()];
+ auto const model_index = std::get<0>(desc);
+ auto const subg_index = std::get<1>(desc);
+ auto const io_index = std::get<2>(desc);
+ auto const executor = at(model_index, subg_index);
+ return executor->inputInfo(io_index.value());
+}
+
+const ir::OperandInfo &MultiModelExecutors::outputInfo(const ir::IOIndex &index) const
+{
+ auto const desc = _model_edges->pkg_outputs[index.value()];
+ auto const model_index = std::get<0>(desc);
+ auto const subg_index = std::get<1>(desc);
+ auto const io_index = std::get<2>(desc);
+ auto const executor = at(model_index, subg_index);
+ return executor->outputInfo(io_index.value());
+}
+
+// Allow below edges only
+// m1 < m2, s1 == 0 and s2 == 0 if m1:s1:o1 -> m2:s2:o2'
+void MultiModelExecutors::checkSupportedMultimodel() const
+{
+ // If package includes no-connection model, model_count is less than real model count in package.
+ // Then this method will throw exception based on model index
+ // 1st model: input assumption
+ // Otherwise: edges assumption
+
+ // Assumption: edges
+ // m1 < m2, s1 == 0 and s2 == 0 if edge 'm1:s1:o1 -> m2:s2:o2'
+ for (auto &&edge : _model_edges->edges)
+ {
+ auto const model_from = std::get<ir::ModelIndex>(edge.from);
+ auto const model_to = std::get<ir::ModelIndex>(edge.to);
+ auto const subg_from = std::get<ir::SubgraphIndex>(edge.from);
+ auto const subg_to = std::get<ir::SubgraphIndex>(edge.to);
+
+ if (model_from.value() == model_to.value())
+ {
+ throw std::runtime_error{"Multi model's edge set has invalid edge"};
+ }
+
+ if ((model_from.value() > model_to.value()) || (subg_from != ir::SubgraphIndex{0}) ||
+ (subg_to != ir::SubgraphIndex{0}))
+ throw std::runtime_error{"NYI: Multi model execution for this edge set is not supported yet"};
+ }
+
+ // Assumption: package inputs
+ // All 1st model inputs come from package input if always m1 < m2
+ {
+ auto first_executor = at(ir::ModelIndex{0}, ir::SubgraphIndex{0});
+ auto search_first_model = [&](const ir::IOIndex &input_index) {
+ for (const auto &input : _model_edges->pkg_inputs)
+ {
+ if ((std::get<ir::ModelIndex>(input) == ir::ModelIndex{0}) ||
+ (std::get<ir::SubgraphIndex>(input) == ir::SubgraphIndex{0}) ||
+ (std::get<ir::IOIndex>(input) == input_index))
+ return true;
+ }
+
+ return false;
+ };
+
+ for (uint32_t i = 0; i < first_executor->inputSize(); i++)
+ {
+ if (!search_first_model(ir::IOIndex{i}))
+ throw std::runtime_error{"Cannot find 1st model's input buffer"};
+ }
+ }
+
+ // Check whether nnpkg outputs and Edge `from` are duplicated
+ for (const auto &edge : _model_edges->edges)
+ {
+ if (std::find(_model_edges->pkg_outputs.begin(), _model_edges->pkg_outputs.end(), edge.from) !=
+ _model_edges->pkg_outputs.end())
+ {
+ throw std::runtime_error{"Multi model execution does not support duplicating nnpkg outputs "
+ "with `from` of edges yet"};
+ }
+ }
+}
+
+void MultiModelExecutors::createEdgeQuantLayers()
+{
+ if (_is_created_edge_quant_layers)
+ {
+ return;
+ }
+
+ // Create EdgeTensor for edges between executors
+ for (const auto &pair : _edge_map)
+ {
+ const auto &from_iodesc = pair.first;
+ const auto &from_model_index = std::get<ir::ModelIndex>(from_iodesc);
+ const auto &from_subg_index = std::get<ir::SubgraphIndex>(from_iodesc);
+ const auto &from_io_index = std::get<ir::IOIndex>(from_iodesc);
+
+ const auto from_executor = _executors.at({from_model_index, from_subg_index}).get();
+ const auto &from_info = from_executor->inputInfo(from_io_index.value());
+ const auto from_layout = from_executor->inputLayout(from_io_index.value());
+ _edge_tensors[from_iodesc] = std::make_unique<EdgeTensor>(from_info, from_layout);
+ }
+
+ // Append type-aware quantization layer for edges between executors
+ for (const auto &executor_pair : _executors)
+ {
+ const auto &executor_index = executor_pair.first;
+ const auto &model_index = executor_index.first;
+ const auto &subg_index = executor_index.second;
+
+ std::vector<backend::ITensor *> inputs;
+ std::vector<backend::ITensor *> outputs;
+ for (const auto &pair : _edge_map)
+ {
+ const auto &from_iodesc = pair.first;
+ if (std::get<ir::ModelIndex>(from_iodesc) == model_index &&
+ std::get<ir::SubgraphIndex>(from_iodesc) == subg_index)
+ {
+ const auto from_tensor = _edge_tensors[from_iodesc].get();
+ const auto &to_list = pair.second;
+
+ for (const auto &to_iodesc : to_list)
+ {
+ const auto &to_model_index = std::get<ir::ModelIndex>(to_iodesc);
+ const auto &to_subg_index = std::get<ir::SubgraphIndex>(to_iodesc);
+ const auto &to_io_index = std::get<ir::IOIndex>(to_iodesc);
+
+ const auto to_executor = _executors.at({to_model_index, to_subg_index}).get();
+ const auto &to_info = to_executor->inputInfo(to_io_index.value());
+ const auto to_layout = to_executor->inputLayout(to_io_index.value());
+
+ // TODO Unify tensors with the same `from` tensor and same type
+ if (from_tensor->data_type() != to_info.typeInfo().type())
+ {
+ assert(inputs.size() == outputs.size());
+ inputs.emplace_back(from_tensor);
+
+ auto type_aware_quant_tensor = std::make_unique<EdgeTensor>(to_info, to_layout);
+ outputs.emplace_back(type_aware_quant_tensor.get());
+
+ _edge_quant_tensors[to_iodesc] = std::move(type_aware_quant_tensor);
+ }
+ }
+ }
+ }
+
+ auto layer = std::make_unique<PermuteLayer>(inputs, outputs);
+ layer->prepare();
+ _edge_quant_layers[{model_index, subg_index}] = std::move(layer);
+ }
+
+ _is_created_edge_quant_layers = true;
+}
+
+void MultiModelExecutors::CreatePkgIOTensors(const IODescription &desc)
+{
+ for (const auto &pkg_input : _model_edges->pkg_inputs)
+ {
+ // Create IOTensor for nnpkg inputs
+ const auto &model_index = std::get<ir::ModelIndex>(pkg_input);
+ const auto &subg_index = std::get<ir::SubgraphIndex>(pkg_input);
+ const auto &io_index = std::get<ir::IOIndex>(pkg_input);
+ const auto input_pkg_index =
+ find_input_index(_model_edges->pkg_inputs, model_index, subg_index, io_index);
+ if (input_pkg_index == -1)
+ throw std::runtime_error{"Cannot find multi model input index"};
+ auto input_desc = desc.inputs[input_pkg_index].get();
+ // TODO Remove const_cast (we need const_cast as ITensor is writable)
+ _pkg_input_tensors[pkg_input] = std::make_unique<backend::builtin::UserTensor>(
+ input_desc->info, input_desc->layout,
+ const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(input_desc->buffer)),
+ input_desc->size);
+ }
+
+ for (const auto &pkg_output : _model_edges->pkg_outputs)
+ {
+ // Create IOTensor for nnpkg outputs
+ const auto &model_index = std::get<ir::ModelIndex>(pkg_output);
+ const auto &subg_index = std::get<ir::SubgraphIndex>(pkg_output);
+ const auto &io_index = std::get<ir::IOIndex>(pkg_output);
+ const auto output_pkg_index =
+ find_output_index(_model_edges->pkg_outputs, model_index, subg_index, io_index);
+ if (output_pkg_index == -1)
+ throw std::runtime_error{"Cannot find multi model output index"};
+ auto output_desc = desc.outputs[output_pkg_index].get();
+ _pkg_output_tensors[pkg_output] = std::make_unique<backend::builtin::UserTensor>(
+ output_desc->info, output_desc->layout, reinterpret_cast<uint8_t *>(output_desc->buffer),
+ output_desc->size);
+ }
+}
+
+void MultiModelExecutors::createPkgIOQuantLayers(const IODescription &desc)
+{
+ // Append type-aware quantization layer for nnpkg inputs/outputs between executors
+ for (const auto &pair : _executors)
+ {
+ const auto &executor_index = pair.first;
+ const auto &model_index = executor_index.first;
+ const auto &subg_index = executor_index.second;
+ const auto executor = pair.second.get();
+
+ // Find pkg inputs of current executor
+ std::vector<ir::IODesc> pkg_inputs;
+ for (const auto &pkg_input : _model_edges->pkg_inputs)
+ {
+ if (std::get<ir::ModelIndex>(pkg_input) == model_index &&
+ std::get<ir::SubgraphIndex>(pkg_input) == subg_index)
+ {
+ pkg_inputs.emplace_back(pkg_input);
+ }
+ }
+ std::vector<backend::ITensor *> src_tensors;
+ std::vector<backend::ITensor *> dst_tensors;
+ for (const auto &pkg_input : pkg_inputs)
+ {
+ const auto &io_index = std::get<ir::IOIndex>(pkg_input);
+ const auto input_pkg_index =
+ find_input_index(_model_edges->pkg_inputs, model_index, subg_index, io_index);
+ if (input_pkg_index == -1)
+ throw std::runtime_error{"Cannot find multi model input index"};
+ auto input_desc = desc.inputs[input_pkg_index].get();
+
+ // Create EdgeTensor for nnpkg input if type is different
+ const auto &orig_info = executor->inputInfo(io_index.value());
+ const auto orig_layout = executor->inputLayout(io_index.value());
+ if (input_desc->info.typeInfo().type() != orig_info.typeInfo().type())
+ {
+ auto pkg_input_edge_tensor = std::make_unique<EdgeTensor>(orig_info, orig_layout);
+ _pkg_input_quant_tensors[pkg_input] = std::move(pkg_input_edge_tensor);
+
+ // Append type-aware quantization layer's inputs/outputs
+ src_tensors.emplace_back(_pkg_input_tensors[pkg_input].get());
+ dst_tensors.emplace_back(_pkg_input_quant_tensors[pkg_input].get());
+ }
+ }
+
+ // Create type-aware quantization layer for nnpkg inputs
+ auto pkg_input_layer = std::make_unique<PermuteLayer>(src_tensors, dst_tensors);
+ pkg_input_layer->prepare();
+ _pkg_input_quant_layers[{model_index, subg_index}] = std::move(pkg_input_layer);
+
+ // Find pkg outputs of current executor
+ std::vector<ir::IODesc> pkg_outputs;
+ for (const auto &pkg_output : _model_edges->pkg_outputs)
+ {
+ if (std::get<ir::ModelIndex>(pkg_output) == model_index &&
+ std::get<ir::SubgraphIndex>(pkg_output) == subg_index)
+ {
+ pkg_outputs.emplace_back(pkg_output);
+ }
+ }
+ src_tensors.clear();
+ dst_tensors.clear();
+ // Create Tensors of nnpkg outputs for type-aware quantization
+ for (const auto &pkg_output : pkg_outputs)
+ {
+ const auto &io_index = std::get<ir::IOIndex>(pkg_output);
+ const auto output_pkg_index =
+ find_output_index(_model_edges->pkg_outputs, model_index, subg_index, io_index);
+ if (output_pkg_index == -1)
+ throw std::runtime_error{"Cannot find multi model output index"};
+ auto output_desc = desc.outputs[output_pkg_index].get();
+
+ // Create EdgeTensor for nnpkg output if type is different
+ const auto &orig_info = executor->outputInfo(io_index.value());
+ const auto orig_layout = executor->outputLayout(io_index.value());
+ if (output_desc->info.typeInfo().type() != orig_info.typeInfo().type())
+ {
+ auto pkg_output_edge_tensor = std::make_unique<EdgeTensor>(orig_info, orig_layout);
+ _pkg_output_quant_tensors[pkg_output] = std::move(pkg_output_edge_tensor);
+
+ // Append type-aware quantization layer's inputs/outputs
+ src_tensors.emplace_back(_pkg_output_quant_tensors[pkg_output].get());
+ dst_tensors.emplace_back(_pkg_output_tensors[pkg_output].get());
+ }
+ }
+
+ // Create type-aware quantization layer for nnpkg outputs
+ auto pkg_output_layer = std::make_unique<PermuteLayer>(src_tensors, dst_tensors);
+ pkg_output_layer->prepare();
+ _pkg_output_quant_layers[{model_index, subg_index}] = std::move(pkg_output_layer);
+ }
+}
+
+void MultiModelExecutors::execute(const ExecutionContext &ctx)
+{
+ auto &desc = ctx.desc;
+
+ // Check supported multi model package
+ checkSupportedMultimodel();
+
+ // TODO Move creating type-aware quantization layers for edges in compilation stage
+ createEdgeQuantLayers();
+
+ // TODO Create IOTensors only once and recreate them only if nnpkg info changes
+ CreatePkgIOTensors(desc);
+
+ // TODO Create type-aware quantization layers only once and recreate them only if type changes
+ createPkgIOQuantLayers(desc);
+
+ // TODO Find better way to schedule order of executors
+ auto const model_count = modelCount();
+
+ auto find_from = [&](const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index,
+ const ir::IOIndex &io_index) {
+ for (const auto &edge : _model_edges->edges)
+ {
+ if ((std::get<ir::ModelIndex>(edge.to) == model_index) &&
+ (std::get<ir::SubgraphIndex>(edge.to) == subg_index) &&
+ (std::get<ir::IOIndex>(edge.to) == io_index))
+ return edge.from;
+ }
+
+ throw std::runtime_error{"Cannot find edge for model input"};
+ };
+
+ // Execute each model
+ // NOTE May be better to use vector instead of unordered_map for _executors
+ for (auto model_index = ir::ModelIndex{0}; model_index.value() < model_count; model_index++)
+ {
+ // Find executor
+ auto executor = at(model_index, ir::SubgraphIndex{0});
+
+ // Set IOTensors
+ // TODO Set internal IOTensors only once
+ std::vector<backend::IPortableTensor *> inputs_inter;
+ std::vector<backend::IPortableTensor *> outputs_inter;
+ auto const input_size = executor->inputSize();
+ auto const output_size = executor->outputSize();
+ inputs_inter.resize(input_size);
+ outputs_inter.resize(output_size);
+
+ // Set inputs of executor
+ // TODO Create layer to allocate/deallocate buffers of EdgeTensor for each executor
+ for (uint32_t i = 0; i < input_size; i++)
+ {
+ const auto input_pkg_index = find_input_index(_model_edges->pkg_inputs, model_index,
+ ir::SubgraphIndex{0}, ir::IOIndex{i});
+ const auto input_io_desc = ir::IODesc{model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}};
+ if (input_pkg_index != -1)
+ {
+ // Allocate type-aware quantization tensors for nnpkg inputs and set internal tensors
+ if (_pkg_input_quant_tensors.find(input_io_desc) != _pkg_input_quant_tensors.end())
+ {
+ _pkg_input_quant_tensors[input_io_desc]->allocate_buffer();
+
+ inputs_inter[i] = _pkg_input_quant_tensors[input_io_desc].get();
+ }
+ else
+ {
+ inputs_inter[i] = _pkg_input_tensors[input_io_desc].get();
+ }
+ }
+ else
+ {
+ auto from_iodesc = find_from(model_index, ir::SubgraphIndex{0}, ir::IOIndex{i});
+
+ // Supported only sequantial execution of models
+ assert(std::get<ir::ModelIndex>(from_iodesc).value() < model_index.value());
+ assert(std::get<ir::SubgraphIndex>(from_iodesc).value() == 0);
+ const auto to_iodesc = ir::IODesc{model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}};
+ if (_edge_quant_tensors.find(to_iodesc) == _edge_quant_tensors.end())
+ {
+ inputs_inter[i] = _edge_tensors.at(from_iodesc).get();
+ }
+ else
+ {
+ inputs_inter[i] = _edge_quant_tensors.at(to_iodesc).get();
+ }
+ assert(inputs_inter[i]->buffer() != nullptr);
+ }
+ }
+
+ // Set outputs of executor
+ for (uint32_t i = 0; i < output_size; i++)
+ {
+ const auto output_pkg_index = find_output_index(_model_edges->pkg_outputs, model_index,
+ ir::SubgraphIndex{0}, ir::IOIndex{i});
+ const auto output_io_desc = ir::IODesc{model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}};
+ if (output_pkg_index != -1)
+ {
+ // Allocate type-aware quantization tensors for nnpkg outputs and set internal tensors
+ if (_pkg_output_quant_tensors.find(output_io_desc) != _pkg_output_quant_tensors.end())
+ {
+ _pkg_output_quant_tensors[output_io_desc]->allocate_buffer();
+
+ outputs_inter[i] = _pkg_output_quant_tensors[output_io_desc].get();
+ }
+ else
+ {
+ outputs_inter[i] = _pkg_output_tensors[output_io_desc].get();
+ }
+ }
+ else
+ {
+ // Allocate buffer of `from` tensors
+ const auto from_iodesc = ir::IODesc{model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}};
+ _edge_tensors[from_iodesc]->allocate_buffer();
+ outputs_inter[i] = _edge_tensors[from_iodesc].get();
+
+ // Allocate buffer of tensors for type-aware quantization
+ for (const auto &to_iodesc : _edge_map[from_iodesc])
+ {
+ _edge_tensors[from_iodesc]->increase_ref();
+ if (_edge_quant_tensors.find(to_iodesc) != _edge_quant_tensors.end())
+ {
+ auto type_aware_quant_tensor = _edge_quant_tensors.at(to_iodesc).get();
+ type_aware_quant_tensor->allocate_buffer();
+
+ _edge_tensors[from_iodesc]->decrease_ref();
+ }
+ }
+ }
+ }
+
+ _pkg_input_quant_layers[{model_index, ir::SubgraphIndex{0}}]->run();
+
+ executor->execute(inputs_inter, outputs_inter, ctx.options);
+
+ _edge_quant_layers[{model_index, ir::SubgraphIndex{0}}]->run();
+ _pkg_output_quant_layers[{model_index, ir::SubgraphIndex{0}}]->run();
+
+ // Release input buffers that are no longer needed
+ for (uint32_t i = 0; i < input_size; i++)
+ {
+ const auto input_pkg_index = find_input_index(_model_edges->pkg_inputs, model_index,
+ ir::SubgraphIndex{0}, ir::IOIndex{i});
+
+ const auto to_iodesc = ir::IODesc{model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}};
+ if (input_pkg_index == -1)
+ {
+ if (_edge_quant_tensors.find(to_iodesc) != _edge_quant_tensors.end())
+ {
+ // Decrease reference count of tensor for type-aware quantization if input tensor is the
+ // tensor
+ const auto to_iodesc = ir::IODesc{model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}};
+ if (_edge_quant_tensors.find(to_iodesc) != _edge_quant_tensors.end())
+ {
+ _edge_quant_tensors[to_iodesc]->decrease_ref();
+ }
+ }
+ else
+ {
+ // Decrease reference count of `from` tensor if input tensor is the `from` tensor
+ const auto from_iodesc = find_from(model_index, ir::SubgraphIndex{0}, ir::IOIndex{i});
+ _edge_tensors[from_iodesc]->decrease_ref();
+
+ // Decrease reference count of nnpkg inputs
+ if (_pkg_input_quant_tensors.find(to_iodesc) != _pkg_input_quant_tensors.end())
+ {
+ _pkg_input_quant_tensors[to_iodesc]->decrease_ref();
+ }
+ }
+ }
+ }
+
+ // Release output buffers if those buffers are no longer used other executors because of
+ // type-aware quantization
+ // FIXME if tensors for type-aware quantization unified for the same `from` tensor and same type
+ for (uint32_t i = 0; i < output_size; i++)
+ {
+ auto from_iodesc = ir::IODesc{model_index, ir::SubgraphIndex{0}, ir::IOIndex{i}};
+
+ // Check if other executors will use the buffer of edge tensor
+ const auto &to_list = _edge_map[from_iodesc];
+ if (to_list.size() == 0)
+ {
+ // This condition means `from_iodesc` tensor is an output of nnpkg
+ continue;
+ }
+
+ bool to_be_release =
+ !std::any_of(to_list.begin(), to_list.end(), [&](const ir::IODesc &to_iodesc) {
+ // This condition means another executor uses the buffer of edge tensor
+ return _edge_quant_tensors.find(to_iodesc) == _edge_quant_tensors.end();
+ });
+
+ if (to_be_release)
+ {
+ // This edge tensor's buffer won't be used in other executors
+ // Tensors for type-aware quantization take over the role of this edge tensor instead
+ _edge_tensors[from_iodesc]->decrease_ref();
+ }
+
+ // Decrease reference count of nnpkg outputs
+ if (_pkg_output_quant_tensors.find(from_iodesc) != _pkg_output_quant_tensors.end())
+ {
+ _pkg_output_quant_tensors[from_iodesc]->decrease_ref();
+ }
+ }
+ }
+}
+
+// modelCount() iterates _executors.
+// It assumes that Compiler will generate Executor for all models and _executors includes all
+// generated Executor.
+// If nnpackage includes model(s) which has no connection and Compiler does not
+// generate Executor for them, modelCount() return less value than real model count.
+uint16_t MultiModelExecutors::modelCount() const
+{
+ uint16_t model_count = 0;
+ for (; _executors.find(std::make_pair(ir::ModelIndex{model_count}, ir::SubgraphIndex{0})) !=
+ _executors.end();
+ model_count++)
+ ;
+
+ return model_count;
+}
+
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/MultiModelExecutors.h b/runtime/onert/core/src/exec/MultiModelExecutors.h
new file mode 100644
index 000000000..0bd9f1143
--- /dev/null
+++ b/runtime/onert/core/src/exec/MultiModelExecutors.h
@@ -0,0 +1,152 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_EXEC_EXECUTORS_H__
+#define __ONERT_EXEC_EXECUTORS_H__
+
+#include "exec/IExecutors.h"
+#include "ir/NNPkg.h"
+#include "IPermuteFunction.h"
+#include "EdgeTensor.h"
+#include "../backend/builtin/UserTensor.h"
+
+namespace std
+{
+
+template <> struct hash<std::pair<::onert::ir::ModelIndex, ::onert::ir::SubgraphIndex>>
+{
+ size_t operator()(
+ const std::pair<::onert::ir::ModelIndex, ::onert::ir::SubgraphIndex> &pair) const noexcept
+ {
+ return (hash<uint32_t>()(pair.first.value()) << 16) ^ hash<uint32_t>()(pair.second.value());
+ }
+};
+
+} // namespace std
+
+namespace onert
+{
+namespace exec
+{
+
+/**
+ * @brief Class to gather executors
+ */
+class MultiModelExecutors : public IExecutors
+{
+public:
+ MultiModelExecutors(void) = delete;
+ MultiModelExecutors(std::unique_ptr<ir::ModelEdges> model_edges)
+ : _executors{}, _model_edges{std::move(model_edges)}, _edge_quant_layers{},
+ _edge_quant_tensors{}, _edge_tensors{}, _is_created_edge_quant_layers{false},
+ _pkg_input_quant_layers{}, _pkg_output_quant_layers{}, _pkg_input_quant_tensors{},
+ _pkg_output_quant_tensors{}, _pkg_input_tensors{}, _pkg_output_tensors{}
+ {
+ for (const auto &edge : _model_edges->edges)
+ {
+ _edge_map[edge.from].emplace_back(edge.to);
+ }
+ }
+ MultiModelExecutors(const MultiModelExecutors &) = delete;
+ MultiModelExecutors(MultiModelExecutors &&) = default;
+ ~MultiModelExecutors() = default;
+
+ // TODO Use Executor index
+ void emplace(const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index,
+ std::unique_ptr<IExecutor> exec) override;
+
+ IExecutor *at(const ir::ModelIndex &model_index,
+ const ir::SubgraphIndex &subg_index) const override;
+
+ uint32_t inputSize() const override;
+
+ uint32_t outputSize() const override;
+
+ const ir::OperandInfo &inputInfo(const ir::IOIndex &index) const override;
+
+ const ir::OperandInfo &outputInfo(const ir::IOIndex &index) const override;
+
+ void execute(const ExecutionContext &ctx) override;
+
+private:
+ void checkSupportedMultimodel() const;
+ void createEdgeQuantLayers();
+ void CreatePkgIOTensors(const IODescription &desc);
+ void createPkgIOQuantLayers(const IODescription &desc);
+ uint16_t modelCount() const;
+
+private:
+ std::unordered_map<std::pair<ir::ModelIndex, ir::SubgraphIndex>, std::unique_ptr<IExecutor>>
+ _executors;
+
+ // NOTE _model_edges may use different struct type for executor implementation
+ std::unique_ptr<ir::ModelEdges> _model_edges;
+ std::unordered_map<ir::IODesc, std::vector<ir::IODesc>> _edge_map;
+
+ /**
+ * @brief Type-aware quantization layers for edges between executors
+ *
+ */
+ // TODO Move variables related to type-aware quantization for edges into compilation stage
+ // TODO Replace PermuteLayer with backend::builtin::kernel::PermuteLayer
+ std::unordered_map<std::pair<ir::ModelIndex, ir::SubgraphIndex>, std::unique_ptr<PermuteLayer>>
+ _edge_quant_layers;
+
+ /**
+ * @brief Tensors for type-aware quantization of edges
+ * Key: `to` IODesc, Value: EdgeTensor
+ */
+ //
+ // Q: Why is Key `to` IODesc
+ // A: these tensors are currently created depending on the type of `to`
+ // TODO Unify tensors with the same `from` tensor and same type
+ // NOTE The incomplete type 'EdgeTensor' cannot be declared as unique_ptr.
+ std::unordered_map<ir::IODesc, std::shared_ptr<EdgeTensor>> _edge_quant_tensors;
+
+ /**
+ * @brief Tensors for edges between executors that are not related to type-aware quantization
+ * Key: `from` IODesc, Value: EdgeTensor
+ */
+ // Q: Why is Key `from` IODesc
+ // A: `from` can be connected to multiple `to`
+ // NOTE The incomplete type 'EdgeTensor' cannot be declared as unique_ptr.
+ std::unordered_map<ir::IODesc, std::shared_ptr<EdgeTensor>> _edge_tensors;
+ /**
+ * @brief Whether type-aware quantization layers for edges between executors are created
+ *
+ */
+ // TODO Remove this member after the creation of type-aware quantization layers for edges
+ // is moved into compilation stage
+ bool _is_created_edge_quant_layers;
+
+ // TODO Replace PermuteLayer with backend::builtin::kernel::PermuteLayer
+ std::unordered_map<std::pair<ir::ModelIndex, ir::SubgraphIndex>, std::unique_ptr<PermuteLayer>>
+ _pkg_input_quant_layers;
+ // TODO Replace PermuteLayer with backend::builtin::kernel::PermuteLayer
+ std::unordered_map<std::pair<ir::ModelIndex, ir::SubgraphIndex>, std::unique_ptr<PermuteLayer>>
+ _pkg_output_quant_layers;
+ // Edge tensors of nnpkg inputs/outputs for type-aware quantization
+ std::unordered_map<ir::IODesc, std::shared_ptr<EdgeTensor>> _pkg_input_quant_tensors;
+ std::unordered_map<ir::IODesc, std::shared_ptr<EdgeTensor>> _pkg_output_quant_tensors;
+ // IOTensors for user buffer
+ std::unordered_map<ir::IODesc, std::unique_ptr<backend::builtin::UserTensor>> _pkg_input_tensors;
+ std::unordered_map<ir::IODesc, std::unique_ptr<backend::builtin::UserTensor>> _pkg_output_tensors;
+};
+
+} // namespace exec
+} // namespace onert
+
+#endif // __ONERT_EXEC_EXECUTORS_H__
diff --git a/runtime/onert/core/src/exec/ParallelExecutor.cc b/runtime/onert/core/src/exec/ParallelExecutor.cc
index ab234aacd..152fa7cd3 100644
--- a/runtime/onert/core/src/exec/ParallelExecutor.cc
+++ b/runtime/onert/core/src/exec/ParallelExecutor.cc
@@ -31,7 +31,7 @@ class HookFunction : public IFunction
public:
HookFunction(IFunction *fn, const std::function<void()> &setup,
const std::function<void()> &teardown)
- : _fn{fn}, _setup{setup}, _teardown{teardown}
+ : _fn{fn}, _setup{setup}, _teardown{teardown}
{
}
@@ -59,29 +59,28 @@ void ParallelExecutor::notify(uint32_t finished_job_id)
_cv_jobs.notify_all();
}
-ParallelExecutor::ParallelExecutor(
- std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors,
- const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors,
- const compiler::TensorRegistries &tensor_regs, backend::TensorManagerSet &&tensor_mgrs,
- compiler::CodeMap &&code_map)
- : DataflowExecutor{std::move(lowered_graph), input_tensors, output_tensors, tensor_regs,
- std::move(tensor_mgrs), std::move(code_map)}
+ParallelExecutor::ParallelExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
+ backend::BackendContexts &&backend_contexts,
+ const compiler::TensorRegistries &tensor_regs,
+ compiler::CodeMap &&code_map,
+ const util::TracingCtx *tracing_ctx)
+ : DataflowExecutor{std::move(lowered_graph), std::move(backend_contexts), tensor_regs,
+ std::move(code_map), tracing_ctx}
{
VERBOSE(ParallelExecutor) << "Constructing Parallel Executor" << std::endl;
}
-void ParallelExecutor::executeImpl()
+void ParallelExecutor::executeImpl(const ExecutionObservee &subject)
{
bool dynamic_input_exists = hasDynamicInput();
// Init scheduler
- // TODO Consider to have distinct backend set in LowerInfoMap
+ // TODO Consider to have distinct backend set in GraphLowerInfo
BackendSet backends;
- for (auto &itr : _lowered_graph->getLowerInfo()->op_seq)
- {
- backends.add(itr.second->backend());
- }
+ _lowered_graph->lower_info().operation.iterate(
+ [&](const ir::OperationIndex &, const compiler::OperationLowerInfo &lower_info) {
+ backends.add(lower_info.backend());
+ });
_scheduler = std::make_unique<ParallelScheduler>(backends);
assert(noWaitingJobs());
@@ -101,7 +100,10 @@ void ParallelExecutor::executeImpl()
VERBOSE(ParallelExecutor) << "INITIAL JOBS : " << _ready_jobs.size() << std::endl;
- _subject.notifyModelBegin(this);
+ auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_graph);
+
+ subject.notifySubgraphBegin(profiling_subg_index);
+
while (true)
{
std::unique_lock<std::mutex> lock{_mu_jobs};
@@ -121,20 +123,24 @@ void ParallelExecutor::executeImpl()
lock.unlock();
- VERBOSE(ParallelExecutor) << "Assigning fn #" << job->index() << std::endl;
+ VERBOSE(ParallelExecutor) << "Assigning fn " << job->index() << std::endl;
auto job_index = job->index();
- auto op_sequence_index = _job_to_op_seq[job_index];
- auto op_seq = &_lowered_graph->op_seqs().at(op_sequence_index);
- auto backend = _lowered_graph->getLowerInfo()->op_seq.at(op_sequence_index)->backend();
- auto setup = [&, op_seq, backend]() { _subject.notifyJobBegin(this, op_seq, backend); };
- auto teardown = [&, job_index, op_seq, backend]() {
- _subject.notifyJobEnd(this, op_seq, backend);
+ auto op_ind = _job_to_op[job_index];
+ auto backend = _lowered_graph->lower_info().operation.at(op_ind).backend();
+ auto setup = [&, op_ind, backend]() {
+ subject.notifyJobBegin(this, profiling_subg_index, op_ind, backend);
+ };
+ auto teardown = [&, job_index, op_ind, backend]() {
+ subject.notifyJobEnd(this, profiling_subg_index, op_ind, backend);
notify(job_index);
};
+ job->fn_seq()->initRunning();
+
// dynamic tensor setting
- bool handle_dynamic_tensor = op_seq->has_dynamic_tensor() || dynamic_input_exists;
+ bool handle_dynamic_tensor =
+ _lowered_graph->getHasDynamicTensor(op_ind) || dynamic_input_exists;
job->fn_seq()->enableDynamicShapeInferer(handle_dynamic_tensor);
_scheduler->assign(std::make_unique<HookFunction>(job->fn_seq(), setup, teardown), backend);
@@ -145,7 +151,7 @@ void ParallelExecutor::executeImpl()
// Wait for all the jobs done
_scheduler->finish();
- _subject.notifyModelEnd(this);
+ subject.notifySubgraphEnd(profiling_subg_index);
// Reset input info for the next execution
_input_info = _initial_input_info;
diff --git a/runtime/onert/core/src/exec/ParallelExecutor.h b/runtime/onert/core/src/exec/ParallelExecutor.h
index 929edfce9..3162d865f 100644
--- a/runtime/onert/core/src/exec/ParallelExecutor.h
+++ b/runtime/onert/core/src/exec/ParallelExecutor.h
@@ -17,17 +17,12 @@
#ifndef __ONERT_EXEC_PARALLEL_EXECUTOR_H__
#define __ONERT_EXEC_PARALLEL_EXECUTOR_H__
-#include <list>
-#include <queue>
-#include <unordered_map>
+#include "DataflowExecutor.h"
+#include "ParallelScheduler.h"
+
+#include "util/TracingCtx.h"
-#include "exec/FunctionSequence.h"
-#include "Job.h"
-#include "ir/OperandIndexSequence.h"
-#include "ir/Index.h"
#include <memory>
-#include "exec/DataflowExecutor.h"
-#include "ParallelScheduler.h"
namespace onert
{
@@ -48,15 +43,14 @@ public:
*
* @param lowered_graph LoweredGraph object
* @param tensor_builders Tensor builders that are currently used
- * @param code_map OpSequence and its code map
+ * @param code_map @c ir::Operation and its code map
*/
ParallelExecutor(std::unique_ptr<compiler::LoweredGraph> lowered_graph,
- const std::vector<std::shared_ptr<backend::ITensor>> &input_tensors,
- const std::vector<std::shared_ptr<backend::ITensor>> &output_tensors,
- const compiler::TensorRegistries &tensor_regs,
- backend::TensorManagerSet &&tensor_mgrs, compiler::CodeMap &&code_map);
+ backend::BackendContexts &&backend_contexts,
+ const compiler::TensorRegistries &tensor_regs, compiler::CodeMap &&code_map,
+ const util::TracingCtx *tracing_ctx);
- void executeImpl() override;
+ void executeImpl(const ExecutionObservee &subject) override;
private:
std::condition_variable _cv_jobs;
diff --git a/runtime/onert/core/src/exec/ParallelScheduler.cc b/runtime/onert/core/src/exec/ParallelScheduler.cc
index 70c9c3dd6..538945631 100644
--- a/runtime/onert/core/src/exec/ParallelScheduler.cc
+++ b/runtime/onert/core/src/exec/ParallelScheduler.cc
@@ -30,7 +30,7 @@ ParallelScheduler::ParallelScheduler(const BackendSet &backends)
{
assert(!backends.empty());
- for (auto backend : backends)
+ for (auto &&backend : backends)
{
_thread_pools[backend] = std::make_unique<ThreadPool>();
}
@@ -45,7 +45,7 @@ void ParallelScheduler::assign(std::unique_ptr<IFunction> &&fn, const backend::B
void ParallelScheduler::finish()
{
- for (auto &itr : _thread_pools)
+ for (auto &&itr : _thread_pools)
{
itr.second->finish();
}
diff --git a/runtime/onert/core/src/exec/SingleModelExecutors.cc b/runtime/onert/core/src/exec/SingleModelExecutors.cc
new file mode 100644
index 000000000..44c5e5742
--- /dev/null
+++ b/runtime/onert/core/src/exec/SingleModelExecutors.cc
@@ -0,0 +1,170 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "SingleModelExecutors.h"
+
+#include "EdgeTensor.h"
+#include "IPermuteFunction.h"
+#include "../backend/builtin/UserTensor.h"
+
+namespace onert
+{
+namespace exec
+{
+
+void SingleModelExecutors::emplace(const ir::ModelIndex &, const ir::SubgraphIndex &subg_index,
+ std::unique_ptr<IExecutor> exec)
+{
+ _executors.emplace(subg_index, std::move(exec));
+}
+
+IExecutor *SingleModelExecutors::at(const ir::ModelIndex &,
+ const ir::SubgraphIndex &subg_index) const
+{
+ return _executors.at(subg_index).get();
+}
+
+uint32_t SingleModelExecutors::inputSize() const { return entryExecutor()->inputSize(); }
+
+uint32_t SingleModelExecutors::outputSize() const { return entryExecutor()->outputSize(); }
+
+const ir::OperandInfo &SingleModelExecutors::inputInfo(const ir::IOIndex &index) const
+{
+ return entryExecutor()->inputInfo(index.value());
+}
+
+const ir::OperandInfo &SingleModelExecutors::outputInfo(const ir::IOIndex &index) const
+{
+ return entryExecutor()->outputInfo(index.value());
+}
+
+void SingleModelExecutors::execute(const ExecutionContext &ctx)
+{
+ // UserTensor for Input/Output
+ std::vector<std::unique_ptr<backend::builtin::UserTensor>> tensorpool;
+
+ // EdgeTensor for Input Quantization / Output Dequantization
+ std::vector<std::unique_ptr<EdgeTensor>> qtensorpool;
+
+ // Input/Output Tensor vector for executor
+ std::vector<backend::IPortableTensor *> inputs(ctx.desc.inputs.size());
+ std::vector<backend::IPortableTensor *> outputs(ctx.desc.outputs.size());
+
+ // Vector for input quantization I/O
+ std::vector<backend::ITensor *> input_tensors;
+ std::vector<backend::ITensor *> input_qtensors;
+
+ // Vector for output dequantization I/O
+ std::vector<backend::ITensor *> output_qtensors;
+ std::vector<backend::ITensor *> output_tensors;
+
+ // Prepare UserTensor and EdgeTensor for input quantization
+ for (uint32_t i = 0; i < inputs.size(); i++)
+ {
+ auto &desc = ctx.desc.inputs[i];
+
+ // Input is optional if buffer is nullptr, and optional input's size is 0
+ if (desc->buffer == nullptr && (desc->size != 0 || desc->info.total_size() != 0))
+ throw std::runtime_error{"Input " + std::to_string(i) + "'s buffer is not set."};
+
+ tensorpool.emplace_back(std::make_unique<backend::builtin::UserTensor>(
+ desc->info, desc->layout, const_cast<uint8_t *>(static_cast<const uint8_t *>(desc->buffer)),
+ desc->size));
+
+ auto user_type = desc->info.typeInfo().type();
+ auto &model_info = entryExecutor()->inputInfo(i).typeInfo();
+ auto model_type = model_info.type();
+ if (user_type != model_type && user_type == ir::DataType::FLOAT32)
+ {
+ auto quantized_info = desc->info;
+ quantized_info.typeInfo(model_info);
+ qtensorpool.emplace_back(
+ std::make_unique<EdgeTensor>(quantized_info, entryExecutor()->inputLayout(i)));
+ qtensorpool.back()->allocate_buffer();
+
+ input_tensors.push_back(tensorpool.back().get());
+ input_qtensors.push_back(qtensorpool.back().get());
+ inputs[i] = qtensorpool.back().get();
+ }
+ else
+ inputs[i] = tensorpool.back().get();
+ }
+
+ // Prepare UserTensor and EdgeTensor for output dequantization
+ for (uint32_t i = 0; i < outputs.size(); i++)
+ {
+ auto &desc = ctx.desc.outputs[i];
+
+ // Output is optional if buffer is nullptr, and optional output's size is 0
+ if (desc->buffer == nullptr && (desc->size != 0 || desc->info.total_size() != 0))
+ throw std::runtime_error{"Output " + std::to_string(i) + "'s buffer is not set."};
+
+ tensorpool.emplace_back(std::make_unique<backend::builtin::UserTensor>(
+ desc->info, desc->layout, static_cast<uint8_t *>(desc->buffer), desc->size));
+
+ auto user_type = desc->info.typeInfo().type();
+ auto &model_info = entryExecutor()->outputInfo(i).typeInfo();
+ auto model_type = model_info.type();
+ if (user_type != model_type && user_type == ir::DataType::FLOAT32)
+ {
+ auto quantized_info = desc->info;
+ quantized_info.typeInfo(model_info);
+ qtensorpool.emplace_back(
+ std::make_unique<EdgeTensor>(quantized_info, entryExecutor()->outputLayout(i)));
+ qtensorpool.back()->allocate_buffer();
+
+ output_qtensors.push_back(qtensorpool.back().get());
+ output_tensors.push_back(tensorpool.back().get());
+ outputs[i] = qtensorpool.back().get();
+ }
+ else
+ outputs[i] = tensorpool.back().get();
+ }
+
+ // Run quantization
+ if (input_tensors.size() > 0)
+ {
+ auto input_quantize_layer = PermuteLayer(input_tensors, input_qtensors);
+ input_quantize_layer.prepare();
+ input_quantize_layer.run();
+ }
+
+ // Executor
+ entryExecutor()->execute(inputs, outputs, ctx.options);
+
+ // Run dequantization
+ if (output_tensors.size() != 0)
+ {
+ auto output_dequantize_layer = PermuteLayer(output_qtensors, output_tensors);
+ output_dequantize_layer.prepare();
+ output_dequantize_layer.run();
+ }
+
+ // Get dynamic shape inference result
+ for (uint32_t i = 0; i < outputs.size(); i++)
+ {
+ if (ctx.desc.outputs[i]->buffer == nullptr)
+ {
+ // Output is optional if buffer is nullptr
+ continue;
+ }
+
+ ctx.desc.outputs[i]->info.shape(outputs[i]->getShape());
+ }
+}
+
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/SingleModelExecutors.h b/runtime/onert/core/src/exec/SingleModelExecutors.h
new file mode 100644
index 000000000..66dce6077
--- /dev/null
+++ b/runtime/onert/core/src/exec/SingleModelExecutors.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_EXEC_SINGLE_MODEL_EXECUTORS_H__
+#define __ONERT_EXEC_SINGLE_MODEL_EXECUTORS_H__
+
+#include "exec/IExecutors.h"
+#include "ir/NNPkg.h"
+
+namespace onert
+{
+namespace exec
+{
+
+/**
+ * @brief Class to gather executor set for single model NN package
+ */
+class SingleModelExecutors : public IExecutors
+{
+public:
+ /**
+ * @brief Construct a new SingleModelExecutors object
+ */
+ SingleModelExecutors(void) = default;
+ SingleModelExecutors(const SingleModelExecutors &) = delete;
+ SingleModelExecutors(SingleModelExecutors &&) = default;
+
+ /**
+ * @brief Destroy the SingleModelExecutors object
+ */
+ ~SingleModelExecutors() = default;
+
+public:
+ void emplace(const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index,
+ std::unique_ptr<IExecutor> exec) override;
+
+ IExecutor *at(const ir::ModelIndex &model_index,
+ const ir::SubgraphIndex &subg_index) const override;
+
+ uint32_t inputSize() const override;
+
+ uint32_t outputSize() const override;
+
+ const ir::OperandInfo &inputInfo(const ir::IOIndex &index) const override;
+
+ const ir::OperandInfo &outputInfo(const ir::IOIndex &index) const override;
+
+ void execute(const ExecutionContext &ctx) override;
+
+private:
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<IExecutor>> _executors;
+};
+
+} // namespace exec
+} // namespace onert
+
+#endif // __ONERT_EXEC_SINGLE_MODEL_EXECUTORS_H__
diff --git a/runtime/onert/core/src/exec/Sink.h b/runtime/onert/core/src/exec/Sink.h
deleted file mode 100644
index 6a99efe60..000000000
--- a/runtime/onert/core/src/exec/Sink.h
+++ /dev/null
@@ -1,199 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __ONERT_EXEC_SINK_H__
-#define __ONERT_EXEC_SINK_H__
-
-#include "feature/nchw/Reader.h"
-#include "feature/nchw/View.h"
-#include "feature/nhwc/Reader.h"
-#include "feature/nhwc/View.h"
-
-#include <cassert>
-#include <memory>
-#include "util/Utils.h"
-#include <misc/feature/IndexIterator.h>
-
-namespace onert
-{
-namespace exec
-{
-struct ISink
-{
- virtual ~ISink() = default;
-
- virtual void pull(::onert::backend::ITensor &tensor) const = 0;
-};
-
-// Create second lever inheritance: the first lever is used as a reference type in use-case places
-template <typename T> class ITemplSink : public ISink
-{
-public:
- ITemplSink(void *output_buffer, const size_t &output_size, const ir::Shape &shape,
- const bool copy, ir::Layout io_layout)
- : _output_buffer{reinterpret_cast<T *>(output_buffer)}, _output_size{output_size},
- _shape{shape}, _copy{copy}, _io_layout{io_layout}
- {
- }
-
-protected:
- void pullUnif(onert::backend::ITensor &tensor) const
- {
- assert(((_io_layout == ir::Layout::NHWC && tensor.layout() == ir::Layout::NCHW) ||
- (_io_layout == ir::Layout::NCHW && tensor.layout() == ir::Layout::NHWC)) ||
- _copy);
- auto input_buffer = tensor.buffer();
- auto rank = _shape.rank();
-
- if (!tensor.has_padding() && rank < 4 + _copy)
- {
- memcpy(_output_buffer, input_buffer, _output_size);
- return;
- }
-
- switch (rank)
- {
- case 0:
- case 1:
- {
- memcpy(_output_buffer, input_buffer, _output_size);
- break;
- }
- case 2:
- {
- const int32_t copy_len = _shape.dim(1);
-
- for (auto i = 0; i < _shape.dim(0); ++i)
- {
- ir::Coordinates coords{i, 0};
- memcpy(_output_buffer + i * copy_len, input_buffer + tensor.calcOffset(coords),
- copy_len * sizeof(T));
- }
- break;
- }
- case 3:
- {
- const int32_t dim1 = _shape.dim(1);
- const int32_t dim2 = _shape.dim(2);
-
- for (auto i = 0; i < _shape.dim(0); ++i)
- {
- for (auto j = 0; j < _shape.dim(1); ++j)
- {
- ir::Coordinates coords{i, j, 0};
- memcpy(_output_buffer + i * dim1 * dim2 + j * dim2,
- input_buffer + tensor.calcOffset(coords), dim2 * sizeof(T));
- }
- }
- break;
- }
- case 4:
- {
- if (_copy)
- {
- const int32_t dim1 = _shape.dim(1);
- const int32_t dim2 = _shape.dim(2);
- const int32_t dim3 = _shape.dim(3);
-
- for (auto i = 0; i < _shape.dim(0); ++i)
- {
- for (auto j = 0; j < _shape.dim(1); ++j)
- {
- for (auto k = 0; k < _shape.dim(2); ++k)
- {
- ir::Coordinates coords{i, j, k, 0};
- memcpy(_output_buffer + i * dim1 * dim2 * dim3 + j * dim2 * dim3 + k * dim3,
- input_buffer + tensor.calcOffset(coords), dim3 * sizeof(T));
- }
- }
- }
- }
- else
- {
- const auto shape = _shape.asFeature(_io_layout);
-
- if (_io_layout == ir::Layout::NHWC)
- {
- const exec::feature::nchw::Reader<T> from(&tensor);
- exec::feature::nhwc::View<T> into(shape, _output_buffer, _output_size);
- feature::iterate(shape)
- << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) {
- const auto value = from.at(batch, ch, row, col);
- into.at(batch, row, col, ch) = value;
- };
- }
- else if (_io_layout == ir::Layout::NCHW)
- {
- const exec::feature::nhwc::Reader<T> from(&tensor);
- exec::feature::nchw::View<T> into(shape, _output_buffer, _output_size);
- feature::iterate(shape)
- << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) {
- const auto value = from.at(batch, row, col, ch);
- into.at(batch, ch, row, col) = value;
- };
- }
- else
- {
- throw std::runtime_error("Wrong Layout");
- }
- }
- break;
- }
- default:
- throw std::runtime_error("NYI: rank > 4");
- break;
- }
- }
-
-private:
- T *_output_buffer;
- const size_t _output_size;
- const ir::Shape _shape;
- const bool _copy;
- const ir::Layout _io_layout;
-};
-
-template <typename T> class PermutateSink final : public ITemplSink<T>
-{
-public:
- PermutateSink(void *output_buffer, const size_t &output_size, const ir::Shape &shape,
- ir::Layout io_layout)
- : ITemplSink<T>(output_buffer, output_size, shape, false, io_layout)
- {
- }
-
-public:
- void pull(onert::backend::ITensor &tensor) const override { ITemplSink<T>::pullUnif(tensor); }
-};
-
-// Only supports NHWC format front-end(NNAPI) now
-template <typename T> class CopySink final : public ITemplSink<T>
-{
-public:
- CopySink(void *output_buffer, const size_t &output_size, const ir::Shape &shape,
- ir::Layout io_layout = ir::Layout::UNKNOWN)
- : ITemplSink<T>(output_buffer, output_size, shape, true, io_layout)
- {
- }
-
-public:
- void pull(onert::backend::ITensor &tensor) const override { ITemplSink<T>::pullUnif(tensor); }
-};
-
-} // namespace exec
-} // namespace onert
-
-#endif // __ONERT_EXEC_SINK_H__
diff --git a/runtime/onert/core/src/exec/Source.h b/runtime/onert/core/src/exec/Source.h
deleted file mode 100644
index fb2be4dd8..000000000
--- a/runtime/onert/core/src/exec/Source.h
+++ /dev/null
@@ -1,208 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __ONERT_EXEC_SOURCE_H__
-#define __ONERT_EXEC_SOURCE_H__
-
-#include "feature/IndexIterator.h"
-#include "feature/nchw/Reader.h"
-#include "feature/nchw/View.h"
-#include "feature/nhwc/Reader.h"
-#include "feature/nhwc/View.h"
-
-#include <cassert>
-#include <memory>
-#include "util/Utils.h"
-#include <ir/Layout.h>
-#include "ir/Shape.h"
-
-namespace onert
-{
-namespace exec
-{
-
-struct ISource
-{
- virtual ~ISource() = default;
-
- virtual void push(::onert::backend::ITensor &tensor) const = 0;
-};
-
-// Create second lever inheritance: the first lever is used as a reference type in use-case places
-template <typename T> class ITemplSource : public ISource
-{
-public:
- ITemplSource(const void *input_buffer, const size_t &input_size, const ir::Shape &shape,
- const bool copy, ir::Layout io_layout)
- : _input_buffer{reinterpret_cast<const T *>(input_buffer)}, _input_size{input_size},
- _shape{shape}, _copy(copy), _io_layout{io_layout}
- {
- }
-
- virtual void push(::onert::backend::ITensor &tensor) const = 0;
-
-protected:
- void pushUnif(onert::backend::ITensor &tensor) const
- {
- assert(((_io_layout == ir::Layout::NHWC && tensor.layout() == ir::Layout::NCHW) ||
- (_io_layout == ir::Layout::NCHW && tensor.layout() == ir::Layout::NHWC)) ||
- _copy);
- auto output_buffer = tensor.buffer();
- auto rank = _shape.rank();
-
- if (!tensor.has_padding() && rank < 4 + _copy)
- {
- memcpy(output_buffer, _input_buffer, _input_size);
- return;
- }
-
- switch (rank)
- {
- case 0:
- case 1:
- {
- memcpy(output_buffer, _input_buffer, _input_size);
- break;
- }
- case 2:
- {
- const int32_t copy_len = _shape.dim(1);
-
- for (auto i = 0; i < _shape.dim(0); ++i)
- {
- ir::Coordinates coords{i, 0};
- memcpy(output_buffer + tensor.calcOffset(coords), _input_buffer + i * copy_len,
- copy_len * sizeof(T));
- }
- break;
- }
- case 3:
- {
- const int32_t dim1 = _shape.dim(1);
- const int32_t dim2 = _shape.dim(2);
-
- for (auto i = 0; i < _shape.dim(0); ++i)
- {
- for (auto j = 0; j < _shape.dim(1); ++j)
- {
- ir::Coordinates coords{i, j, 0};
- memcpy(output_buffer + tensor.calcOffset(coords),
- _input_buffer + i * dim1 * dim2 + j * dim2, dim2 * sizeof(T));
- }
- }
- break;
- }
- case 4:
- {
- if (_copy)
- {
- const int32_t dim1 = _shape.dim(1);
- const int32_t dim2 = _shape.dim(2);
- const int32_t dim3 = _shape.dim(3);
- for (auto i = 0; i < _shape.dim(0); ++i)
- {
- for (auto j = 0; j < _shape.dim(1); ++j)
- {
- for (auto k = 0; k < _shape.dim(2); ++k)
- {
- ir::Coordinates coords{i, j, k, 0};
- memcpy(output_buffer + tensor.calcOffset(coords),
- _input_buffer + i * dim1 * dim2 * dim3 + j * dim2 * dim3 + k * dim3,
- dim3 * sizeof(T));
- }
- }
- }
- }
- else
- {
- const auto shape = _shape.asFeature(_io_layout);
-
- if (_io_layout == ir::Layout::NCHW)
- {
- const exec::feature::nchw::Reader<T> from(shape, _input_buffer, _input_size);
- exec::feature::nhwc::View<T> into(&tensor);
- feature::iterate(shape)
- << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) {
- const auto value = from.at(batch, ch, row, col);
- into.at(batch, row, col, ch) = value;
- };
- }
- else if (_io_layout == ir::Layout::NHWC)
- {
- const exec::feature::nhwc::Reader<T> from(shape, _input_buffer, _input_size);
- exec::feature::nchw::View<T> into(&tensor);
- feature::iterate(shape)
- << [&](uint32_t batch, uint32_t ch, uint32_t row, uint32_t col) {
- const auto value = from.at(batch, row, col, ch);
- into.at(batch, ch, row, col) = value;
- };
- }
- else
- {
- throw std::runtime_error("Wrong Layout");
- }
- }
-
- break;
- }
- default:
- throw std::runtime_error("NYI: rank > 4");
- break;
- }
- }
-
-private:
- const T *_input_buffer;
- const size_t _input_size;
- const ir::Shape _shape;
- const bool _copy;
- const ir::Layout _io_layout;
-};
-
-template <typename T> class PermutateSource final : public ITemplSource<T>
-{
-public:
- PermutateSource(const void *input_buffer, const size_t &input_size, const ir::Shape &shape,
- ir::Layout io_layout)
- : ITemplSource<T>(input_buffer, input_size, shape, false, io_layout)
- {
- }
-
-public:
- void push(onert::backend::ITensor &tensor) const override
- {
- // do NHWC_TO_NCHW or NCHW_TO_NHWC permutation
- ITemplSource<T>::pushUnif(tensor);
- }
-};
-
-template <typename T> class CopySource final : public ITemplSource<T>
-{
-public:
- CopySource(const void *input_buffer, const size_t &input_size, const ir::Shape &shape,
- ir::Layout io_layout = ir::Layout::UNKNOWN)
- : ITemplSource<T>(input_buffer, input_size, shape, true, io_layout)
- {
- }
-
-public:
- void push(onert::backend::ITensor &tensor) const override { ITemplSource<T>::pushUnif(tensor); }
-};
-
-} // namespace exec
-} // namespace onert
-
-#endif // __ONERT_EXEC_SOURCE_H__
diff --git a/runtime/onert/core/src/exec/ThreadPool.cc b/runtime/onert/core/src/exec/ThreadPool.cc
index c8e0e3265..bf85e59f6 100644
--- a/runtime/onert/core/src/exec/ThreadPool.cc
+++ b/runtime/onert/core/src/exec/ThreadPool.cc
@@ -48,7 +48,7 @@ uint32_t ThreadPool::numJobsInQueue() { return _worker.numJobsInQueue(); }
void ThreadPool::join()
{
- for (auto &thread : _threads)
+ for (auto &&thread : _threads)
{
thread.join();
}
diff --git a/runtime/onert/core/src/exec/feature/MockTensor.test.h b/runtime/onert/core/src/exec/feature/MockTensor.test.h
new file mode 100644
index 000000000..1d2d375e2
--- /dev/null
+++ b/runtime/onert/core/src/exec/feature/MockTensor.test.h
@@ -0,0 +1,66 @@
+
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "backend/ITensor.h"
+
+template <typename T> class MockTensor : public onert::backend::ITensor
+{
+public:
+ MockTensor<T>(onert::ir::Shape &shape, T *buf, onert::ir::Layout layout)
+ : _buf(reinterpret_cast<uint8_t *>(buf)), _shape(shape), _layout(layout)
+ {
+ }
+
+public:
+ uint8_t *buffer() const override { return _buf; }
+
+ size_t calcOffset(const onert::ir::Coordinates &coords) const override
+ {
+ size_t rank = _shape.rank();
+ rank = rank == 0 ? 1 : rank;
+ size_t offset = 0;
+ for (size_t i = 0; i < rank; ++i)
+ {
+ auto dim = _shape.rank() == 0 ? 1 : _shape.dim(i);
+ offset = offset * dim + coords[i];
+ }
+ offset *= sizeof(T);
+
+ return offset;
+ }
+
+ onert::ir::Shape getShape() const override { return _shape; }
+
+public: // DUMMY methods
+ size_t total_size() const override { return 0; }
+ onert::ir::Layout layout() const override { return _layout; }
+ onert::ir::DataType data_type() const override { return onert::ir::DataType::UINT8; }
+ float data_scale() const override { return 0; }
+ int32_t data_zero_point() const override { return 0; }
+ const std::vector<float> &data_scales() const override { return _dummy_scales; }
+ const std::vector<int32_t> &data_zero_points() const override { return _dummy_zerops; }
+ bool has_padding() const override { return false; }
+ void access(const std::function<void(ITensor &tensor)> &fn) override {}
+ bool is_dynamic() const override { return false; }
+
+private:
+ uint8_t *_buf = nullptr;
+ onert::ir::Shape _shape;
+ onert::ir::Layout _layout = onert::ir::Layout::UNKNOWN;
+ std::vector<float> _dummy_scales;
+ std::vector<int32_t> _dummy_zerops;
+};
diff --git a/runtime/onert/core/src/exec/feature/nchw/Reader.h b/runtime/onert/core/src/exec/feature/nchw/Reader.h
index 7be9df4d5..e1a963cbd 100644
--- a/runtime/onert/core/src/exec/feature/nchw/Reader.h
+++ b/runtime/onert/core/src/exec/feature/nchw/Reader.h
@@ -36,35 +36,35 @@ namespace nchw
template <typename T> class Reader : public feature::Reader<T>
{
public:
- // Construct for buffer of model inputs
- Reader(const ir::FeatureShape &shape, const T *ptr, size_t len)
- : _shape{shape}, _ptr{reinterpret_cast<const uint8_t *>(ptr)}, _len{len}
+ using Strides = ir::FeatureShape;
+ // Construct for buffer and strides
+ Reader(const ir::FeatureShape &shape, const Strides &strides, const T *ptr, size_t len)
+ : _shape{shape}, _strides{strides}, _ptr{reinterpret_cast<const uint8_t *>(ptr)}, _len{len}
{
- assert(shape.N * shape.C * shape.H * shape.W * sizeof(T) == len);
-
- // No padding
- _strides.W = sizeof(T);
- _strides.H = shape.W * sizeof(T);
- _strides.C = shape.W * shape.H * sizeof(T);
- _strides.N = shape.W * shape.H * shape.C * sizeof(T);
+ UNUSED_RELEASE(len); // Workaround for unused variable in release mode
+ assert(len == static_cast<size_t>(strides.N != 0 ? shape.N * strides.N
+ : strides.C != 0 ? shape.C * strides.C
+ : strides.H != 0 ? shape.H * strides.H
+ : shape.W * strides.W));
}
// Construct for backend tensor
Reader(backend::ITensor *tensor)
- : _ptr{tensor->buffer() + tensor->calcOffset({0, 0, 0, 0})}, _len{tensor->total_size()}
+ : _ptr{tensor->buffer() + tensor->calcOffset({0, 0, 0, 0})}, _len{tensor->total_size()}
{
assert(tensor->layout() == ir::Layout::NCHW);
const auto start_offset = tensor->calcOffset({0, 0, 0, 0});
- _strides.W = tensor->dimension(3) == 1 ? 0 : tensor->calcOffset({0, 0, 0, 1}) - start_offset;
- _strides.H = tensor->dimension(2) == 1 ? 0 : tensor->calcOffset({0, 0, 1, 0}) - start_offset;
- _strides.C = tensor->dimension(1) == 1 ? 0 : tensor->calcOffset({0, 1, 0, 0}) - start_offset;
- _strides.N = tensor->dimension(0) == 1 ? 0 : tensor->calcOffset({1, 0, 0, 0}) - start_offset;
-
- _shape.W = tensor->dimension(3);
- _shape.H = tensor->dimension(2);
- _shape.C = tensor->dimension(1);
- _shape.N = tensor->dimension(0);
+ auto shape = tensor->getShape();
+ _strides.W = shape.dim(3) == 1 ? 0 : tensor->calcOffset({0, 0, 0, 1}) - start_offset;
+ _strides.H = shape.dim(2) == 1 ? 0 : tensor->calcOffset({0, 0, 1, 0}) - start_offset;
+ _strides.C = shape.dim(1) == 1 ? 0 : tensor->calcOffset({0, 1, 0, 0}) - start_offset;
+ _strides.N = shape.dim(0) == 1 ? 0 : tensor->calcOffset({1, 0, 0, 0}) - start_offset;
+
+ _shape.W = shape.dim(3);
+ _shape.H = shape.dim(2);
+ _shape.C = shape.dim(1);
+ _shape.N = shape.dim(0);
}
public:
@@ -104,7 +104,6 @@ private:
private:
// TODO Remove _shape
ir::FeatureShape _shape;
- using Strides = ir::FeatureShape;
Strides _strides;
const uint8_t *_ptr;
size_t _len;
diff --git a/runtime/onert/core/src/exec/feature/nchw/Reader.test.cc b/runtime/onert/core/src/exec/feature/nchw/Reader.test.cc
new file mode 100644
index 000000000..c405190f7
--- /dev/null
+++ b/runtime/onert/core/src/exec/feature/nchw/Reader.test.cc
@@ -0,0 +1,85 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Reader.h"
+
+#include "../MockTensor.test.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::exec::feature;
+
+template <typename T> class Reader_nchw : public testing::Test
+{
+public:
+ void setData(std::initializer_list<T> list) { _data = std::make_shared<std::vector<T>>(list); }
+
+ void setShape(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ _shape = onert::ir::FeatureShape(batch, depth, height, width);
+ }
+
+ void setStride(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ auto elem_size = sizeof(T);
+ _stride = onert::ir::FeatureShape(batch * elem_size, depth * elem_size, height * elem_size,
+ width * elem_size);
+ }
+
+ void createReader()
+ {
+ _reader =
+ std::make_shared<nchw::Reader<T>>(_shape, _stride, _data->data(), _data->size() * sizeof(T));
+ }
+
+ void createUsingMockTensor()
+ {
+ onert::ir::Shape shape = {_shape.N, _shape.H, _shape.W, _shape.C};
+ _tensor = std::make_shared<MockTensor<T>>(shape, _data->data(), onert::ir::Layout::NCHW);
+ _reader = std::make_shared<nchw::Reader<T>>(_tensor.get());
+ }
+
+ std::shared_ptr<Reader<T>> _reader = nullptr;
+
+private:
+ std::shared_ptr<std::vector<T>> _data = nullptr;
+ onert::ir::FeatureShape _shape;
+ onert::ir::FeatureShape _stride;
+ std::shared_ptr<MockTensor<T>> _tensor = nullptr;
+};
+
+using ReaderTypes = ::testing::Types<float, int32_t, uint8_t, int8_t, int16_t>;
+TYPED_TEST_SUITE(Reader_nchw, ReaderTypes);
+
+TYPED_TEST(Reader_nchw, basic_reader)
+{
+ this->setData({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ this->setShape(1, 2, 3, 2);
+ this->setStride(12, 6, 2, 1);
+ this->createReader();
+
+ // Data: NCHW
+ // Shape: NCHW
+ ASSERT_EQ(this->_reader->at(0, 1, 1, 0), 8);
+ ASSERT_EQ(this->_reader->at(1, 1, 0), 8);
+
+ // Data: NCHW
+ // Shape: NCHW
+ this->createUsingMockTensor();
+
+ ASSERT_EQ(this->_reader->at(0, 1, 1, 0), 6);
+ ASSERT_EQ(this->_reader->at(1, 1, 0), 6);
+}
diff --git a/runtime/onert/core/src/exec/feature/nchw/View.h b/runtime/onert/core/src/exec/feature/nchw/View.h
index dbaf1a91e..cdbb0cd7c 100644
--- a/runtime/onert/core/src/exec/feature/nchw/View.h
+++ b/runtime/onert/core/src/exec/feature/nchw/View.h
@@ -37,8 +37,10 @@ namespace nchw
template <typename T> class View final : public Reader<T>
{
public:
+ using Strides = typename Reader<T>::Strides;
// Construct for buffer of model inputs
- View(const ir::FeatureShape &shape, T *ptr, size_t len) : Reader<T>{shape, ptr, len}
+ View(const ir::FeatureShape &shape, const Strides &strides, T *ptr, size_t len)
+ : Reader<T>{shape, strides, ptr, len}
{
// DO NOTHING
}
diff --git a/runtime/onert/core/src/exec/feature/nchw/View.test.cc b/runtime/onert/core/src/exec/feature/nchw/View.test.cc
new file mode 100644
index 000000000..d21a8b784
--- /dev/null
+++ b/runtime/onert/core/src/exec/feature/nchw/View.test.cc
@@ -0,0 +1,85 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "View.h"
+
+#include "../MockTensor.test.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::exec::feature;
+
+template <typename T> class View_nchw : public testing::Test
+{
+public:
+ void setData(std::initializer_list<T> list) { _data = std::make_shared<std::vector<T>>(list); }
+
+ void setShape(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ _shape = onert::ir::FeatureShape(batch, depth, height, width);
+ }
+
+ void setStride(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ auto elem_size = sizeof(T);
+ _stride = onert::ir::FeatureShape(batch * elem_size, depth * elem_size, height * elem_size,
+ width * elem_size);
+ }
+
+ void createView()
+ {
+ _view =
+ std::make_shared<nchw::View<T>>(_shape, _stride, _data->data(), _data->size() * sizeof(T));
+ }
+
+ void createUsingMockTensor()
+ {
+ onert::ir::Shape shape = {_shape.N, _shape.H, _shape.W, _shape.C};
+ _tensor = std::make_shared<MockTensor<T>>(shape, _data->data(), onert::ir::Layout::NCHW);
+ _view = std::make_shared<nchw::View<T>>(_tensor.get());
+ }
+
+ std::shared_ptr<nchw::View<T>> _view = nullptr;
+
+private:
+ std::shared_ptr<std::vector<T>> _data = nullptr;
+ onert::ir::FeatureShape _shape;
+ onert::ir::FeatureShape _stride;
+ std::shared_ptr<MockTensor<T>> _tensor = nullptr;
+};
+
+using ViewTypes = ::testing::Types<float, int32_t, uint8_t, int8_t, int16_t>;
+TYPED_TEST_SUITE(View_nchw, ViewTypes);
+
+TYPED_TEST(View_nchw, basic_view)
+{
+ this->setData({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ this->setShape(1, 2, 3, 2);
+ this->setStride(12, 6, 2, 1);
+ this->createView();
+
+ // Data: NCHW
+ // Shape: NCHW
+ ASSERT_EQ(this->_view->at(0, 1, 1, 0), 8);
+ ASSERT_EQ(this->_view->at(1, 1, 0), 8);
+
+ // Data: NCHW
+ // Shape: NCHW
+ this->createUsingMockTensor();
+
+ ASSERT_EQ(this->_view->at(0, 1, 1, 0), 6);
+ ASSERT_EQ(this->_view->at(1, 1, 0), 6);
+}
diff --git a/runtime/onert/core/src/exec/feature/nhwc/Reader.h b/runtime/onert/core/src/exec/feature/nhwc/Reader.h
index 7730cee72..3e3c431bf 100644
--- a/runtime/onert/core/src/exec/feature/nhwc/Reader.h
+++ b/runtime/onert/core/src/exec/feature/nhwc/Reader.h
@@ -37,36 +37,35 @@ namespace nhwc
template <typename T> class Reader : public feature::Reader<T>
{
public:
- // Construct for buffer of model inputs
- Reader(const ir::FeatureShape &shape, const T *ptr, size_t len)
- : _shape{shape}, _ptr{reinterpret_cast<const uint8_t *>(ptr)}, _len{len}
+ using Strides = ir::FeatureShape;
+ // Construct for buffer and strides
+ Reader(const ir::FeatureShape &shape, const Strides &strides, const T *ptr, size_t len)
+ : _shape{shape}, _strides{strides}, _ptr{reinterpret_cast<const uint8_t *>(ptr)}, _len{len}
{
UNUSED_RELEASE(len); // Workaround for unused variable in release mode
- assert(shape.N * shape.C * shape.H * shape.W * sizeof(T) == len);
-
- // No padding
- _strides.C = sizeof(T);
- _strides.W = shape.C * sizeof(T);
- _strides.H = shape.C * shape.W * sizeof(T);
- _strides.N = shape.C * shape.W * shape.H * sizeof(T);
+ assert(len == static_cast<size_t>(strides.N != 0 ? shape.N * strides.N
+ : strides.H != 0 ? shape.H * strides.H
+ : strides.W != 0 ? shape.W * strides.W
+ : shape.C * strides.C));
}
// Construct for backend tensor
Reader(const backend::ITensor *tensor)
- : _ptr{tensor->buffer() + tensor->calcOffset({0, 0, 0, 0})}, _len{tensor->total_size()}
+ : _ptr{tensor->buffer() + tensor->calcOffset({0, 0, 0, 0})}, _len{tensor->total_size()}
{
assert(tensor->layout() == ir::Layout::NHWC);
const auto start_offset = tensor->calcOffset({0, 0, 0, 0});
- _strides.C = tensor->dimension(3) == 1 ? 0 : tensor->calcOffset({0, 0, 0, 1}) - start_offset;
- _strides.W = tensor->dimension(2) == 1 ? 0 : tensor->calcOffset({0, 0, 1, 0}) - start_offset;
- _strides.H = tensor->dimension(1) == 1 ? 0 : tensor->calcOffset({0, 1, 0, 0}) - start_offset;
- _strides.N = tensor->dimension(0) == 1 ? 0 : tensor->calcOffset({1, 0, 0, 0}) - start_offset;
-
- _shape.C = tensor->dimension(3);
- _shape.W = tensor->dimension(2);
- _shape.H = tensor->dimension(1);
- _shape.N = tensor->dimension(0);
+ auto shape = tensor->getShape();
+ _strides.C = shape.dim(3) == 1 ? 0 : tensor->calcOffset({0, 0, 0, 1}) - start_offset;
+ _strides.W = shape.dim(2) == 1 ? 0 : tensor->calcOffset({0, 0, 1, 0}) - start_offset;
+ _strides.H = shape.dim(1) == 1 ? 0 : tensor->calcOffset({0, 1, 0, 0}) - start_offset;
+ _strides.N = shape.dim(0) == 1 ? 0 : tensor->calcOffset({1, 0, 0, 0}) - start_offset;
+
+ _shape.C = shape.dim(3);
+ _shape.W = shape.dim(2);
+ _shape.H = shape.dim(1);
+ _shape.N = shape.dim(0);
}
public:
@@ -106,7 +105,6 @@ private:
private:
// TODO Remove _shape
ir::FeatureShape _shape;
- using Strides = ir::FeatureShape;
Strides _strides;
const uint8_t *_ptr;
size_t _len;
diff --git a/runtime/onert/core/src/exec/feature/nhwc/Reader.test.cc b/runtime/onert/core/src/exec/feature/nhwc/Reader.test.cc
new file mode 100644
index 000000000..1f3a4dd06
--- /dev/null
+++ b/runtime/onert/core/src/exec/feature/nhwc/Reader.test.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Reader.h"
+
+#include "../MockTensor.test.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::exec::feature;
+
+template <typename T> class Reader_nhwc : public testing::Test
+{
+public:
+ void setData(std::initializer_list<T> list) { _data = std::make_shared<std::vector<T>>(list); }
+
+ void setShape(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ _shape = onert::ir::FeatureShape(batch, depth, height, width);
+ }
+
+ void setStride(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ auto elem_size = sizeof(T);
+ _stride = onert::ir::FeatureShape(batch * elem_size, depth * elem_size, height * elem_size,
+ width * elem_size);
+ }
+
+ void createReader()
+ {
+ _reader =
+ std::make_shared<nhwc::Reader<T>>(_shape, _stride, _data->data(), _data->size() * sizeof(T));
+ }
+
+ void createUsingMockTensor()
+ {
+ onert::ir::Shape shape = {_shape.N, _shape.H, _shape.W, _shape.C};
+ _tensor = std::make_shared<MockTensor<T>>(shape, _data->data(), onert::ir::Layout::NHWC);
+ _reader = std::make_shared<nhwc::Reader<T>>(_tensor.get());
+ }
+
+ std::shared_ptr<nhwc::Reader<T>> _reader = nullptr;
+
+private:
+ std::shared_ptr<std::vector<T>> _data = nullptr;
+ onert::ir::FeatureShape _shape;
+ onert::ir::FeatureShape _stride;
+ std::shared_ptr<MockTensor<T>> _tensor = nullptr;
+};
+
+using ReaderTypes = ::testing::Types<float, int32_t, uint8_t, int8_t, int16_t>;
+TYPED_TEST_SUITE(Reader_nhwc, ReaderTypes);
+TYPED_TEST_SUITE(MockTensorReader_nhwc, ReaderTypes);
+
+TYPED_TEST(Reader_nhwc, basic_reader)
+{
+ this->setData({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ this->setShape(1, 2, 3, 2);
+ this->setStride(12, 1, 6, 2);
+ this->createReader();
+
+ // Data: NCHW
+ // Shape: NHWC
+ ASSERT_EQ(this->_reader->at(0, 1, 1, 0), 8);
+ ASSERT_EQ(this->_reader->at(1, 1, 0), 8);
+
+ // Data: NHWC
+ // Shape: NHWC
+ this->createUsingMockTensor();
+
+ ASSERT_EQ(this->_reader->at(0, 1, 1, 0), 6);
+ ASSERT_EQ(this->_reader->at(1, 1, 0), 6);
+}
diff --git a/runtime/onert/core/src/exec/feature/nhwc/View.h b/runtime/onert/core/src/exec/feature/nhwc/View.h
index 72c8c3415..c98d050c3 100644
--- a/runtime/onert/core/src/exec/feature/nhwc/View.h
+++ b/runtime/onert/core/src/exec/feature/nhwc/View.h
@@ -17,7 +17,7 @@
#ifndef __ONERT_EXEC_FEATURE_NHWC_VIEW_H__
#define __ONERT_EXEC_FEATURE_NHWC_VIEW_H__
-#include "../Reader.h"
+#include "Reader.h"
#include <cassert>
#include <cstddef>
@@ -38,8 +38,10 @@ namespace nhwc
template <typename T> class View final : public Reader<T>
{
public:
- // Construct for buffer of model inputs
- View(const ir::FeatureShape &shape, T *ptr, size_t len) : Reader<T>{shape, ptr, len}
+ using Strides = typename Reader<T>::Strides;
+ // Construct for buffer and strides
+ View(const ir::FeatureShape &shape, const Strides &strides, T *ptr, size_t len)
+ : Reader<T>{shape, strides, ptr, len}
{
// DO NOTHING
}
diff --git a/runtime/onert/core/src/exec/feature/nhwc/View.test.cc b/runtime/onert/core/src/exec/feature/nhwc/View.test.cc
new file mode 100644
index 000000000..c9018660c
--- /dev/null
+++ b/runtime/onert/core/src/exec/feature/nhwc/View.test.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "View.h"
+
+#include "../MockTensor.test.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::exec::feature;
+
+template <typename T> class View_nhwc : public testing::Test
+{
+public:
+ void setData(std::initializer_list<T> list) { _data = std::make_shared<std::vector<T>>(list); }
+
+ void setShape(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ _shape = onert::ir::FeatureShape(batch, depth, height, width);
+ }
+
+ void setStride(int32_t batch, int32_t depth, int32_t height, int32_t width)
+ {
+ auto elem_size = sizeof(T);
+ _stride = onert::ir::FeatureShape(batch * elem_size, depth * elem_size, height * elem_size,
+ width * elem_size);
+ }
+
+ void createView()
+ {
+ _view =
+ std::make_shared<nhwc::View<T>>(_shape, _stride, _data->data(), _data->size() * sizeof(T));
+ }
+
+ void createUsingMockTensor()
+ {
+ onert::ir::Shape shape = {_shape.N, _shape.H, _shape.W, _shape.C};
+ _tensor = std::make_shared<MockTensor<T>>(shape, _data->data(), onert::ir::Layout::NHWC);
+ _view = std::make_shared<nhwc::View<T>>(_tensor.get());
+ }
+
+ std::shared_ptr<nhwc::View<T>> _view = nullptr;
+
+private:
+ std::shared_ptr<std::vector<T>> _data = nullptr;
+ onert::ir::FeatureShape _shape;
+ onert::ir::FeatureShape _stride;
+ std::shared_ptr<MockTensor<T>> _tensor = nullptr;
+};
+
+using ViewTypes = ::testing::Types<float, int32_t, uint8_t, int8_t, int16_t>;
+TYPED_TEST_SUITE(View_nhwc, ViewTypes);
+TYPED_TEST_SUITE(MockTensorView_nhwc, ViewTypes);
+
+TYPED_TEST(View_nhwc, basic_view)
+{
+ this->setData({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+ this->setShape(1, 2, 3, 2);
+ this->setStride(12, 1, 6, 2);
+ this->createView();
+
+ // Data: NCHW
+ // Shape: NHWC
+ ASSERT_EQ(this->_view->at(0, 1, 1, 0), 8);
+ ASSERT_EQ(this->_view->at(1, 1, 0), 8);
+
+ // Data: NHWC
+ // Shape: NHWC
+ this->createUsingMockTensor();
+
+ ASSERT_EQ(this->_view->at(0, 1, 1, 0), 6);
+ ASSERT_EQ(this->_view->at(1, 1, 0), 6);
+}
diff --git a/runtime/onert/core/src/exec/train/TrainableExecutor.cc b/runtime/onert/core/src/exec/train/TrainableExecutor.cc
new file mode 100644
index 000000000..5d7c4f3f7
--- /dev/null
+++ b/runtime/onert/core/src/exec/train/TrainableExecutor.cc
@@ -0,0 +1,225 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TrainableExecutor.h"
+#ifdef RUY_PROFILER
+#include "ruy/profiler/instrumentation.h"
+#endif
+
+#include <misc/polymorphic_downcast.h>
+
+namespace onert
+{
+namespace exec
+{
+namespace train
+{
+
+TrainableExecutor::TrainableExecutor(
+ std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+ backend::train::TrainableBackendContexts &&backend_contexts,
+ const compiler::train::TensorRegistries &tensor_regs,
+ compiler::train::TrainableCodeMap &&code_map,
+ const std::vector<ir::OperationIndex> &forward_order,
+ const std::vector<ir::OperationIndex> &backward_order, const util::TracingCtx *tracing_ctx,
+ const ir::train::LossInfo &loss_info)
+ : _code_map{std::move(code_map)}, _forward_order{std::move(forward_order)},
+ _backward_order{std::move(backward_order)}, _lowered_graph{std::move(lowered_graph)},
+ _backend_contexts{std::move(backend_contexts)},
+ _trainable_graph{_lowered_graph->trainable_graph()}, _tensor_regs{std::move(tensor_regs)},
+ _mutex(), _tracing_ctx(tracing_ctx), _loss_info(loss_info)
+{
+ auto build_tensor_list = [&](const auto &ind_seq, auto &tensors) {
+ assert(tensors.empty());
+ for (auto &&ind : ind_seq)
+ {
+ backend::ITensor *tensor = _tensor_regs.getITensor(ind);
+ assert(tensor != nullptr);
+ auto io_tensor = nnfw::misc::polymorphic_downcast<backend::builtin::IOTensor *>(tensor);
+ tensors.push_back(io_tensor);
+ }
+ };
+ build_tensor_list(_trainable_graph.getInputs(), _input_tensors);
+ build_tensor_list(_trainable_graph.getOutputs(), _output_tensors);
+}
+
+void TrainableExecutor::forward(const std::vector<backend::IPortableTensor *> &inputs,
+ const std::vector<backend::IPortableTensor *> &outputs,
+ const ExecutionOptions &options, bool training)
+{
+ // For thread-safe, use mutex
+ // TODO: if all used backends on this executor are thread-safe,
+ // do not need to use mutex (otherwise, use mutex)
+ std::lock_guard<std::mutex> lock(_mutex);
+ _current_options = options;
+
+ assert(_input_tensors.size() == inputs.size());
+ for (uint32_t i = 0; i < _input_tensors.size(); ++i)
+ {
+ auto tensor = _input_tensors[i];
+ const auto input = inputs[i];
+ assert(input->buffer() != nullptr || input->get_info().total_size() == 0);
+ assert(tensor != nullptr);
+ tensor->setTensor(input);
+ }
+
+ // Set output(s)
+ assert(_output_tensors.size() == outputs.size());
+ for (uint32_t i = 0; i < _output_tensors.size(); ++i)
+ {
+ auto tensor = _output_tensors[i];
+ const auto output = outputs[i];
+ // Output may not be used on training, so don't check optional
+ assert(tensor != nullptr);
+ tensor->setTensor(output);
+ }
+
+ // Create observee
+ ExecutionObservee subject(_observers, options);
+
+ forwardImpl(subject, training);
+
+ // TODO Update output(s) desc if desc has dynamic input
+}
+
+void TrainableExecutor::forwardImpl(const ExecutionObservee &subject, bool training)
+{
+ if (!subject.isEmpty() && _tracing_ctx)
+ {
+ auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph());
+
+ subject.notifySubgraphBegin(profiling_subg_index);
+ for (auto &&index : _forward_order)
+ {
+ const auto &code = _code_map.at(index);
+ const auto backend = code.lower_info->backend();
+// TODO : Move ruy profiler into ExecutionObserver
+#ifdef RUY_PROFILER
+ ruy::profiler::ScopeLabel label(code.op->name());
+#endif
+ subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend);
+
+ auto &tn_seq = code.tn_seq;
+ tn_seq->forward(training && code.op->isRequiredForBackward());
+
+ subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend);
+ }
+ subject.notifySubgraphEnd(profiling_subg_index);
+ }
+ else
+ {
+ for (auto &&index : _forward_order)
+ {
+ const auto &code = _code_map.at(index);
+// TODO : Move ruy profiler into ExecutionObserver
+#ifdef RUY_PROFILER
+ ruy::profiler::ScopeLabel label(code.op->name());
+#endif
+ auto &tn_seq = code.tn_seq;
+ tn_seq->forward(training && code.op->isRequiredForBackward());
+ }
+ }
+}
+
+void TrainableExecutor::backward(const ExecutionOptions &options, uint32_t training_step)
+{
+ // For thread-safe, use mutex
+ // TODO: if all used backends on this executor are thread-safe,
+ // do not need to use mutex (otherwise, use mutex)
+ std::lock_guard<std::mutex> lock(_mutex);
+ _current_options = options;
+
+ // Create observee
+ ExecutionObservee subject(_observers, options);
+
+ backwardImpl(subject, training_step);
+}
+
+void TrainableExecutor::backwardImpl(const ExecutionObservee &subject, uint32_t training_step)
+{
+ if (!subject.isEmpty() && _tracing_ctx)
+ {
+ auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph());
+
+ subject.notifySubgraphBegin(profiling_subg_index);
+ for (auto &&index : _backward_order)
+ {
+ const auto &code = _code_map.at(index);
+ if (!code.op->isRequiredForBackward())
+ {
+ continue;
+ }
+ const auto backend = code.lower_info->backend();
+// TODO : Move ruy profiler into ExecutionObserver
+#ifdef RUY_PROFILER
+ ruy::profiler::ScopeLabel label(code.op->name());
+#endif
+ subject.notifyJobBegin(this, profiling_subg_index, code.op_ind, backend);
+
+ auto &tn_seq = code.tn_seq;
+ tn_seq->backward(training_step, code.op->isWeightsUpdateEnabled());
+
+ subject.notifyJobEnd(this, profiling_subg_index, code.op_ind, backend);
+ }
+ subject.notifySubgraphEnd(profiling_subg_index);
+ }
+ else
+ {
+ for (auto &&index : _backward_order)
+ {
+ const auto &code = _code_map.at(index);
+ if (!code.op->isRequiredForBackward())
+ {
+ continue;
+ }
+// TODO : Move ruy profiler into ExecutionObserver
+#ifdef RUY_PROFILER
+ ruy::profiler::ScopeLabel label(code.op->name());
+#endif
+ auto &tn_seq = code.tn_seq;
+ tn_seq->backward(training_step, code.op->isWeightsUpdateEnabled());
+ }
+ }
+}
+
+float TrainableExecutor::getLoss(const ir::IOIndex &pred_io_ind) const
+{
+ const auto &loss_ind = _trainable_graph.getLossIndex(pred_io_ind);
+ if (loss_ind.undefined())
+ throw std::runtime_error{"Loss " + std::to_string(loss_ind.value()) + " is not defined."};
+ backend::ITensor *tensor = _tensor_regs.getITensor(loss_ind);
+ long double sum = 0;
+ for (uint64_t i = 0; i < tensor->getShape().num_elements(); ++i)
+ {
+ sum += reinterpret_cast<float *>(tensor->buffer())[i];
+ }
+ if (_loss_info.reduction_type == ir::train::LossReductionType::SumOverBatchSize)
+ {
+ sum /= tensor->getShape().num_elements();
+ }
+ return static_cast<float>(sum);
+}
+
+void TrainableExecutor::iterateTrainableTensors(
+ const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &fn)
+ const
+{
+ _tensor_regs.iterateTrainableTensors(fn);
+}
+
+} // namespace train
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/train/TrainableExecutor.h b/runtime/onert/core/src/exec/train/TrainableExecutor.h
new file mode 100644
index 000000000..986c2236c
--- /dev/null
+++ b/runtime/onert/core/src/exec/train/TrainableExecutor.h
@@ -0,0 +1,143 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_
+#define __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_
+
+#include "exec/IExecutor.h"
+
+#include "../ExecutionObservee.h"
+#include "../../compiler/train/TensorRegistries.h"
+
+#include "backend/train/TrainableBackendContext.h"
+#include "compiler/train/TrainableCodeMap.h"
+#include "compiler/train/LoweredTrainableGraph.h"
+#include "ir/train/LossInfo.h"
+#include "ir/Index.h"
+#include "util/TracingCtx.h"
+
+namespace onert
+{
+namespace exec
+{
+namespace train
+{
+
+class TrainableExecutor : public IExecutor
+{
+public:
+ /**
+ * @brief Construct a new TrainableExecutor object
+ * @param lowered_graph LoweredTrainableGraph object
+ * @param tensor_builders Tensor builders that are currently used
+ * @param code_map @c ir::Operation and its code map
+ */
+ TrainableExecutor(std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
+ backend::train::TrainableBackendContexts &&backend_contexts,
+ const compiler::train::TensorRegistries &tensor_regs,
+ compiler::train::TrainableCodeMap &&code_map,
+ const std::vector<ir::OperationIndex> &forward_order,
+ const std::vector<ir::OperationIndex> &backward_order,
+ const util::TracingCtx *tracing_ctx, const ir::train::LossInfo &training_info);
+
+public:
+ const ir::Graph &graph() const final { return _trainable_graph.graph(); }
+
+ void execute(const std::vector<backend::IPortableTensor *> &inputs,
+ const std::vector<backend::IPortableTensor *> &outputs,
+ const ExecutionOptions &options) override
+ {
+ forward(inputs, outputs, options, false);
+ }
+
+ uint32_t inputSize() const override { return _input_tensors.size(); }
+
+ uint32_t outputSize() const override { return _output_tensors.size(); }
+
+ const ir::OperandInfo &inputInfo(uint32_t index) const override
+ {
+ return _input_tensors[index]->get_info();
+ }
+
+ const ir::OperandInfo &outputInfo(uint32_t index) const override
+ {
+ return _output_tensors[index]->get_info();
+ }
+
+ ir::Layout inputLayout(uint32_t index) const override { return _input_tensors[index]->layout(); }
+
+ ir::Layout outputLayout(uint32_t index) const override
+ {
+ return _output_tensors[index]->layout();
+ }
+
+ void forward(const std::vector<backend::IPortableTensor *> &inputs,
+ const std::vector<backend::IPortableTensor *> &outputs,
+ const ExecutionOptions &options, bool training);
+ void backward(const ExecutionOptions &options, uint32_t training_step);
+
+ // Used only in Dataflow and Parallel Executors
+ void setIndexedRanks(std::shared_ptr<ir::OperationIndexMap<int64_t>> ranks) final
+ {
+ _indexed_ranks = std::move(ranks);
+ };
+
+ void addObserver(std::unique_ptr<IExecutionObserver> ref) { _observers.add(std::move(ref)); };
+
+ float getLoss(const ir::IOIndex &pred_io_ind) const;
+
+ void iterateTrainableTensors(
+ const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)>
+ &fn) const;
+
+ backend::train::TrainableBackendContexts &getBackendContexts() { return _backend_contexts; }
+
+ const ExecutionOptions &currentOptions() const override { return _current_options; }
+
+private:
+ void forwardImpl(const ExecutionObservee &subject, bool training);
+ void backwardImpl(const ExecutionObservee &subject, uint32_t training_step);
+
+private:
+ compiler::train::TrainableCodeMap _code_map;
+ std::vector<ir::OperationIndex> _forward_order;
+ std::vector<ir::OperationIndex> _backward_order;
+ ExecObservers _observers;
+ std::shared_ptr<ir::OperationIndexMap<int64_t>> _indexed_ranks;
+ std::unique_ptr<compiler::train::LoweredTrainableGraph> _lowered_graph;
+ backend::train::TrainableBackendContexts _backend_contexts;
+ const ir::train::TrainableGraph &_trainable_graph;
+ compiler::train::TensorRegistries _tensor_regs;
+ std::vector<backend::builtin::IOTensor *> _input_tensors;
+ std::vector<backend::builtin::IOTensor *> _output_tensors;
+ std::mutex _mutex;
+ const util::TracingCtx *_tracing_ctx;
+ const ir::train::LossInfo _loss_info;
+ /**
+ * It is set by execute() method only in thread-safe environment.
+ * It is used for non-primary executor call on builtin backend
+ * and accessed by entryExecutor's currentOptions() method.
+ *
+ * TODO: Find better way to pass config to non-primary executor
+ */
+ ExecutionOptions _current_options;
+};
+
+} // namespace train
+} // namespace exec
+} // namespace onert
+
+#endif // __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTOR_H_
diff --git a/runtime/onert/core/src/exec/train/TrainableExecutors.cc b/runtime/onert/core/src/exec/train/TrainableExecutors.cc
new file mode 100644
index 000000000..73217c836
--- /dev/null
+++ b/runtime/onert/core/src/exec/train/TrainableExecutors.cc
@@ -0,0 +1,142 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TrainableExecutors.h"
+
+#include "../../backend/builtin/IOTensor.h"
+
+#include <misc/polymorphic_downcast.h>
+
+namespace onert
+{
+namespace exec
+{
+namespace train
+{
+
+void TrainableExecutors::emplace(const ir::ModelIndex &, const ir::SubgraphIndex &subg_index,
+ std::unique_ptr<IExecutor> exec)
+{
+ std::unique_ptr<TrainableExecutor> t_exec{
+ nnfw::misc::polymorphic_downcast<TrainableExecutor *>(exec.release())};
+ _executors.emplace(subg_index, std::move(t_exec));
+}
+
+TrainableExecutor *TrainableExecutors::at(const ir::ModelIndex &,
+ const ir::SubgraphIndex &subg_index) const
+{
+ return _executors.at(subg_index).get();
+}
+
+uint32_t TrainableExecutors::inputSize() const { return entryExecutor()->inputSize(); }
+
+uint32_t TrainableExecutors::outputSize() const { return entryExecutor()->outputSize(); }
+
+const ir::OperandInfo &TrainableExecutors::inputInfo(const ir::IOIndex &index) const
+{
+ return entryExecutor()->inputInfo(index.value());
+}
+
+const ir::OperandInfo &TrainableExecutors::outputInfo(const ir::IOIndex &index) const
+{
+ return entryExecutor()->outputInfo(index.value());
+}
+
+void TrainableExecutors::execute(const ExecutionContext &ctx)
+{
+ if (_executors.size() > 1)
+ throw std::runtime_error("TrainableExecutors does not support multiple executors yet");
+
+ // UserTensor for Input/Output
+ std::vector<std::unique_ptr<backend::builtin::UserTensor>> tensorpool;
+
+ // Allocate UserTensor and call executor forward
+ forward(ctx, tensorpool, false);
+
+ // TODO Support multple executors
+}
+
+void TrainableExecutors::train(const ExecutionContext &ctx, uint32_t training_step)
+{
+ if (_executors.size() > 1)
+ throw std::runtime_error("TrainableExecutors does not support multiple executors yet");
+
+ // UserTensor for Input/Output
+ std::vector<std::unique_ptr<backend::builtin::UserTensor>> tensorpool;
+
+ // Allocate UserTensor and call executor forward and backward
+ forward(ctx, tensorpool, true);
+ entryExecutor()->backward(ctx.options, training_step);
+
+ // TODO Support multple executors
+}
+
+void TrainableExecutors::forward(
+ const ExecutionContext &ctx,
+ std::vector<std::unique_ptr<backend::builtin::UserTensor>> &tensorpool, bool training)
+{
+ // Input/Output Tensor vector for executor
+ std::vector<backend::IPortableTensor *> inputs(ctx.desc.inputs.size());
+ std::vector<backend::IPortableTensor *> outputs(ctx.desc.outputs.size());
+
+ // Prepare UserTensor for input
+ for (uint32_t i = 0; i < inputs.size(); i++)
+ {
+ auto &desc = ctx.desc.inputs[i];
+
+ // Input is optional if buffer is nullptr, and optional input's size is 0
+ if (desc->buffer == nullptr && (desc->size != 0 || desc->info.total_size() != 0))
+ throw std::runtime_error{"Input " + std::to_string(i) + "'s buffer is not set."};
+
+ tensorpool.emplace_back(std::make_unique<backend::builtin::UserTensor>(
+ desc->info, desc->layout, const_cast<uint8_t *>(static_cast<const uint8_t *>(desc->buffer)),
+ desc->size));
+ inputs[i] = tensorpool.back().get();
+ }
+
+ // Prepare UserTensor for output
+ for (uint32_t i = 0; i < outputs.size(); i++)
+ {
+ auto &desc = ctx.desc.outputs[i];
+
+ // If training, output buffer may not be used
+ // So don't check optional
+ tensorpool.emplace_back(std::make_unique<backend::builtin::UserTensor>(
+ desc->info, desc->layout, static_cast<uint8_t *>(desc->buffer), desc->size));
+ outputs[i] = tensorpool.back().get();
+ }
+
+ // Call forward
+ entryExecutor()->forward(inputs, outputs, ctx.options, training);
+}
+
+float TrainableExecutors::getLoss(const ir::IOIndex &index) const
+{
+ if (_executors.size() > 1)
+ throw std::runtime_error("TrainableExecutors does not support multiple executors yet");
+ return entryExecutor()->getLoss(index);
+}
+
+void TrainableExecutors::iterateTrainableTensors(
+ const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)> &fn)
+ const
+{
+ return entryExecutor()->iterateTrainableTensors(fn);
+}
+
+} // namespace train
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exec/train/TrainableExecutors.h b/runtime/onert/core/src/exec/train/TrainableExecutors.h
new file mode 100644
index 000000000..ae120f6f0
--- /dev/null
+++ b/runtime/onert/core/src/exec/train/TrainableExecutors.h
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__
+#define __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__
+
+#include "TrainableExecutor.h"
+#include "exec/IExecutors.h"
+#include "ir/NNPkg.h"
+
+namespace onert
+{
+namespace exec
+{
+namespace train
+{
+
+/**
+ * @brief Class to gather executor set for trainable model NN package
+ */
+class TrainableExecutors : public IExecutors
+{
+public:
+ /**
+ * @brief Construct a new TrainableExecutors object
+ */
+ TrainableExecutors(void) = default;
+ TrainableExecutors(const TrainableExecutors &) = delete;
+ TrainableExecutors(TrainableExecutors &&) = default;
+
+ /**
+ * @brief Destroy the TrainableExecutors object
+ */
+ ~TrainableExecutors() = default;
+
+public:
+ TrainableExecutors &operator=(const TrainableExecutors &) = delete;
+ TrainableExecutors &operator=(TrainableExecutors &&) = default;
+
+public:
+ void emplace(const ir::ModelIndex &model_index, const ir::SubgraphIndex &subg_index,
+ std::unique_ptr<IExecutor> exec) override;
+
+ TrainableExecutor *at(const ir::ModelIndex &model_index,
+ const ir::SubgraphIndex &subg_index) const override;
+
+ TrainableExecutor *entryExecutor() const { return at(ir::ModelIndex{0}, ir::SubgraphIndex{0}); }
+
+ uint32_t inputSize() const override;
+
+ uint32_t outputSize() const override;
+
+ const ir::OperandInfo &inputInfo(const ir::IOIndex &index) const override;
+
+ const ir::OperandInfo &outputInfo(const ir::IOIndex &index) const override;
+
+ void execute(const ExecutionContext &ctx) override;
+
+ /**
+ * @brief Train
+ *
+ * @param ctx Execution context
+ * @param training_step The number of iterations of an training process.
+ * In other words, the number of gradient update.
+ */
+ void train(const ExecutionContext &ctx, uint32_t training_step);
+
+ float getLoss(const ir::IOIndex &index) const;
+
+ void iterateTrainableTensors(
+ const std::function<void(const ir::OperandIndex &, const backend::train::ITrainableTensor *)>
+ &fn) const;
+
+private:
+ // If you want to use I/O buffer on step, tensorpool should be alive until one step is finished
+ // So this method get tensorpool from outside.
+ // tensorpool is not defined as a member variable to avoid memory access conflict between threads.
+ void forward(const ExecutionContext &ctx,
+ std::vector<std::unique_ptr<backend::builtin::UserTensor>> &tensorpool,
+ bool training);
+
+private:
+ // TODO Append model index to ModelIndex
+ std::unordered_map<ir::SubgraphIndex, std::unique_ptr<TrainableExecutor>> _executors;
+};
+
+} // namespace train
+} // namespace exec
+} // namespace onert
+
+#endif // __ONERT_EXEC_TRAIN_TRAINABLE_EXECUTORS_H__
diff --git a/runtime/onert/core/src/exec/train/TrainableFnSequence.cc b/runtime/onert/core/src/exec/train/TrainableFnSequence.cc
new file mode 100644
index 000000000..36e4c3171
--- /dev/null
+++ b/runtime/onert/core/src/exec/train/TrainableFnSequence.cc
@@ -0,0 +1,69 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "exec/train/TrainableFnSequence.h"
+
+namespace onert
+{
+namespace exec
+{
+namespace train
+{
+
+void TrainableFnSequence::forward(bool training)
+{
+ for (const auto &function : _functions)
+ {
+ function->forward(training);
+ }
+}
+
+void TrainableFnSequence::backward(uint32_t training_step, bool weight_update_enabled)
+{
+ for (auto it = _functions.rbegin(); it != _functions.rend(); ++it)
+ {
+ (*it)->backward();
+ }
+ if (weight_update_enabled)
+ {
+ for (const auto &applier : _appliers)
+ {
+ applier->applyGradient(training_step);
+ }
+ }
+}
+
+void TrainableFnSequence::append(std::unique_ptr<ITrainableFunction> &&function)
+{
+ _functions.push_back(std::move(function));
+}
+
+void TrainableFnSequence::append(std::unique_ptr<IGradientApplier> &&applier)
+{
+ _appliers.push_back(std::move(applier));
+}
+
+void TrainableFnSequence::iterate(const std::function<void(ITrainableFunction &)> &fn)
+{
+ for (const auto &func : _functions)
+ {
+ fn(*func);
+ }
+}
+
+} // namespace train
+} // namespace exec
+} // namespace onert
diff --git a/runtime/onert/core/src/exporter/CircleExporter.cc b/runtime/onert/core/src/exporter/CircleExporter.cc
new file mode 100644
index 000000000..b9ac8d5bb
--- /dev/null
+++ b/runtime/onert/core/src/exporter/CircleExporter.cc
@@ -0,0 +1,153 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "exporter/CircleExporter.h"
+
+#include "exec/Execution.h"
+#include "ir/train/TrainingInfo.h"
+#include "circle_schema_generated.h"
+#include "TrainInfoBuilder.h"
+
+#include <fstream>
+#include <iostream>
+
+namespace onert
+{
+namespace exporter
+{
+
+CircleExporter::CircleExporter(const std::string &source, const std::string &path)
+ : _path{path}, _data{}, _model{nullptr}
+{
+ // make sure the architecture is little endian before direct access to flatbuffers
+ assert(FLATBUFFERS_LITTLEENDIAN);
+
+ std::ifstream src(source.c_str(), std::ios::binary);
+ if (src.is_open())
+ {
+ src.seekg(0, std::ios::end);
+ _data.resize(src.tellg());
+ src.seekg(0, std::ios::beg);
+ src.read(&_data[0], static_cast<std::streamsize>(_data.size()));
+ src.close();
+ }
+
+ if (_data.size() == 0)
+ throw std::runtime_error("Invalid source file");
+
+ const auto model = ::circle::GetModel(_data.data());
+ if (!model)
+ throw std::runtime_error("Failed to load original circle file");
+ _model.reset(model->UnPack());
+}
+
+CircleExporter::~CircleExporter() { finish(); }
+
+void CircleExporter::updateWeight(const std::unique_ptr<exec::Execution> &exec)
+{
+ exec->iterateTrainableTensors(
+ [&](const ir::OperandIndex &idx, const backend::train::ITrainableTensor *tensor) {
+ std::lock_guard<std::mutex> guard(_mutex);
+ const auto &subgs = _model->subgraphs;
+ if (subgs.size() != 1)
+ throw std::runtime_error("Circle does not has valid subgraph or has multiple subgraphs");
+
+ if (!idx.valid())
+ throw std::runtime_error("Trainable tensor is invalid");
+
+ uint32_t buf_idx = -1;
+ const auto &subg = subgs.at(0); // Get 1st subgraph
+ if (idx.value() >= subg->tensors.size())
+ {
+ auto buffer = std::make_unique<::circle::BufferT>();
+ buffer->size = tensor->total_size();
+ buffer->data.resize(buffer->size);
+
+ buf_idx = _model->buffers.size();
+ _model->buffers.push_back(std::move(buffer));
+ }
+ else
+ {
+ buf_idx = subg->tensors.at(idx.value())->buffer;
+ if (buf_idx >= _model->buffers.size())
+ throw std::runtime_error("Buffer for trainable tensors is invalid");
+ }
+
+ const auto &buffer = _model->buffers.at(buf_idx);
+
+ auto org_buf_sz = buffer->data.size();
+ if (org_buf_sz != tensor->total_size())
+ throw std::runtime_error("Trained tensor buffer size does not match original tensor's one");
+
+ memcpy(buffer->data.data(), tensor->buffer(), org_buf_sz);
+ });
+}
+
+void CircleExporter::updateMetadata(const std::unique_ptr<ir::train::TrainingInfo> &training_info)
+{
+ const char *const TRAININFO_METADATA_NAME = "CIRCLE_TRAINING";
+
+ TrainInfoBuilder tbuilder(training_info);
+ bool found = false;
+ for (const auto &meta : _model->metadata)
+ {
+ if (meta->name == std::string{TRAININFO_METADATA_NAME})
+ {
+ std::lock_guard<std::mutex> guard(_mutex);
+ const uint32_t buf_idx = meta->buffer;
+ auto &buffer = _model->buffers.at(buf_idx);
+
+ if (tbuilder.size() != buffer->data.size())
+ {
+ buffer->data.resize(tbuilder.size());
+ buffer->size = tbuilder.size();
+ }
+
+ memcpy(buffer->data.data(), tbuilder.get(), tbuilder.size());
+ found = true;
+ break;
+ }
+ }
+
+ if (!found)
+ {
+ std::lock_guard<std::mutex> guard(_mutex);
+ auto buffer = std::make_unique<::circle::BufferT>();
+ buffer->size = tbuilder.size();
+ buffer->data.resize(buffer->size);
+ memcpy(buffer->data.data(), tbuilder.get(), buffer->size);
+
+ auto meta = std::make_unique<::circle::MetadataT>();
+ meta->name = std::string{TRAININFO_METADATA_NAME};
+ meta->buffer = _model->buffers.size();
+
+ _model->buffers.push_back(std::move(buffer));
+ _model->metadata.push_back(std::move(meta));
+ }
+}
+
+void CircleExporter::finish()
+{
+ flatbuffers::FlatBufferBuilder builder(1024);
+ builder.Finish(::circle::Model::Pack(builder, _model.get()), ::circle::ModelIdentifier());
+
+ std::ofstream dst(_path.c_str(), std::ios::binary);
+ dst.write(reinterpret_cast<const char *>(builder.GetBufferPointer()),
+ static_cast<std::streamsize>(builder.GetSize()));
+ dst.close();
+}
+} // namespace exporter
+} // namespace onert
diff --git a/runtime/onert/core/src/exporter/TrainInfoBuilder.h b/runtime/onert/core/src/exporter/TrainInfoBuilder.h
new file mode 100644
index 000000000..c3084b462
--- /dev/null
+++ b/runtime/onert/core/src/exporter/TrainInfoBuilder.h
@@ -0,0 +1,116 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_EXPORTER_TRAININFO_BUILDER_H__
+#define __ONERT_EXPORTER_TRAININFO_BUILDER_H__
+
+#include "ir/train/TrainingInfo.h"
+#include "circle_schema_generated.h"
+#include "circle_traininfo_generated.h"
+
+namespace onert
+{
+namespace exporter
+{
+
+class TrainInfoBuilder
+{
+public:
+ TrainInfoBuilder(const std::unique_ptr<ir::train::TrainingInfo> &training_info) : _builder(1024)
+ {
+ const auto &optimizerInfo = training_info->optimizerInfo();
+ const auto &lossInfo = training_info->lossInfo();
+
+ ::circle::Optimizer optimizer;
+ ::circle::OptimizerOptions optimizer_opt_type;
+ ::flatbuffers::Offset<void> optimizer_opt;
+ switch (optimizerInfo.optim_code)
+ {
+ case ir::train::OptimizerCode::SGD:
+ optimizer = ::circle::Optimizer_SGD;
+ optimizer_opt_type = ::circle::OptimizerOptions_SGDOptions;
+ optimizer_opt = ::circle::CreateSGDOptions(_builder, optimizerInfo.learning_rate).Union();
+ break;
+ case ir::train::OptimizerCode::Adam:
+ optimizer = ::circle::Optimizer_ADAM;
+ optimizer_opt_type = ::circle::OptimizerOptions_AdamOptions;
+ optimizer_opt = ::circle::CreateAdamOptions(_builder, optimizerInfo.learning_rate).Union();
+ break;
+ default:
+ throw std::runtime_error("Not supported optimizer code");
+ }
+
+ ::circle::LossFn lossfn;
+ ::circle::LossFnOptions lossfn_opt_type;
+ ::flatbuffers::Offset<void> lossfn_opt;
+ switch (lossInfo.loss_code)
+ {
+ case ir::train::LossCode::MeanSquaredError:
+ lossfn = ::circle::LossFn_MEAN_SQUARED_ERROR;
+ lossfn_opt_type = ::circle::LossFnOptions_MeanSquaredErrorOptions;
+ lossfn_opt = ::circle::CreateMeanSquaredErrorOptions(_builder).Union();
+ break;
+ case ir::train::LossCode::CategoricalCrossentropy:
+ lossfn = ::circle::LossFn_CATEGORICAL_CROSSENTROPY;
+ lossfn_opt_type = ::circle::LossFnOptions_CategoricalCrossentropyOptions;
+ lossfn_opt = ::circle::CreateCategoricalCrossentropyOptions(_builder).Union();
+ break;
+ default:
+ throw std::runtime_error("Not supported loss code");
+ }
+
+ ::circle::LossReductionType loss_reduction_type;
+ switch (lossInfo.reduction_type)
+ {
+ case ir::train::LossReductionType::SumOverBatchSize:
+ loss_reduction_type = ::circle::LossReductionType_SumOverBatchSize;
+ break;
+ case ir::train::LossReductionType::Sum:
+ loss_reduction_type = ::circle::LossReductionType_Sum;
+ break;
+ default:
+ throw std::runtime_error("Not supported loss reduction type");
+ }
+
+ std::vector<int32_t> trainable_ops;
+ for (const auto &op : training_info->getTrainableOps())
+ {
+ trainable_ops.push_back(op.value());
+ }
+
+ const auto end = ::circle::CreateModelTrainingDirect(
+ _builder, training_info->version(), optimizer, optimizer_opt_type, optimizer_opt, lossfn,
+ lossfn_opt_type, lossfn_opt, 0, training_info->batchSize(), loss_reduction_type,
+ &trainable_ops);
+ _builder.Finish(end, ::circle::ModelTrainingIdentifier());
+
+ ::flatbuffers::Verifier v(_builder.GetBufferPointer(), _builder.GetSize());
+ bool verified = ::circle::VerifyModelTrainingBuffer(v);
+ if (not verified)
+ throw std::runtime_error{"TrainingInfo buffer is not accessible"};
+ }
+
+ uint8_t *get() const { return _builder.GetBufferPointer(); }
+ uint32_t size() const { return _builder.GetSize(); }
+
+private:
+ ::flatbuffers::FlatBufferBuilder _builder;
+};
+
+} // namespace exporter
+} // namespace onert
+
+#endif // __ONERT_EXPORTER_TRAININFO_BUILDER_H__
diff --git a/runtime/onert/core/src/interp/Buffer.h b/runtime/onert/core/src/interp/Buffer.h
deleted file mode 100644
index 24938f74f..000000000
--- a/runtime/onert/core/src/interp/Buffer.h
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * @file Buffer.h
- * @brief This file contains Buffer interface and InternalBuffer, ExternalBuffer class
- */
-#ifndef __ONERT_INTERP_BUFFER_H__
-#define __ONERT_INTERP_BUFFER_H__
-
-#include <memory>
-
-#include "ir/Data.h"
-
-namespace onert
-{
-namespace interp
-{
-
-/**
- * @brief Interface for writable data area
- */
-class Buffer : public ir::Data
-{
-public:
- /**
- * @brief Return writable pointer for data area
- * @return Writable pointer
- */
- virtual uint8_t *baseWritable(void) const = 0;
-};
-
-/**
- * @brief Class for internally allocated data area
- */
-class InternalBuffer final : public Buffer
-{
-public:
- InternalBuffer(size_t size) : _base{std::make_unique<uint8_t[]>(size)}, _size{size}
- {
- // DO NOTHING
- }
-
-public:
- size_t size(void) const override { return _size; }
- const uint8_t *base(void) const override { return _base.get(); }
- uint8_t *baseWritable(void) const override { return _base.get(); }
-
-private:
- std::unique_ptr<uint8_t[]> _base;
- size_t _size;
-};
-
-/**
- * @brief Class for data area from outside
- */
-class ExternalBuffer final : public Buffer
-{
-public:
- ExternalBuffer(uint8_t *base, size_t size) : _base{base}, _size{size}
- {
- // DO NOTHING
- }
-
-public:
- size_t size(void) const override { return _size; }
- const uint8_t *base(void) const override { return _base; }
- uint8_t *baseWritable(void) const override { return _base; }
-
-private:
- uint8_t *_base;
- size_t _size;
-};
-
-} // namespace interp
-} // namespace onert
-
-#endif // __ONERT_INTERP_BUFFER_H__
diff --git a/runtime/onert/core/src/interp/ExecEnv.h b/runtime/onert/core/src/interp/ExecEnv.h
deleted file mode 100644
index 7f577ea6e..000000000
--- a/runtime/onert/core/src/interp/ExecEnv.h
+++ /dev/null
@@ -1,212 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * @file ExecEnv.h
- * @brief This file contains ExecEnv to access interpreter tensor and execution status
- */
-#ifndef __ONERT_INTERP_EXEC_ENV_H_
-#define __ONERT_INTERP_EXEC_ENV_H_
-
-#include <unordered_set>
-
-#include "ir/Graph.h"
-#include "Tensor.h"
-
-namespace onert
-{
-namespace interp
-{
-
-/**
- * @brief Class to gather interpreter execution environment
- * Each interpreter instance own execution environment
- */
-class ExecEnv
-{
-public:
- /**
- * @brief Construct a new Exec Env object (deleted)
- */
- ExecEnv(void) = delete;
- /**
- * @brief Construct a new ExecEnv object
- * @param[in] graph Graph to execute by interpreter
- */
- explicit ExecEnv(const ir::Graph &graph) : _graph(graph)
- {
- // DO NOTHING
- }
-
-public:
- /**
- * @brief Return graph to execute
- * @return Graph
- */
- const ir::Graph &graph(void) const { return _graph; }
- /**
- * @brief Assign tensor to environment which have allocated or assigned buffer
- * @param[in] index Tensor index
- * @param[in] tensor Tensor
- */
- void assignTensor(const ir::OperandIndex index, std::shared_ptr<ITensor> tensor)
- {
- assert(tensor->bufferRO() != nullptr);
- _tensors.emplace(index, tensor);
- }
-
- /**
- * @brief Return tensor pointer in environment
- * @param[in] index Tensor index
- * can_optional @c True if tensor can be optional input, otherwise @c false
- * @return Tensor pointer
- */
- const ITensor *tensorAt(const ir::OperandIndex index, bool can_optional = false) const
- {
- if (_tensors.find(index) == _tensors.end())
- {
- // It may optional input,
- // otherwise input is not set by runtime user
- if (can_optional)
- {
- return nullptr;
- }
-
- throw std::runtime_error{"ExecEnv: Input is not set"};
- }
-
- return _tensors.at(index).get();
- }
-
- /**
- * @brief Check environment contains tensor
- * @param[in] index Tensor index
- * @return @c true if environment contain tensor, otherwise @c false
- */
- bool contains(const ir::OperandIndex index) const
- {
- return (_tensors.find(index) != _tensors.end());
- }
-
- /**
- * @brief Allocate tensor using operand info
- * @param[in] index Tensor index
- * @param[in] info Operand info
- * @note If already allocated, just return
- * @TODO More smart allocation policy
- */
- void allocateIfNeeded(const ir::OperandIndex index, const ir::OperandInfo &info)
- {
- // already allocated, or constant
- if (contains(index))
- {
- return;
- }
-
- // Buffer from external (ex. model output)
- auto tensor = std::make_shared<Tensor>(info);
- if (isExtBuffer(index))
- {
- tensor->setBuffer(_external_buffers.at(index));
- assignTensor(index, tensor);
-
- return;
- }
-
- tensor->setBuffer(std::make_shared<InternalBuffer>(tensor->total_size()));
- assignTensor(index, tensor);
- _buffers.insert(index);
- }
-
- /**
- * @brief Allocate read-only tensor and share data with other tensor
- * @param[in] index Tensor index
- * @param[in] info Operand info
- * @param[in] index_to_share Tensor index that have data to share
- */
- void allocateAndShareIfNeeded(const ir::OperandIndex index, const ir::OperandInfo &info,
- const ir::OperandIndex index_to_share)
- {
- if (!contains(index_to_share))
- {
- throw std::runtime_error{"Cannot find tensor to share data"};
- }
-
- // already allocated
- if (contains(index))
- {
- return;
- }
-
- if (isExtBuffer(index))
- {
- auto tensor = std::make_shared<Tensor>(info);
- tensor->setBuffer(_external_buffers.at(index));
- assignTensor(index, tensor);
- }
- else
- {
- auto tensor = std::make_shared<ROTensor>(info);
- tensor->setData(tensorAt(index_to_share)->shareData());
- assignTensor(index, tensor);
- _buffers.insert(index);
- }
- }
-
- /**
- * @brief Free buffer if allocated by allocateIfNeed
- * @param[in] index Tensor index
- * @note If allocated by outside, just return
- */
- void freeIfAllocated(const ir::OperandIndex index)
- {
- if (_buffers.find(index) != _buffers.end())
- {
- _tensors.at(index)->releaseData();
- }
- }
-
- /**
- * @brief Assign ExternalBuffer into external buffer map
- * @param[in] index Tensor index
- * @param[in] buffer External buffer
- */
- void assignExternalBuffer(const ir::OperandIndex index, std::shared_ptr<ExternalBuffer> buffer)
- {
- _external_buffers.emplace(index, buffer);
- }
-
-private:
- bool isExtBuffer(const ir::OperandIndex index)
- {
- return (_external_buffers.find(index) != _external_buffers.end());
- }
-
-private:
- const ir::Graph &_graph;
- // Tensor map to use in interpreter
- // It should map tensors that have allocated or assigned buffer pointer
- std::unordered_map<ir::OperandIndex, std::shared_ptr<ITensor>> _tensors;
- // Tensors allocated by allocateIfNeed (buffer)
- std::unordered_set<ir::OperandIndex> _buffers;
- // Tensor buffer from external
- std::unordered_map<ir::OperandIndex, std::shared_ptr<ExternalBuffer>> _external_buffers;
-};
-
-} // namespace interp
-} // namespace onert
-
-#endif // __ONERT_INTERP_EXEC_ENV_H_
diff --git a/runtime/onert/core/src/interp/InterpExecutor.cc b/runtime/onert/core/src/interp/InterpExecutor.cc
deleted file mode 100644
index cd31a4dca..000000000
--- a/runtime/onert/core/src/interp/InterpExecutor.cc
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "interp/InterpExecutor.h"
-#include "interp/ExecEnv.h"
-#include "interp/Interpreter.h"
-
-#include "util/logging.h"
-
-#include <memory>
-
-namespace onert
-{
-namespace interp
-{
-
-void InterpExecutor::execute(const exec::IODescription &desc)
-{
- /************************************************************************
- * Prepare execution model (submodel)
- It may execute divided model
- but now consider model inference is done at interpreter
- ***********************************************************************/
- ir::OperandIndexMap<std::shared_ptr<ITensor>> tensor_map;
-
- for (uint32_t n = 0; n < _graph.getInputs().size(); n++)
- {
- ir::IOIndex index{n};
- const auto input_index = _graph.getInputs().at(index);
-
- const auto input = desc.inputs.at(n).get();
- if (input == nullptr)
- {
- // Optional input
- continue;
- }
-
- auto input_tensor = std::make_shared<ROTensor>(input->info);
- input_tensor->setData(std::make_shared<const ir::ExternalData>(
- reinterpret_cast<const uint8_t *>(input->buffer), input->size));
- tensor_map[input_index] = input_tensor;
- }
-
- /************************************************************************
- * Prepare execution environment
- Execution environment will be assigned to invoked interpreter instance
- ***********************************************************************/
-
- std::unique_ptr<ExecEnv> interp_env = std::make_unique<ExecEnv>(_graph);
-
- // Assign input/output tensor into interpreter execution environment
- for (auto index : _graph.getInputs())
- {
- if (tensor_map.find(index) != tensor_map.end())
- {
- VERBOSE(INTERPRETER) << "Assign input tensor. operand index:" << index.value() << std::endl;
- interp_env->assignTensor(index, tensor_map.at(index));
- }
- }
-
- for (uint32_t n = 0; n < _graph.getOutputs().size(); n++)
- {
- ir::IOIndex index{n};
- const auto output_index = _graph.getOutputs().at(index);
- const auto output = desc.outputs.at(n).get();
- if (output == nullptr)
- {
- // Optional output
- continue;
- }
-
- VERBOSE(INTERPRETER) << "Set out buffer to ExecEnv. operand index:" << output_index.value()
- << std::endl;
-
- interp_env->assignExternalBuffer(
- output_index, std::make_shared<ExternalBuffer>(reinterpret_cast<uint8_t *>(output->buffer),
- output->size));
- }
-
- // Allocate constant tensor
- _graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) {
- if (obj.isConstant())
- {
- VERBOSE(INTERPRETER) << "Allocate and assign constant tensor. operand index:" << ind.value()
- << std::endl;
-
- assert(obj.data());
- auto const_tensor = std::make_shared<ROTensor>(obj.info());
- // Assume that interpreter's tensor layout is same with model (NHWC)
- const_tensor->setData(
- std::make_shared<ir::ExternalData>(obj.data()->base(), obj.info().total_size()));
- interp_env->assignTensor(ind, const_tensor);
- }
- });
-
- /*****************************************************************************
- * Invoke interpreter
- ****************************************************************************/
-
- interp::Interpreter interp(std::move(interp_env));
- interp.run();
-
- /*****************************************************************************
- * Invoked interpreter run is finished
- ****************************************************************************/
-
- // If interpreter execute submodel
- // 1. Get tensor output of submodel into tensor_map to save result
- // 2. Generate new ExecEnv for next interpretation
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/InterpExecutor.h b/runtime/onert/core/src/interp/InterpExecutor.h
deleted file mode 100644
index 2e3f3ca54..000000000
--- a/runtime/onert/core/src/interp/InterpExecutor.h
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * @file InterpExecutor.h
- * @brief This file contains InterpExecutor class\n
- * to manage interpreter execution and environment
- */
-#ifndef __ONERT_INTERP_INTERP_EXECUTOR_H__
-#define __ONERT_INTERP_INTERP_EXECUTOR_H__
-
-#include "ir/OperandIndexMap.h"
-#include "ir/Graph.h"
-#include "exec/IExecutor.h"
-
-namespace onert
-{
-namespace interp
-{
-
-class ITensor;
-
-/**
- * @brief Class to execute model using interpreter
- */
-class InterpExecutor final : public exec::IExecutor
-{
-public:
- explicit InterpExecutor(const ir::Graph &graph) : _graph(graph)
- {
- // DO NOTHING
- }
-
-public:
- /**
- * @brief Return graph object
- * @return Graph object
- */
- const ir::Graph &graph() final { return _graph; }
- void setIndexedRanks(std::shared_ptr<ir::OperationIndexMap<int64_t>>) override{
- // Not implemented
- };
- /**
- * @brief Start execution
- * @note It should be called after setting input and output buffer
- */
- void execute(const exec::IODescription &desc) final;
-
-private:
- const ir::Graph &_graph;
- ir::OperandIndexMap<std::shared_ptr<ITensor>> _tensor_map;
-};
-
-} // namespace interp
-} // namespace onert
-
-#endif // __ONERT_INTERP_INTERP_EXECUTOR_H__
diff --git a/runtime/onert/core/src/interp/InterpOps.lst b/runtime/onert/core/src/interp/InterpOps.lst
deleted file mode 100644
index 0714df38a..000000000
--- a/runtime/onert/core/src/interp/InterpOps.lst
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef INTERP_OP
-#error Define INTERP_OP before including this file
-#endif
-
-// Supported operation name in interpreter
-//
-// Same list with Operations.lst
-// Make comment out if operation is not supported in interpreter
-INTERP_OP(BinaryArithmetic)
-//INTERP_OP(BatchToSpaceND)
-//INTERP_OP(Cast)
-INTERP_OP(Conv2D)
-INTERP_OP(DepthwiseConv2D)
-INTERP_OP(Pool2D)
-INTERP_OP(Concat)
-INTERP_OP(FullyConnected)
-//INTERP_OP(Reduce)
-INTERP_OP(Reshape)
-INTERP_OP(Softmax)
-//INTERP_OP(Squeeze)
-//INTERP_OP(Slice)
-//INTERP_OP(StridedSlice)
-INTERP_OP(ElementwiseActivation)
-//INTERP_OP(Transpose)
-//INTERP_OP(Exp)
-//INTERP_OP(Comparison)
-//INTERP_OP(LogicalNot)
-//INTERP_OP(LSTM)
-//INTERP_OP(RSQRT)
-//INTERP_OP(ResizeBilinear)
-//INTERP_OP(RNN)
-//INTERP_OP(Floor)
-//INTERP_OP(SpaceToBatchND)
-//INTERP_OP(SpaceToDepth)
-//INTERP_OP(EmbeddingLookup)
-//INTERP_OP(L2Normalization)
-//INTERP_OP(HashtableLookup)
-INTERP_OP(InstanceNorm)
-//INTERP_OP(PReLU)
-INTERP_OP(TransposeConv)
-//INTERP_OP(SQRT)
-//INTERP_OP(SquaredDifference)
-//INTERP_OP(TopKV2)
-INTERP_OP(Gather)
-//INTERP_OP(Neg)
-//INTERP_OP(Abs)
-//INTERP_OP(ArgMax)
-//INTERP_OP(Dequantize)
-//INTERP_OP(LocalResponseNormalization)
-//INTERP_OP(DepthToSpace)
-//INTERP_OP(Pack)
-//INTERP_OP(Split)
-//INTERP_OP(Unpack)
-INTERP_OP(Pad)
-//INTERP_OP(Custom)
-//INTERP_OP(Permute)
-//INTERP_OP(OneHot)
diff --git a/runtime/onert/core/src/interp/Interpreter.cc b/runtime/onert/core/src/interp/Interpreter.cc
deleted file mode 100644
index b92afbe73..000000000
--- a/runtime/onert/core/src/interp/Interpreter.cc
+++ /dev/null
@@ -1,184 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "Interpreter.h"
-
-#include <stack>
-#include <unordered_set>
-
-#include "Registration.h"
-
-#include "ir/OperandIndexMap.h"
-#include "util/logging.h"
-#include "ir/OperationVisitor.h"
-
-namespace onert
-{
-namespace interp
-{
-
-// TODO more structured execution kernel implementation
-// TODO use cker for execution
-// TODO divide tensor prepare and execution
-// TODO introduce memory manager (buffer allocate and free)
-class OperationExecutor
-{
-public:
- OperationExecutor(ExecEnv *env) : _env{env}
- {
-#define INTERP_OP(InternalName) _kernels[ir::OpCode::InternalName] = get##InternalName();
-#include "InterpOps.lst"
-#undef INTERP_OP
- }
-
- void execute(const ir::OperationIndex &idx)
- {
- const ir::Operation &node = _env->graph().operations().at(idx);
- const auto nodeName = node.name();
- VERBOSE(INTERPRETER) << "Prepare output operands and execute " << nodeName
- << " operation (id: " << idx.value() << ")" << std::endl;
-
- const auto nodeOpCode = node.opcode();
- if (_kernels.find(nodeOpCode) == _kernels.end())
- {
- throw std::runtime_error{"Interpreter: Operation " + nodeName + " is not yet implemented"};
- }
-
- if (_kernels[nodeOpCode]->prepare != nullptr)
- {
- _kernels[nodeOpCode]->prepare(_env, node);
- }
- _kernels[nodeOpCode]->invoke(_env, node);
- }
-
-private:
- ExecEnv *_env;
- std::unordered_map<ir::OpCode, OpKernel *> _kernels;
-};
-
-void Interpreter::run()
-{
- VERBOSE(INTERPRETER) << "Interpreter is invoked " << std::endl;
-
- // operand_stack: save operands prepared to use
- std::stack<ir::OperandIndex> operand_stack;
-
- // Note: We should push input first, then constant.
- // We use use-def for find operators ready to execution,
- // but Use-Def cannot handle parameters (maybe constant, but not always)
- // Note: If all model inputs are constant, it may not work (depend on tensors' order).
- // But that scenario may not exist
- for (auto ind : _env->graph().getInputs())
- {
- VERBOSE(INTERPRETER) << "Input: Push to operand stack " << ind.value() << std::endl;
-
- operand_stack.push(ind);
- }
-
- _env->graph().operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) {
- if (obj.isConstant())
- {
- VERBOSE(INTERPRETER) << "Constant: Push to operand stack " << ind.value() << std::endl;
-
- operand_stack.push(ind);
- }
- });
-
- // Execution
- std::unordered_set<ir::OperandIndex> ready_check;
- std::unordered_set<ir::OperationIndex> executed;
- OperationExecutor executor{_env.get()};
- while (!operand_stack.empty())
- {
- const auto current_operand_index = operand_stack.top();
- operand_stack.pop();
- VERBOSE(INTERPRETER) << "Poped operand " << current_operand_index.value()
- << " is checked ready to use" << std::endl;
-
- assert(ready_check.find(current_operand_index) == ready_check.end());
- ready_check.insert(current_operand_index);
-
- // Find prepared operations by scan use of current operand
- std::stack<ir::OperationIndex> operation_stack;
- const auto use_operators = _env->graph().operands().at(current_operand_index).getUses();
- for (const auto &use_operator : use_operators)
- {
- // Assumption: all parameters are ready to use
- bool operator_ready = true;
- for (auto input_index : _env->graph().operations().at(use_operator).getInputs())
- {
- if (ready_check.find(input_index) == ready_check.end())
- {
- operator_ready = false;
- break;
- }
- }
-
- if (operator_ready)
- {
- VERBOSE(INTERPRETER) << "Ready to execute operation " << use_operator.value() << std::endl;
- operation_stack.push(use_operator);
- }
- }
-
- while (!operation_stack.empty())
- {
- const auto current_operation_index = operation_stack.top();
- operation_stack.pop();
- VERBOSE(INTERPRETER) << "Poped operation: " << current_operation_index.value() << "("
- << _env->graph().operations().at(current_operation_index).name() << ")"
- << std::endl;
-
- // execution
- // 1. Prepare output tensor
- // 2. Call operation kernel
- executor.execute(current_operation_index);
- executed.insert(current_operation_index);
-
- // 3. Push each output into operand stack
- const auto def_operands = _env->graph().operations().at(current_operation_index).getOutputs();
- for (auto def_operand : def_operands)
- {
- VERBOSE(INTERPRETER) << "Buffer: Push to operand stack " << def_operand.value()
- << std::endl;
- operand_stack.push(def_operand);
- }
-
- // 4. Free if lifetime of buffer operands used by input is finished
- for (auto input_index : _env->graph().operations().at(current_operation_index).getInputs())
- {
- const auto use_operators = _env->graph().operands().at(input_index).getUses();
- bool dead_buffer = true;
- for (const auto &use_operator : use_operators)
- {
- if (executed.find(use_operator) == executed.end())
- {
- dead_buffer = false;
- break;
- }
- }
-
- if (dead_buffer)
- {
- _env->freeIfAllocated(input_index);
- }
- }
- }
- }
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/Interpreter.h b/runtime/onert/core/src/interp/Interpreter.h
deleted file mode 100644
index d2165f538..000000000
--- a/runtime/onert/core/src/interp/Interpreter.h
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * @file Interpreter.h
- * @brief This file contains Interpreter class for interpretation
- */
-#ifndef __ONERT_INTERP_INTERPRETER_H__
-#define __ONERT_INTERP_INTERPRETER_H__
-
-#include "ExecEnv.h"
-
-namespace onert
-{
-namespace interp
-{
-
-/**
- * @brief Class for interpretation
- */
-class Interpreter
-{
-
-public:
- /**
- * @brief Construct a new Interpreter object (deleted)
- */
- Interpreter() = delete;
- /**
- * @brief Construct a new Interpreter object
- * @param[in] env Execution environment variable for interpreter object
- */
- Interpreter(std::unique_ptr<ExecEnv> env) : _env{std::move(env)}
- {
- // DO NOTHING
- }
-
-public:
- /**
- * @brief Run interpreter until there is no operation to execute
- */
- void run();
-
-private:
- std::unique_ptr<ExecEnv> _env;
-};
-
-} // namespace interp
-} // namespace onert
-
-#endif // __ONERT_INTERP_INTERPRETER_H__
diff --git a/runtime/onert/core/src/interp/Registration.h b/runtime/onert/core/src/interp/Registration.h
deleted file mode 100644
index 956b92a53..000000000
--- a/runtime/onert/core/src/interp/Registration.h
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __ONERT_INTERP_REGISTRATION_H__
-#define __ONERT_INTERP_REGISTRATION_H__
-
-#include "ExecEnv.h"
-
-#include "ir/Operation.h"
-
-namespace onert
-{
-namespace interp
-{
-
-struct OpKernel
-{
- std::function<void(ExecEnv *, const ir::Operation &)> prepare;
- std::function<void(const ExecEnv *, const ir::Operation &)> invoke;
-};
-
-// Defined in operations/ directory
-#define INTERP_OP(InternalName) OpKernel *get##InternalName();
-#include "InterpOps.lst"
-#undef INTERP_OP
-
-} // namespace interp
-} // namespace onert
-
-#endif // __ONERT_INTERP_REGISTRATION_H__
diff --git a/runtime/onert/core/src/interp/Tensor.cc b/runtime/onert/core/src/interp/Tensor.cc
deleted file mode 100644
index 07f8b75dc..000000000
--- a/runtime/onert/core/src/interp/Tensor.cc
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "Tensor.h"
-
-#define NO_USE(a) (void)(a)
-
-namespace onert
-{
-namespace interp
-{
-
-void ITensor::access(const std::function<void(backend::ITensor &tensor)> &fn) { fn(*this); }
-
-size_t ROTensor::calcOffset(const ir::Coordinates &coords) const
-{
- NO_USE(coords);
- throw std::runtime_error("offset_element_in_bytes is not supported for cpu::Tensor now.");
-}
-
-size_t Tensor::calcOffset(const ir::Coordinates &coords) const
-{
- NO_USE(coords);
- throw std::runtime_error("offset_element_in_bytes is not supported for cpu::Tensor now.");
-}
-
-ir::Layout ROTensor::layout() const
-{
- // TODO Changes to return frontend layout
- return ir::Layout::NHWC;
-}
-
-ir::Layout Tensor::layout() const
-{
- // TODO Changes to return frontend layout
- return ir::Layout::NHWC;
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/Tensor.h b/runtime/onert/core/src/interp/Tensor.h
deleted file mode 100644
index 008a4b9d4..000000000
--- a/runtime/onert/core/src/interp/Tensor.h
+++ /dev/null
@@ -1,184 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * @file Tensor.h
- * @brief This file contains ITensor interface, ROTensor class, and Tensor class
- */
-#ifndef __ONERT_INTERP_TENSOR_H__
-#define __ONERT_INTERP_TENSOR_H__
-
-#include "Buffer.h"
-
-#include "ir/OperandInfo.h"
-#include "backend/ITensor.h"
-#include "ir/Layout.h"
-
-namespace onert
-{
-namespace interp
-{
-
-/**
- * @brief Interface to handle Tensor in interpreter
- */
-class ITensor : public backend::ITensor
-{
-public:
- virtual ~ITensor() = default;
-
-public:
- virtual uint8_t *buffer() const = 0;
- /**
- * @brief Return shared pointer for buffer
- * @return Buffer shared pointer
- */
- virtual std::shared_ptr<const Buffer> shareBuffer() const = 0;
- /**
- * @brief Return read-only buffer pointer
- * @return Read-only buffer pointer
- */
- virtual const uint8_t *bufferRO() const = 0;
- /**
- * @brief Return shared pointer for data
- * @return Data shared pointer
- */
- virtual std::shared_ptr<const ir::Data> shareData() const = 0;
- /**
- * @brief Set internal/external buffer
- * @param[in] buffer Buffer pointer
- */
- virtual void setBuffer(std::shared_ptr<const Buffer> buffer) = 0;
- /**
- * @brief Set data reference (including constant, input)
- * @param[in] data Data pointer
- */
- virtual void setData(std::shared_ptr<const ir::Data> data) = 0;
- virtual void releaseData() = 0;
-
- virtual size_t total_size() const = 0;
- virtual size_t dimension(size_t index) const = 0;
- virtual size_t num_dimensions() const = 0;
- virtual size_t calcOffset(const ir::Coordinates &coords) const = 0;
-
- virtual bool has_padding() const = 0;
- /**
- * @brief Return data type of tensor
- * @return Data type of tensor
- */
- virtual ir::DataType data_type() const = 0;
- /**
- * @brief Return TensorInfo
- * @return TensorInfo
- */
- virtual const ir::OperandInfo &tensorInfo() const = 0;
- /**
- * @brief Return number of elements
- * @return Number of elements
- */
- virtual uint64_t num_elements() const = 0;
- void access(const std::function<void(backend::ITensor &tensor)> &fn) final;
-};
-
-/**
- * @brief Class to handle tensor in interpreter as read-only
- */
-class ROTensor final : public ITensor
-{
-public:
- ROTensor() = delete;
- ROTensor(const ir::OperandInfo &info) : _info(info)
- {
- // DO NOTHING
- }
-
-public:
- uint8_t *buffer() const override { throw std::runtime_error{"Read only tensor"}; }
- std::shared_ptr<const Buffer> shareBuffer() const override
- {
- throw std::runtime_error{"Read only tensor"};
- }
- const uint8_t *bufferRO() const override { return _data->base(); }
- std::shared_ptr<const ir::Data> shareData() const override { return _data; }
- void setBuffer(std::shared_ptr<const Buffer> buffer) override { _data = buffer; }
- void setData(std::shared_ptr<const ir::Data> data) override { _data = data; }
- void releaseData() override { _data = nullptr; }
-
- size_t total_size() const override { return _info.total_size(); }
- size_t dimension(size_t index) const override { return _info.shape().dim(index); }
- size_t num_dimensions() const override { return _info.shape().rank(); }
- size_t calcOffset(const ir::Coordinates &coords) const override;
- ir::Layout layout() const override;
- bool is_dynamic() const override { return false; }
- bool has_padding() const override { return false; }
- ir::DataType data_type() const override { return _info.typeInfo().type(); }
- float data_scale() const override { return _info.typeInfo().scale(); }
- int32_t data_offset() const override { return _info.typeInfo().offset(); }
- const ir::OperandInfo &tensorInfo() const override { return _info; }
- uint64_t num_elements() const override { return _info.shape().num_elements(); };
-
-private:
- const ir::OperandInfo _info;
- std::shared_ptr<const ir::Data> _data{nullptr};
-};
-
-/**
- * @brief Class to handle tensor in interpreter as writable
- */
-class Tensor final : public ITensor
-{
-public:
- Tensor() = delete;
- Tensor(const ir::OperandInfo &info) : _info(info)
- {
- // DO NOTHING
- }
-
-public:
- uint8_t *buffer() const override { return _buffer->baseWritable(); }
- std::shared_ptr<const Buffer> shareBuffer() const override { return _buffer; };
- const uint8_t *bufferRO() const override { return _buffer->base(); }
- std::shared_ptr<const ir::Data> shareData() const override { return _buffer; }
- void setBuffer(std::shared_ptr<const Buffer> buffer) override { _buffer = buffer; }
- void setData(std::shared_ptr<const ir::Data>) override
- {
- throw std::runtime_error{"Passed data may read-only"};
- }
- void releaseData() override { _buffer = nullptr; }
-
- size_t total_size() const override { return _info.total_size(); }
- size_t dimension(size_t index) const override { return _info.shape().dim(index); }
- size_t num_dimensions() const override { return _info.shape().rank(); }
- size_t calcOffset(const ir::Coordinates &coords) const override;
- ir::Layout layout() const override;
- bool is_dynamic() const override { return false; }
- bool has_padding() const override { return false; }
- ir::DataType data_type() const override { return _info.typeInfo().type(); }
- float data_scale() const override { return _info.typeInfo().scale(); }
- int32_t data_offset() const override { return _info.typeInfo().offset(); }
- const ir::OperandInfo &tensorInfo() const override { return _info; }
- uint64_t num_elements() const override { return _info.shape().num_elements(); };
- backend::IDynamicTensorManager *dynamic_tensor_manager() override { return nullptr; }
-
-private:
- const ir::OperandInfo _info;
- std::shared_ptr<const Buffer> _buffer{nullptr};
-};
-
-} // namespace interp
-} // namespace onert
-
-#endif // __ONERT_INTERP_TENSOR_H__
diff --git a/runtime/onert/core/src/interp/operations/BinaryArithmeticOps.cc b/runtime/onert/core/src/interp/operations/BinaryArithmeticOps.cc
deleted file mode 100644
index 86e883524..000000000
--- a/runtime/onert/core/src/interp/operations/BinaryArithmeticOps.cc
+++ /dev/null
@@ -1,205 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <cker/operation/BinaryArithmeticOps.h>
-
-#include "OperationUtil.h"
-
-#include "interp/Registration.h"
-#include "ir/operation/BinaryArithmetic.h"
-#include "misc/polymorphic_downcast.h"
-#include "cker/Types.h"
-
-namespace onert
-{
-namespace interp
-{
-namespace
-{
-
-enum class OpType
-{
- ADD,
- SUB,
- MUL
-};
-
-void prepare(ExecEnv *env, const ir::Operation &node)
-{
- const auto &arithmetic_node =
- nnfw::misc::polymorphic_downcast<const ir::operation::BinaryArithmetic &>(node);
-
- const auto lhs_index = node.getInputs().at(arithmetic_node.LHS);
- const auto rhs_index = node.getInputs().at(arithmetic_node.RHS);
- const auto out_index = node.getOutputs().at(0);
-
- const auto lhs_tensor = env->tensorAt(lhs_index);
- const auto rhs_tensor = env->tensorAt(rhs_index);
-
- // Check shape and type lhs is same with rhs
- // TODO Util function to compare TensorInfo
- if (lhs_tensor->data_type() != rhs_tensor->data_type())
- {
- throw std::runtime_error{"Interp(" + arithmetic_node.name() + "): Different input types"};
- }
-
- bool try_broadcast = (lhs_tensor->tensorInfo().shape() != rhs_tensor->tensorInfo().shape());
- if (try_broadcast)
- {
- bool success = true;
- auto out_shape = calcBroadcastShape(lhs_tensor->tensorInfo().shape(),
- rhs_tensor->tensorInfo().shape(), success);
- if (!success)
- {
- throw std::runtime_error{"Interp(" + arithmetic_node.name() + "): Fail to brodcasting"};
- }
-
- auto output_info =
- ir::OperandInfo::createStaticInfo(out_shape, lhs_tensor->tensorInfo().typeInfo());
- // We can handle already allocated (ex. model output)
- env->allocateIfNeeded(out_index, output_info);
- }
- else
- {
- // Output's shape and type is same with input
- auto output_info = lhs_tensor->tensorInfo();
- // We can handle already allocated (ex. model output)
- env->allocateIfNeeded(out_index, output_info);
- }
-
- auto out_tensor = env->tensorAt(out_index);
- // Check shape and type lhs is same with output
- // TODO Util function to compare TensorInfo
- if (lhs_tensor->data_type() != out_tensor->data_type())
- {
- throw std::runtime_error{"Interp(" + arithmetic_node.name() + "): Invalid output type"};
- }
-}
-
-inline void setActivationParams(float min, float max, nnfw::cker::BinaryArithmeticOpParam *params)
-{
- params->float_activation_min = min;
- params->float_activation_max = max;
-}
-
-inline void setActivationParams(int32_t min, int32_t max,
- nnfw::cker::BinaryArithmeticOpParam *params)
-{
- params->quantized_activation_min = min;
- params->quantized_activation_max = max;
-}
-
-template <typename raw_type, OpType op_type>
-void invoke(const ITensor *lhs_tensor, const ITensor *rhs_tensor, const ITensor *out_tensor,
- const ir::operation::BinaryArithmetic::Param &param)
-{
- const auto lhs_buffer = lhs_tensor->bufferRO();
- const auto rhs_buffer = rhs_tensor->bufferRO();
- auto out_buffer = out_tensor->buffer();
-
- nnfw::cker::BinaryArithmeticOpParam cker_param;
- raw_type activation_min, activation_max;
- calculateActivationRange(param.activation, &activation_min, &activation_max);
- setActivationParams(activation_min, activation_max, &cker_param);
- const raw_type *lhs_ptr = reinterpret_cast<const raw_type *>(lhs_buffer);
- const raw_type *rhs_ptr = reinterpret_cast<const raw_type *>(rhs_buffer);
- raw_type *out_ptr = reinterpret_cast<raw_type *>(out_buffer);
-
- const auto cker_op_type =
- (op_type == OpType::ADD)
- ? nnfw::cker::BinaryArithmeticOpType::ADD
- : ((op_type == OpType::SUB) ? nnfw::cker::BinaryArithmeticOpType::SUB
- : nnfw::cker::BinaryArithmeticOpType::MUL);
-
- const bool need_broadcast = nnfw::cker::ProcessBroadcastShapes(
- convertShape(lhs_tensor->tensorInfo().shape()),
- convertShape(rhs_tensor->tensorInfo().shape()), &cker_param);
-
- if (need_broadcast)
- {
- const auto lhs_shape = convertShape(lhs_tensor->tensorInfo().shape());
- const auto rhs_shape = convertShape(rhs_tensor->tensorInfo().shape());
- const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
- nnfw::cker::BroadcastBinaryArithmeticOp<cker_op_type>(cker_param, lhs_shape, lhs_ptr, rhs_shape,
- rhs_ptr, out_shape, out_ptr);
- return;
- }
-
- const auto lhs_shape = convertShape(lhs_tensor->tensorInfo().shape());
- const auto rhs_shape = convertShape(rhs_tensor->tensorInfo().shape());
- const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
- nnfw::cker::BinaryArithmeticOp<cker_op_type>(cker_param, lhs_shape, lhs_ptr, rhs_shape, rhs_ptr,
- out_shape, out_ptr);
-}
-
-template <OpType op_type>
-void invokeBinaryArithmetic(const ExecEnv *env, const ir::operation::BinaryArithmetic &node)
-{
- const auto lhs_index = node.getInputs().at(node.LHS);
- const auto rhs_index = node.getInputs().at(node.RHS);
- const auto out_index = node.getOutputs().at(0);
- const auto lhs_tensor = env->tensorAt(lhs_index);
- const auto rhs_tensor = env->tensorAt(rhs_index);
- const auto out_tensor = env->tensorAt(out_index);
- const auto data_type = lhs_tensor->data_type();
-
- if (data_type == ir::DataType::INT32)
- {
- invoke<int32_t, op_type>(lhs_tensor, rhs_tensor, out_tensor, node.param());
- }
- else if (data_type == ir::DataType::FLOAT32)
- {
- invoke<float, op_type>(lhs_tensor, rhs_tensor, out_tensor, node.param());
- }
- else
- {
- throw std::runtime_error{"NYI: Unsupported data type"};
- }
-}
-
-void invokeBinaryArithmeticOps(const ExecEnv *env, const ir::Operation &node)
-{
- const auto &arithmetic_node =
- nnfw::misc::polymorphic_downcast<const ir::operation::BinaryArithmetic &>(node);
-
- switch (arithmetic_node.param().arithmetic_type)
- {
- case ir::operation::BinaryArithmetic::ArithmeticType::ADD:
- invokeBinaryArithmetic<OpType::ADD>(env, arithmetic_node);
- break;
- case ir::operation::BinaryArithmetic::ArithmeticType::SUB:
- invokeBinaryArithmetic<OpType::SUB>(env, arithmetic_node);
- break;
- case ir::operation::BinaryArithmetic::ArithmeticType::MUL:
- invokeBinaryArithmetic<OpType::MUL>(env, arithmetic_node);
- break;
- default:
- throw std::runtime_error{"Interp(BinaryArithmetic): NYI unsupported operation " +
- arithmetic_node.name()};
- break;
- }
-}
-
-} // namespace
-
-OpKernel *getBinaryArithmetic()
-{
- static OpKernel kernel = {prepare, invokeBinaryArithmeticOps};
- return &kernel;
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/operations/Concat.cc b/runtime/onert/core/src/interp/operations/Concat.cc
deleted file mode 100644
index efc46c66b..000000000
--- a/runtime/onert/core/src/interp/operations/Concat.cc
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <cker/operation/Concatenation.h>
-
-#include "OperationUtil.h"
-
-#include "interp/Registration.h"
-#include "ir/operation/Concat.h"
-#include "misc/polymorphic_downcast.h"
-
-namespace onert
-{
-namespace interp
-{
-namespace concat
-{
-
-void prepareConcat(ExecEnv *env, const ir::Operation &node)
-{
- const auto &concat_node = nnfw::misc::polymorphic_downcast<const ir::operation::Concat &>(node);
-
- const auto first_index = node.getInputs().at(0);
- const auto out_index = node.getOutputs().at(0);
-
- const auto first_tensor = env->tensorAt(first_index);
- uint32_t out_axis_dimension = 0;
- const int32_t axis_raw = concat_node.param().axis;
- const uint32_t axis = (axis_raw < 0) ? (axis_raw + first_tensor->num_dimensions()) : axis_raw;
-
- // All inputs shape should be same except axis dimension
- // All inputs type should be same
- for (auto input : node.getInputs())
- {
- assert(first_tensor->num_dimensions() == env->tensorAt(input)->num_dimensions());
- assert(first_tensor->data_type() == env->tensorAt(input)->data_type());
- for (uint32_t i = 0; i < first_tensor->num_dimensions(); i++)
- {
- if (i == axis)
- {
- out_axis_dimension += env->tensorAt(input)->dimension(i);
- continue;
- }
- assert(first_tensor->dimension(i) == env->tensorAt(input)->dimension(i));
- }
- }
-
- // Make output tensor info using first input tensor info, and accumulated axis dimension value
- auto out_shape = first_tensor->tensorInfo().shape();
- out_shape.dim(axis) = out_axis_dimension;
- env->allocateIfNeeded(out_index, ir::OperandInfo::createStaticInfo(
- out_shape, first_tensor->tensorInfo().typeInfo()));
-
- auto out_tensor = env->tensorAt(out_index);
- UNUSED_RELEASE(out_tensor);
-
- // Output shape should be same with input except axis dimension
- // Output type should be same with input
- assert(first_tensor->data_type() == out_tensor->data_type());
- for (uint32_t i = 0; i < first_tensor->num_dimensions(); i++)
- {
- if (i == axis)
- {
- continue;
- }
- assert(first_tensor->dimension(i) == out_tensor->dimension(i));
- }
-}
-
-void invoke(const std::vector<const ITensor *> in_tensors, const ITensor *out_tensor, uint32_t axis)
-{
- const uint32_t count = in_tensors.size();
-
- // Calculate
- nnfw::cker::ConcatenationParams cker_param;
- cker_param.axis = (int8_t)axis;
- cker_param.inputs_count = count;
-
- const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
-
- std::vector<nnfw::cker::Shape> in_shapes;
- std::vector<const nnfw::cker::Shape *> in_shape_ptrs;
- in_shapes.reserve(count);
- in_shape_ptrs.reserve(count);
- std::vector<const float *> in_ptrs;
- for (uint32_t i = 0; i < count; i++)
- {
- in_shapes.push_back(convertShape(in_tensors[i]->tensorInfo().shape()));
- in_shape_ptrs.push_back(&in_shapes[i]);
- in_ptrs.push_back(reinterpret_cast<const float *>(in_tensors[i]->bufferRO()));
- }
-
- auto out_buffer = out_tensor->buffer();
- float *out_ptr = reinterpret_cast<float *>(out_buffer);
-
- nnfw::cker::Concatenation<float>(cker_param, in_shape_ptrs.data(), in_ptrs.data(), out_shape,
- out_ptr);
-}
-
-void invokeConcat(const ExecEnv *env, const ir::Operation &node)
-{
- const auto &concat_node = nnfw::misc::polymorphic_downcast<const ir::operation::Concat &>(node);
- const int32_t axis_raw = concat_node.param().axis;
-
- std::vector<const ITensor *> in_tensors;
- for (const auto &e : concat_node.getInputs())
- {
- in_tensors.emplace_back(env->tensorAt(e));
- }
-
- const auto out_index = node.getOutputs().at(0);
- const auto out_tensor = env->tensorAt(out_index);
- const uint32_t axis = (axis_raw < 0) ? (axis_raw + out_tensor->num_dimensions()) : axis_raw;
-
- const auto data_type = in_tensors[0]->data_type();
- if (data_type == ir::DataType::FLOAT32)
- {
- invoke(in_tensors, out_tensor, axis);
- }
- else
- {
- throw std::runtime_error{"NYI: Support float32 only"};
- }
-}
-} // namespace concat
-
-OpKernel *getConcat()
-{
- static OpKernel kernel = {concat::prepareConcat, concat::invokeConcat};
- return &kernel;
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/operations/Conv2D.cc b/runtime/onert/core/src/interp/operations/Conv2D.cc
deleted file mode 100644
index bb00b828c..000000000
--- a/runtime/onert/core/src/interp/operations/Conv2D.cc
+++ /dev/null
@@ -1,151 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <cker/operation/Conv.h>
-
-#include "OperationUtil.h"
-
-#include "interp/Registration.h"
-#include "ir/operation/Conv2D.h"
-#include "util/Utils.h"
-#include "util/ShapeInference.h"
-#include "misc/polymorphic_downcast.h"
-
-namespace onert
-{
-namespace interp
-{
-namespace conv2d
-{
-
-void prepareConv2D(ExecEnv *env, const ir::Operation &node)
-{
- const auto in_index = node.getInputs().at(ir::operation::Conv2D::INPUT);
- const auto kernel_index = node.getInputs().at(ir::operation::Conv2D::KERNEL);
- const auto bias_index = node.getInputs().at(ir::operation::Conv2D::BIAS);
- const auto out_index = node.getOutputs().at(0);
-
- const auto in_tensor = env->tensorAt(in_index);
- const auto kernel_tensor = env->tensorAt(kernel_index);
- const auto bias_tensor = env->tensorAt(bias_index);
-
- assert(in_tensor->num_dimensions() == 4);
- assert(kernel_tensor->num_dimensions() == 4);
- assert(bias_tensor->num_dimensions() == 1);
-
- UNUSED_RELEASE(in_tensor);
- UNUSED_RELEASE(kernel_tensor);
- UNUSED_RELEASE(bias_tensor);
-
- const auto output_info = env->graph().operands().at(out_index).info();
- if (output_info.total_size() == 0)
- {
- // Handle unspecified output shape
- const auto &conv_node = nnfw::misc::polymorphic_downcast<const ir::operation::Conv2D &>(node);
- const auto infered_output_shape = shape_inference::inferConv2DShape(
- in_tensor->tensorInfo().shape(), kernel_tensor->tensorInfo().shape(), conv_node.param());
- env->allocateIfNeeded(
- out_index, ir::OperandInfo::createStaticInfo(infered_output_shape, output_info.typeInfo()));
- }
- else
- {
- env->allocateIfNeeded(out_index, output_info);
- }
-
- auto out_tensor = env->tensorAt(out_index);
- UNUSED_RELEASE(out_tensor);
-
- // Handle same ifm & ofm data type only
- assert(in_tensor->data_type() == out_tensor->data_type());
- assert(out_tensor->num_dimensions() == 4);
-}
-
-void invoke(const ITensor *ifm_tensor, const ITensor *ker_tensor, const ITensor *bias_tensor,
- const ITensor *ofm_tensor, const ir::operation::Conv2D::Param &param)
-{
- // TODO Support NCHW frontned
- const auto ifm_shape = ifm_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC);
- const auto ofm_shape = ofm_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC);
- // Kernel format is [depth_out, kernel_height, kernel_width, depth_in].
- const auto &ker_shape = ker_tensor->tensorInfo().shape();
- const auto ker_height = ker_shape.dim(1);
- const auto ker_width = ker_shape.dim(2);
- const auto padding = ir::calculatePadding(param.padding, ifm_shape, ofm_shape, param.stride,
- ker_width, ker_height);
-
- // Calculate
- float activation_min, activation_max;
- calculateActivationRange(param.activation, &activation_min, &activation_max);
-
- nnfw::cker::ConvParams cker_param;
- cker_param.padding_type = convertPaddingType(param.padding.type);
- cker_param.padding_values.width = padding.left;
- cker_param.padding_values.height = padding.top;
- cker_param.stride_width = param.stride.horizontal;
- cker_param.stride_height = param.stride.vertical;
- cker_param.dilation_width_factor = 1;
- cker_param.dilation_height_factor = 1;
- cker_param.float_activation_min = activation_min;
- cker_param.float_activation_max = activation_max;
-
- const auto cker_ifm_shape = convertShape(ifm_tensor->tensorInfo().shape());
- const auto cker_ker_shape = convertShape(ker_tensor->tensorInfo().shape());
- const auto cker_bias_shape = convertShape(bias_tensor->tensorInfo().shape());
- const auto cker_ofm_shape = convertShape(ofm_tensor->tensorInfo().shape());
- const float *ifm_ptr = reinterpret_cast<const float *>(ifm_tensor->bufferRO());
- const float *ker_ptr = reinterpret_cast<const float *>(ker_tensor->bufferRO());
- const float *bias_ptr = reinterpret_cast<const float *>(bias_tensor->bufferRO());
- float *ofm_ptr = reinterpret_cast<float *>(ofm_tensor->buffer());
-
- nnfw::cker::Conv conv_kernel;
- conv_kernel(cker_param, cker_ifm_shape, ifm_ptr, cker_ker_shape, ker_ptr, cker_bias_shape,
- bias_ptr, cker_ofm_shape, ofm_ptr);
-}
-
-void invokeConv2D(const ExecEnv *env, const ir::Operation &node)
-{
- const auto &conv_node = nnfw::misc::polymorphic_downcast<const ir::operation::Conv2D &>(node);
-
- const auto ifm_index = node.getInputs().at(ir::operation::Conv2D::INPUT);
- const auto ker_index = node.getInputs().at(ir::operation::Conv2D::KERNEL);
- const auto bias_index = node.getInputs().at(ir::operation::Conv2D::BIAS);
- const auto ofm_index = node.getOutputs().at(0);
-
- const auto ifm_tensor = env->tensorAt(ifm_index);
- const auto ker_tensor = env->tensorAt(ker_index);
- const auto bias_tensor = env->tensorAt(bias_index);
- const auto ofm_tensor = env->tensorAt(ofm_index);
-
- const auto data_type = ifm_tensor->data_type();
- if (data_type == ir::DataType::FLOAT32)
- {
- invoke(ifm_tensor, ker_tensor, bias_tensor, ofm_tensor, conv_node.param());
- }
- else
- {
- throw std::runtime_error{"NYI: Support float32 only"};
- }
-}
-} // namespace conv2d
-
-OpKernel *getConv2D()
-{
- static OpKernel kernel = {conv2d::prepareConv2D, conv2d::invokeConv2D};
- return &kernel;
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/operations/DepthwiseConv2D.cc b/runtime/onert/core/src/interp/operations/DepthwiseConv2D.cc
deleted file mode 100644
index 0473855d9..000000000
--- a/runtime/onert/core/src/interp/operations/DepthwiseConv2D.cc
+++ /dev/null
@@ -1,156 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <cker/operation/DepthwiseConv.h>
-#include <misc/polymorphic_downcast.h>
-
-#include "OperationUtil.h"
-
-#include "interp/Registration.h"
-#include "ir/operation/DepthwiseConv2D.h"
-#include "util/Utils.h"
-#include "util/ShapeInference.h"
-
-namespace onert
-{
-namespace interp
-{
-
-namespace
-{
-
-void prepareDepthwiseConv(ExecEnv *env, const ir::Operation &node)
-{
- const auto in_index = node.getInputs().at(ir::operation::DepthwiseConv2D::INPUT);
- const auto kernel_index = node.getInputs().at(ir::operation::DepthwiseConv2D::KERNEL);
- const auto bias_index = node.getInputs().at(ir::operation::DepthwiseConv2D::BIAS);
- const auto out_index = node.getOutputs().at(0);
-
- const auto in_tensor = env->tensorAt(in_index);
- const auto kernel_tensor = env->tensorAt(kernel_index);
- const auto bias_tensor = env->tensorAt(bias_index);
-
- assert(in_tensor->num_dimensions() == 4);
- assert(kernel_tensor->num_dimensions() == 4);
- assert(bias_tensor->num_dimensions() == 1);
-
- UNUSED_RELEASE(in_tensor);
- UNUSED_RELEASE(kernel_tensor);
- UNUSED_RELEASE(bias_tensor);
-
- // TODO handle unspecified output shape:
- // calculate output shape using ifm shape, kernel shape, padding, stride
- const auto output_info = env->graph().operands().at(out_index).info();
- if (output_info.total_size() == 0)
- {
- // Handle unspecified output shape
- const auto &depth_conv_node =
- nnfw::misc::polymorphic_downcast<const ir::operation::DepthwiseConv2D &>(node);
- const auto infered_output_shape = shape_inference::inferDepthwiseConv2DShape(
- in_tensor->tensorInfo().shape(), kernel_tensor->tensorInfo().shape(),
- depth_conv_node.param());
- env->allocateIfNeeded(
- out_index, ir::OperandInfo::createStaticInfo(infered_output_shape, output_info.typeInfo()));
- }
- else
- {
- env->allocateIfNeeded(out_index, output_info);
- }
-
- auto out_tensor = env->tensorAt(out_index);
- UNUSED_RELEASE(out_tensor);
-
- // Handle same ifm & ofm data type only
- assert(in_tensor->data_type() == out_tensor->data_type());
- assert(out_tensor->num_dimensions() == 4);
-}
-
-void invoke(const ITensor *ifm_tensor, const ITensor *ker_tensor, const ITensor *bias_tensor,
- const ITensor *ofm_tensor, const ir::operation::DepthwiseConv2D::Param &param)
-{
- // TODO Support NCHW frontend
- const auto ifm_shape = ifm_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC);
- const auto ofm_shape = ofm_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC);
- // Kernel format is [1, kernel_height, kernel_width, depth_out].
- const auto &ker_shape = ker_tensor->tensorInfo().shape();
- const auto ker_height = ker_shape.dim(1);
- const auto ker_width = ker_shape.dim(2);
- const auto padding = ir::calculatePadding(param.padding, ifm_shape, ofm_shape, param.stride,
- ker_width, ker_height);
-
- // Calculate
- float activation_min, activation_max;
- calculateActivationRange(param.activation, &activation_min, &activation_max);
-
- nnfw::cker::DepthwiseConvParams cker_param;
- cker_param.padding_values.width = padding.left;
- cker_param.padding_values.height = padding.top;
- cker_param.depth_multiplier = param.multiplier;
- cker_param.stride_width = param.stride.horizontal;
- cker_param.stride_height = param.stride.vertical;
- cker_param.dilation_width_factor = 1;
- cker_param.dilation_height_factor = 1;
- cker_param.float_activation_min = activation_min;
- cker_param.float_activation_max = activation_max;
-
- const auto cker_ifm_shape = convertShape(ifm_tensor->tensorInfo().shape());
- const auto cker_ker_shape = convertShape(ker_tensor->tensorInfo().shape());
- const auto cker_bias_shape = convertShape(bias_tensor->tensorInfo().shape());
- const auto cker_ofm_shape = convertShape(ofm_tensor->tensorInfo().shape());
- const float *ifm_ptr = reinterpret_cast<const float *>(ifm_tensor->bufferRO());
- const float *ker_ptr = reinterpret_cast<const float *>(ker_tensor->bufferRO());
- const float *bias_ptr = reinterpret_cast<const float *>(bias_tensor->bufferRO());
- float *ofm_ptr = reinterpret_cast<float *>(ofm_tensor->buffer());
-
- nnfw::cker::DepthwiseConv(cker_param, cker_ifm_shape, ifm_ptr, cker_ker_shape, ker_ptr,
- cker_bias_shape, bias_ptr, cker_ofm_shape, ofm_ptr);
-}
-
-void invokeDepthwiseConv(const ExecEnv *env, const ir::Operation &node)
-{
- const auto &conv_node = static_cast<const ir::operation::DepthwiseConv2D &>(node);
-
- const auto ifm_index = node.getInputs().at(ir::operation::DepthwiseConv2D::INPUT);
- const auto ker_index = node.getInputs().at(ir::operation::DepthwiseConv2D::KERNEL);
- const auto bias_index = node.getInputs().at(ir::operation::DepthwiseConv2D::BIAS);
- const auto ofm_index = node.getOutputs().at(0);
-
- const auto ifm_tensor = env->tensorAt(ifm_index);
- const auto ker_tensor = env->tensorAt(ker_index);
- const auto bias_tensor = env->tensorAt(bias_index);
- const auto ofm_tensor = env->tensorAt(ofm_index);
-
- const auto data_type = ifm_tensor->data_type();
- if (data_type == ir::DataType::FLOAT32)
- {
- invoke(ifm_tensor, ker_tensor, bias_tensor, ofm_tensor, conv_node.param());
- }
- else
- {
- throw std::runtime_error{"NYI: Support float32 only"};
- }
-}
-
-} // namespace
-
-OpKernel *getDepthwiseConv2D()
-{
- static OpKernel kernel = {prepareDepthwiseConv, invokeDepthwiseConv};
- return &kernel;
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/operations/ElementwiseActivations.cc b/runtime/onert/core/src/interp/operations/ElementwiseActivations.cc
deleted file mode 100644
index c8773bef4..000000000
--- a/runtime/onert/core/src/interp/operations/ElementwiseActivations.cc
+++ /dev/null
@@ -1,161 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <cmath>
-
-#include "OperationUtil.h"
-
-#include "interp/Registration.h"
-
-#include "ir/operation/ElementwiseActivation.h"
-
-#include <misc/polymorphic_downcast.h>
-#include <cker/operation/Logistic.h>
-#include <cker/operation/Tanh.h>
-
-namespace onert
-{
-namespace interp
-{
-namespace
-{
-
-enum class ActivationType
-{
- Logistic,
- ReLU,
- Tanh
-};
-
-void prepare(ExecEnv *env, const ir::Operation &node)
-{
- const auto input_index = node.getInputs().at(0);
- const auto output_index = node.getOutputs().at(0);
-
- const auto input_tensor = env->tensorAt(input_index);
-
- const auto output_info = env->graph().operands().at(output_index).info();
- if (output_info.total_size() == 0)
- {
- // Output's shape and type is same with input
- auto input_info = input_tensor->tensorInfo();
- // We can handle already allocated (ex. model output)
- env->allocateIfNeeded(output_index, input_info);
- }
- else
- {
- env->allocateIfNeeded(output_index, output_info);
- }
-
- const auto output_tensor = env->tensorAt(output_index);
- // Check shape and type lhs is same with output
- // TODO Util function to compare TensorInfo
- if (input_tensor->data_type() != output_tensor->data_type())
- {
- throw std::runtime_error{"Interp(ElementwiseActivation): Invalid output type"};
- }
-}
-
-template <ActivationType act_type>
-void evalFloat(const float *input_ptr, float *output_ptr, uint64_t num_elements, float alpha,
- float beta)
-{
- std::function<float(const float &)> fn = [](const float &) { return std::nanf(""); };
- switch (act_type)
- {
- case ActivationType::ReLU:
- fn = [alpha, beta](const float &in) { return std::min(std::max(beta, in), alpha); };
- break;
- case ActivationType::Tanh:
- fn = [](const float &in) { return std::tanh(in); };
- break;
- default:
- throw std::runtime_error{"Interp(ElementwiseActivation): NYI - Unsupported activation"};
- break;
- }
-
- const float *input_end = input_ptr + num_elements;
- for (; input_ptr < input_end; input_ptr++, output_ptr++)
- {
- *output_ptr = fn(*input_ptr);
- }
-}
-
-template <ActivationType act_type> void invoke(const ExecEnv *env, const ir::Operation &node)
-{
- const auto input_index = node.getInputs().at(0);
- const auto output_index = node.getOutputs().at(0);
-
- // Check lhs shape is same with rhs (with broadcast)
- const auto input_tensor = env->tensorAt(input_index);
- const auto output_tensor = env->tensorAt(output_index);
-
- const auto data_type = input_tensor->data_type();
- if (data_type == ir::DataType::FLOAT32)
- {
- uint64_t elements = input_tensor->num_elements();
- const float *input_start = reinterpret_cast<const float *>(input_tensor->bufferRO());
- float *out = reinterpret_cast<float *>(output_tensor->buffer());
- if (act_type == ActivationType::Logistic)
- {
- const auto cker_input_shape = convertShape(input_tensor->tensorInfo().shape());
- const auto cker_output_shape = convertShape(output_tensor->tensorInfo().shape());
- nnfw::cker::Logistic(cker_input_shape, input_start, cker_output_shape, out);
- }
- else
- {
- const auto &act_node =
- nnfw::misc::polymorphic_downcast<const ir::operation::ElementwiseActivation &>(node);
- evalFloat<act_type>(input_start, out, elements, act_node.param().alpha,
- act_node.param().beta);
- }
- }
- else
- {
- throw std::runtime_error{"Interp(" + node.name() + "): NYI - Support float only"};
- }
-}
-
-void invokeElementwiseActivation(const ExecEnv *env, const ir::Operation &node)
-{
- const auto &act_node =
- nnfw::misc::polymorphic_downcast<const ir::operation::ElementwiseActivation &>(node);
- switch (act_node.param().op_type)
- {
- case ir::operation::ElementwiseActivation::Type::LOGISTIC:
- invoke<ActivationType::Logistic>(env, node);
- break;
- case ir::operation::ElementwiseActivation::Type::RELU:
- invoke<ActivationType::ReLU>(env, node);
- break;
- case ir::operation::ElementwiseActivation::Type::TANH:
- invoke<ActivationType::Tanh>(env, node);
- break;
- default:
- throw std::runtime_error("Interp(" + node.name() + "): NYI - Unsupported activation");
- }
-}
-
-} // namespace
-
-OpKernel *getElementwiseActivation()
-{
- static OpKernel kernel = {prepare, invokeElementwiseActivation};
- return &kernel;
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/operations/FullyConnected.cc b/runtime/onert/core/src/interp/operations/FullyConnected.cc
deleted file mode 100644
index 12f529dab..000000000
--- a/runtime/onert/core/src/interp/operations/FullyConnected.cc
+++ /dev/null
@@ -1,136 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <cker/operation/FullyConnected.h>
-
-#include "OperationUtil.h"
-
-#include "interp/Registration.h"
-#include "ir/operation/FullyConnected.h"
-#include "misc/polymorphic_downcast.h"
-
-namespace onert
-{
-namespace interp
-{
-namespace fc
-{
-
-void prepareFC(ExecEnv *env, const ir::Operation &node)
-{
- const auto in_index = node.getInputs().at(ir::operation::FullyConnected::INPUT);
- const auto kernel_index = node.getInputs().at(ir::operation::FullyConnected::WEIGHT);
- const auto bias_index = node.getInputs().at(ir::operation::FullyConnected::BIAS);
- const auto out_index = node.getOutputs().at(0);
-
- const auto in_tensor = env->tensorAt(in_index);
- const auto kernel_tensor = env->tensorAt(kernel_index);
- const auto bias_tensor = env->tensorAt(bias_index);
-
- UNUSED_RELEASE(in_tensor);
- UNUSED_RELEASE(kernel_tensor);
- UNUSED_RELEASE(bias_tensor);
-
- assert(in_tensor->num_dimensions() >= 2);
- assert(kernel_tensor->num_dimensions() == 2);
- assert(bias_tensor->num_dimensions() == 1);
-
- const auto input_size_with_batch = in_tensor->num_elements();
- const auto num_units = kernel_tensor->dimension(0);
- const auto input_size = kernel_tensor->dimension(1);
- const auto batch_size = input_size_with_batch / input_size;
- assert(input_size_with_batch % input_size == 0);
- assert(num_units == bias_tensor->dimension(0));
-
- // Make output tensor info
- ir::Shape output_shape(2);
- output_shape.dim(0) = batch_size;
- output_shape.dim(1) = num_units;
- const auto out_info =
- ir::OperandInfo::createStaticInfo(output_shape, in_tensor->tensorInfo().typeInfo());
- env->allocateIfNeeded(out_index, out_info);
-
- auto out_tensor = env->tensorAt(out_index);
- UNUSED_RELEASE(out_tensor);
-
- // Handle same ifm & ofm data type only
- assert(in_tensor->data_type() == out_tensor->data_type());
- assert(out_tensor->num_dimensions() == 2);
- assert(out_tensor->dimension(0) == batch_size);
- assert(out_tensor->dimension(1) == num_units);
-}
-
-void invoke(const ITensor *ifm_tensor, const ITensor *ker_tensor, const ITensor *bias_tensor,
- const ITensor *ofm_tensor, const ir::operation::FullyConnected::Param &param)
-{
- const auto ifm_buffer = ifm_tensor->bufferRO();
- const auto ker_buffer = ker_tensor->bufferRO();
- const auto bias_buffer = bias_tensor->bufferRO();
- auto ofm_buffer = ofm_tensor->buffer();
-
- // Calculate
- nnfw::cker::FullyConnectedParams cker_param;
- cker_param.activation = convertActivationType(param.activation);
- calculateActivationRange(param.activation, &cker_param.float_activation_min,
- &cker_param.float_activation_max);
- const auto cker_ifm_shape = convertShape(ifm_tensor->tensorInfo().shape());
- const auto cker_ker_shape = convertShape(ker_tensor->tensorInfo().shape());
- const auto cker_bias_shape = convertShape(bias_tensor->tensorInfo().shape());
- const auto cker_ofm_shape = convertShape(ofm_tensor->tensorInfo().shape());
- const float *ifm_ptr = reinterpret_cast<const float *>(ifm_buffer);
- const float *ker_ptr = reinterpret_cast<const float *>(ker_buffer);
- const float *bias_ptr = reinterpret_cast<const float *>(bias_buffer);
- float *ofm_ptr = reinterpret_cast<float *>(ofm_buffer);
-
- nnfw::cker::FullyConnected(cker_param, cker_ifm_shape, ifm_ptr, cker_ker_shape, ker_ptr,
- cker_bias_shape, bias_ptr, cker_ofm_shape, ofm_ptr);
-}
-
-void invokeFC(const ExecEnv *env, const ir::Operation &node)
-{
- const auto &conv_node =
- nnfw::misc::polymorphic_downcast<const ir::operation::FullyConnected &>(node);
-
- const auto ifm_index = node.getInputs().at(ir::operation::FullyConnected::INPUT);
- const auto ker_index = node.getInputs().at(ir::operation::FullyConnected::WEIGHT);
- const auto bias_index = node.getInputs().at(ir::operation::FullyConnected::BIAS);
- const auto ofm_index = node.getOutputs().at(0);
-
- const auto ifm_tensor = env->tensorAt(ifm_index);
- const auto ker_tensor = env->tensorAt(ker_index);
- const auto bias_tensor = env->tensorAt(bias_index);
- const auto ofm_tensor = env->tensorAt(ofm_index);
-
- const auto data_type = ifm_tensor->data_type();
- if (data_type == ir::DataType::FLOAT32)
- {
- invoke(ifm_tensor, ker_tensor, bias_tensor, ofm_tensor, conv_node.param());
- }
- else
- {
- throw std::runtime_error{"NYI: Support float only"};
- }
-}
-} // namespace fc
-
-OpKernel *getFullyConnected()
-{
- static OpKernel kernel = {fc::prepareFC, fc::invokeFC};
- return &kernel;
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/operations/Gather.cc b/runtime/onert/core/src/interp/operations/Gather.cc
deleted file mode 100644
index 9e82def5f..000000000
--- a/runtime/onert/core/src/interp/operations/Gather.cc
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <cker/operation/Gather.h>
-
-#include "OperationUtil.h"
-
-#include "interp/Registration.h"
-#include "ir/operation/Gather.h"
-#include "misc/polymorphic_downcast.h"
-
-namespace onert
-{
-namespace interp
-{
-namespace
-{
-
-void prepareGather(ExecEnv *env, const ir::Operation &node)
-{
- const auto input_index = node.getInputs().at(ir::operation::Gather::INPUT);
- const auto indices_index = node.getInputs().at(ir::operation::Gather::INDICES);
- const auto output_index = node.getOutputs().at(0);
-
- const auto input_tensor = env->tensorAt(input_index);
- const auto indices_tensor = env->tensorAt(indices_index);
-
- // TODO handle unspecified output shape:
- // calculate output shape using ifm shape, kernel shape, padding, stride
- const auto output_info = env->graph().operands().at(output_index).info();
- if (output_info.total_size() == 0)
- {
- throw std::runtime_error{"Interp(Gather): NYI for unspecified output shape"};
- }
- else
- {
- env->allocateIfNeeded(output_index, output_info);
- }
-
- if (indices_tensor->data_type() != ir::DataType::INT32)
- {
- throw std::runtime_error{"Interp(Gather): Invalid indices data type"};
- }
-
- auto output_tensor = env->tensorAt(output_index);
- auto output_rank = input_tensor->num_dimensions() + indices_tensor->num_dimensions() - 1;
-
- if (output_rank != output_tensor->num_dimensions())
- {
- throw std::runtime_error{"Interp(Gather): Invalid output rank"};
- }
- if (output_tensor->data_type() != input_tensor->data_type())
- {
- throw std::runtime_error{"Interp(Gather): Invalid output data type"};
- }
-
- if (input_tensor->data_type() == ir::DataType::QUANT_UINT8_ASYMM &&
- input_tensor->tensorInfo().typeInfo() != output_tensor->tensorInfo().typeInfo())
- {
- throw std::runtime_error{
- "Interp(Gather): Cannot handle different I/O QUANT_UINT8_ASYMM scale/offset"};
- }
-}
-
-template <typename raw_type>
-void invoke(const ITensor *input_tensors, const ITensor *indices_tensors,
- const ITensor *output_tensor, uint32_t axis)
-{
- // Calculate
- nnfw::cker::GatherParams cker_param;
- cker_param.axis = (int8_t)axis;
-
- const auto cker_input_shapes = convertShape(input_tensors->tensorInfo().shape());
- const auto cker_indices_shape = convertShape(indices_tensors->tensorInfo().shape());
- const auto cker_output_shape = convertShape(output_tensor->tensorInfo().shape());
- const raw_type *input_ptr = reinterpret_cast<const raw_type *>(input_tensors->bufferRO());
- const int32_t *indices_ptr = reinterpret_cast<const int32_t *>(indices_tensors->bufferRO());
- raw_type *output_ptr = reinterpret_cast<raw_type *>(output_tensor->buffer());
-
- nnfw::cker::Gather<raw_type>(cker_param, cker_input_shapes, input_ptr, cker_indices_shape,
- indices_ptr, cker_output_shape, output_ptr);
-}
-
-void invokeGather(const ExecEnv *env, const ir::Operation &node)
-{
- const auto &gather_node = nnfw::misc::polymorphic_downcast<const ir::operation::Gather &>(node);
- const int32_t axis_raw = gather_node.param().axis;
-
- const auto input_index = node.getInputs().at(ir::operation::Gather::INPUT);
- const auto indices_index = node.getInputs().at(ir::operation::Gather::INDICES);
- const auto output_index = node.getOutputs().at(0);
-
- const auto input_tensor = env->tensorAt(input_index);
- const auto indices_tensor = env->tensorAt(indices_index);
- const auto output_tensor = env->tensorAt(output_index);
- const uint32_t axis = (axis_raw < 0) ? (axis_raw + input_tensor->num_dimensions()) : axis_raw;
-
- const auto data_type = input_tensor->data_type();
-
- switch (data_type)
- {
- case ir::DataType::FLOAT32:
- invoke<float>(input_tensor, indices_tensor, output_tensor, axis);
- break;
- case ir::DataType::INT32:
- invoke<int32_t>(input_tensor, indices_tensor, output_tensor, axis);
- break;
- case ir::DataType::QUANT_UINT8_ASYMM:
- invoke<uint8_t>(input_tensor, indices_tensor, output_tensor, axis);
- break;
- default:
- throw std::runtime_error{"Interp(Gather): NYI - Not supported type"};
- }
-}
-
-} // namespace
-
-OpKernel *getGather()
-{
- static OpKernel kernel = {prepareGather, invokeGather};
- return &kernel;
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/operations/InstanceNorm.cc b/runtime/onert/core/src/interp/operations/InstanceNorm.cc
deleted file mode 100644
index 2538bcc39..000000000
--- a/runtime/onert/core/src/interp/operations/InstanceNorm.cc
+++ /dev/null
@@ -1,121 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <cker/operation/InstanceNorm.h>
-
-#include "OperationUtil.h"
-
-#include "interp/Registration.h"
-#include "ir/operation/InstanceNorm.h"
-#include "misc/polymorphic_downcast.h"
-
-namespace onert
-{
-namespace interp
-{
-namespace instancenorm
-{
-
-void prepareInstanceNorm(ExecEnv *env, const ir::Operation &node)
-{
- const auto &instancenorm_node =
- nnfw::misc::polymorphic_downcast<const ir::operation::InstanceNorm &>(node);
-
- const auto input_index = node.getInputs().at(instancenorm_node.INPUT);
- const auto output_index = node.getOutputs().at(0);
- const auto input_tensor = env->tensorAt(input_index);
-
- if (input_tensor->num_dimensions() != 4)
- {
- throw std::runtime_error{"Interp(InstanceNorm): Input should be 4D-tensor"};
- }
-
- // Output shape should be same with input
- env->allocateIfNeeded(output_index, input_tensor->tensorInfo());
-
- auto output_tensor = env->tensorAt(output_index);
- UNUSED_RELEASE(output_tensor);
-
- // Handle same ifm & ofm data type only
- assert(input_tensor->data_type() == output_tensor->data_type());
- assert(input_tensor->tensorInfo().shape() == output_tensor->tensorInfo().shape());
-}
-
-inline void setActivationParams(float min, float max, nnfw::cker::InstanceNormParams *params)
-{
- params->float_activation_min = min;
- params->float_activation_max = max;
-}
-
-void invoke(const ITensor *input_tensor, const ITensor *gamma_tensor, const ITensor *beta_tensor,
- const ITensor *output_tensor, const ir::operation::InstanceNorm::Param &param)
-{
- // Calculate
- float activation_min, activation_max;
- calculateActivationRange(param.activation, &activation_min, &activation_max);
-
- nnfw::cker::InstanceNormParams cker_param;
- cker_param.epsilon = param.epsilon;
- cker_param.float_activation_min = activation_min;
- cker_param.float_activation_max = activation_max;
-
- const auto cker_input_shape = convertShape(input_tensor->tensorInfo().shape());
- const auto cker_gamma_shape = convertShape(gamma_tensor->tensorInfo().shape());
- const auto cker_beta_shape = convertShape(beta_tensor->tensorInfo().shape());
- const auto cker_output_shape = convertShape(output_tensor->tensorInfo().shape());
- const float *input_ptr = reinterpret_cast<const float *>(input_tensor->bufferRO());
- const float *gamma_ptr = reinterpret_cast<const float *>(gamma_tensor->bufferRO());
- const float *beta_ptr = reinterpret_cast<const float *>(beta_tensor->bufferRO());
- float *output_ptr = reinterpret_cast<float *>(output_tensor->buffer());
-
- nnfw::cker::InstanceNorm(cker_param, cker_input_shape, input_ptr, cker_gamma_shape, gamma_ptr,
- cker_beta_shape, beta_ptr, cker_output_shape, output_ptr);
-}
-
-void invokeInstanceNorm(const ExecEnv *env, const ir::Operation &node)
-{
- const auto &instancenorm_node =
- nnfw::misc::polymorphic_downcast<const ir::operation::InstanceNorm &>(node);
-
- const auto input_index = node.getInputs().at(instancenorm_node.INPUT);
- const auto gamma_index = node.getInputs().at(instancenorm_node.GAMMA);
- const auto beta_index = node.getInputs().at(instancenorm_node.BETA);
- const auto out_index = node.getOutputs().at(0);
- const auto input_tensor = env->tensorAt(input_index);
- const auto gamma_tensor = env->tensorAt(gamma_index);
- const auto beta_tensor = env->tensorAt(beta_index);
- const auto out_tensor = env->tensorAt(out_index);
- const auto data_type = input_tensor->data_type();
-
- if (data_type == ir::DataType::FLOAT32)
- {
- invoke(input_tensor, gamma_tensor, beta_tensor, out_tensor, instancenorm_node.param());
- }
- else
- {
- throw std::runtime_error{"NYI: Unsupported data type"};
- }
-}
-} // namespace instancenorm
-
-OpKernel *getInstanceNorm()
-{
- static OpKernel kernel = {instancenorm::prepareInstanceNorm, instancenorm::invokeInstanceNorm};
- return &kernel;
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/operations/OperationUtil.h b/runtime/onert/core/src/interp/operations/OperationUtil.h
deleted file mode 100644
index 2fdf098f0..000000000
--- a/runtime/onert/core/src/interp/operations/OperationUtil.h
+++ /dev/null
@@ -1,203 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __ONERT_INTERP_OPERATIONS_OPERATION_UTILS_H_
-#define __ONERT_INTERP_OPERATIONS_OPERATION_UTILS_H_
-
-#include "ir/Shape.h"
-#include "ir/InternalType.h"
-#include "ir/Padding.h"
-
-#include <cker/Shape.h>
-#include <cker/Types.h>
-
-namespace onert
-{
-namespace interp
-{
-
-inline nnfw::cker::Shape convertShape(const ir::Shape &shape)
-{
- auto dimensions = std::vector<uint32_t>(shape.dims().begin(), shape.dims().end());
-
- std::vector<int32_t> raw_shape;
- raw_shape.resize(dimensions.size());
-
- for (uint32_t i = 0; i < dimensions.size(); ++i)
- {
- raw_shape[i] = dimensions[i];
- }
-
- return nnfw::cker::GetShape(raw_shape);
-}
-
-inline nnfw::cker::Shape convertExtendShape(const ir::Shape &shape)
-{
- auto dimensions = std::vector<uint32_t>(shape.dims().begin(), shape.dims().end());
-
- const int32_t extended_rank = 4;
- int32_t raw_shape[extended_rank];
- uint32_t start = extended_rank - dimensions.size();
-
- for (uint32_t i = 0; i < extended_rank; ++i)
- {
- if (i < start)
- {
- raw_shape[i] = 1;
- }
- else
- {
- raw_shape[i] = dimensions[i - start];
- }
- }
-
- return nnfw::cker::Shape(extended_rank, raw_shape);
-}
-
-inline nnfw::cker::FusedActivationFunctionType
-convertActivationType(const ir::Activation activation)
-{
- switch (activation)
- {
- case ir::Activation::NONE:
- return nnfw::cker::FusedActivationFunctionType::kNone;
- case ir::Activation::RELU:
- return nnfw::cker::FusedActivationFunctionType::kRelu;
- case ir::Activation::RELU1:
- return nnfw::cker::FusedActivationFunctionType::kRelu1;
- case ir::Activation::RELU6:
- return nnfw::cker::FusedActivationFunctionType::kRelu6;
- default:
- throw std::runtime_error{"CPU backend: Cannot convert activation type"};
- }
-}
-
-template <typename T>
-void calculateActivationRange(ir::Activation activation, T *activation_min, T *activation_max)
-{
- if (activation == ir::Activation::RELU)
- {
- *activation_min = 0;
- *activation_max = std::numeric_limits<T>::max();
- }
- else if (activation == ir::Activation::RELU6)
- {
- *activation_min = 0;
- *activation_max = 6;
- }
- else if (activation == ir::Activation::RELU1)
- {
- *activation_min = -1;
- *activation_max = 1;
- }
- else if (activation == ir::Activation::NONE)
- {
- *activation_min = std::numeric_limits<T>::lowest();
- *activation_max = std::numeric_limits<T>::max();
- }
- else
- {
- throw std::runtime_error{"Unsupported activation type"};
- }
-}
-
-inline ir::Shape calcBroadcastShape(const ir::Shape &lhs, const ir::Shape &rhs, bool &success)
-{
- int lhs_rank = lhs.rank();
- int rhs_rank = rhs.rank();
-
- int out_rank = (lhs_rank > rhs_rank ? lhs_rank : rhs_rank);
- ir::Shape out_shape(out_rank);
-
- int lhs_idim = lhs_rank - 1;
- int rhs_idim = rhs_rank - 1;
- success = true;
- for (int out_idim = out_rank - 1; out_idim >= 0; out_idim--)
- {
- if (lhs_idim == -1 && rhs_idim == -1)
- {
- // invalid result
- success = false;
- break;
- }
-
- if (lhs_idim == -1)
- {
- out_shape.dim(out_idim) = rhs.dim(rhs_idim);
- rhs_idim--;
- }
- else if (rhs_idim == -1)
- {
- out_shape.dim(out_idim) = lhs.dim(lhs_idim);
- lhs_idim--;
- }
- else
- {
- if (lhs.dim(lhs_idim) == rhs.dim(rhs_idim))
- {
- out_shape.dim(out_idim) = lhs.dim(lhs_idim);
- lhs_idim--;
- rhs_idim--;
- }
- else if (lhs.dim(lhs_idim) == 1)
- {
- out_shape.dim(out_idim) = rhs.dim(rhs_idim);
- lhs_idim--;
- rhs_idim--;
- }
- else if (rhs.dim(rhs_idim) == 1)
- {
- out_shape.dim(out_idim) = lhs.dim(lhs_idim);
- lhs_idim--;
- rhs_idim--;
- }
- else
- {
- // invalid result
- success = false;
- break;
- }
- }
- }
-
- if (lhs_idim != -1 || rhs_idim != -1)
- {
- // invalid result
- success = false;
- }
- return out_shape;
-}
-
-inline nnfw::cker::PaddingType convertPaddingType(ir::PaddingType ir_padding_type)
-{
- switch (ir_padding_type)
- {
- case ir::PaddingType::EXPLICIT:
- return nnfw::cker::PaddingType::kNone;
- case ir::PaddingType::SAME:
- return nnfw::cker::PaddingType::kSame;
- case ir::PaddingType::VALID:
- return nnfw::cker::PaddingType::kValid;
- default:
- throw std::runtime_error("Wrong padding type.");
- break;
- }
-}
-
-} // namespace interp
-} // namespace onert
-
-#endif // __ONERT_INTERP_OPERATIONS_OPERATION_UTILS_H_
diff --git a/runtime/onert/core/src/interp/operations/Pad.cc b/runtime/onert/core/src/interp/operations/Pad.cc
deleted file mode 100644
index c8dce698d..000000000
--- a/runtime/onert/core/src/interp/operations/Pad.cc
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <cker/operation/Pad.h>
-
-#include "OperationUtil.h"
-
-#include "interp/Registration.h"
-#include "ir/operation/Pad.h"
-
-namespace onert
-{
-namespace interp
-{
-namespace
-{
-
-void preparePad(ExecEnv *env, const ir::Operation &node)
-{
- const auto input_index = node.getInputs().at(ir::operation::Pad::INPUT);
- const auto output_index = node.getOutputs().at(0);
-
- const auto input_tensor = env->tensorAt(input_index);
-
- const auto output_info = env->graph().operands().at(output_index).info();
-
- // Check shape and type lhs is same with rhs
- // TODO Util function to compare TensorInfo
- if (output_info.total_size() == 0)
- {
- throw std::runtime_error{"Interp(Pad): NYI unspecified output shape"};
- }
- else
- {
- env->allocateIfNeeded(output_index, output_info);
- }
-
- const auto output_tensor = env->tensorAt(output_index);
- if (input_tensor->data_type() != output_tensor->data_type())
- {
- throw std::runtime_error{"Interp(Pad): Invalid output type"};
- }
-}
-
-void invoke(const ITensor *input_tensor, const ITensor *pad_tensor, const ITensor *output_tensor)
-{
- const auto input_buffer = input_tensor->bufferRO();
- const auto pad_buffer = pad_tensor->bufferRO();
- auto output_buffer = output_tensor->buffer();
-
- int32_t pad_rank = pad_tensor->dimension(0);
-
- const auto cker_input_shape = convertShape(input_tensor->tensorInfo().shape());
- const auto cker_output_shape = convertShape(output_tensor->tensorInfo().shape());
- const float *input_ptr = reinterpret_cast<const float *>(input_buffer);
- const int32_t *pad_ptr = reinterpret_cast<const int32_t *>(pad_buffer);
- float *output_ptr = reinterpret_cast<float *>(output_buffer);
-
- nnfw::cker::Pad<float>(pad_ptr, pad_rank, cker_input_shape, input_ptr, cker_output_shape,
- output_ptr, nullptr);
-}
-
-void invokePad(const ExecEnv *env, const ir::Operation &node)
-{
- const auto input_index = node.getInputs().at(ir::operation::Pad::INPUT);
- const auto pad_index = node.getInputs().at(ir::operation::Pad::PAD);
- const auto output_index = node.getOutputs().at(0);
-
- const auto input_tensor = env->tensorAt(input_index);
- const auto pad_tensor = env->tensorAt(pad_index);
- const auto output_tensor = env->tensorAt(output_index);
-
- const auto data_type = input_tensor->data_type();
-
- if (data_type == ir::DataType::FLOAT32)
- {
- invoke(input_tensor, pad_tensor, output_tensor);
- }
- else
- {
- throw std::runtime_error{"Interp(Pad): NYI - Unsupported data type"};
- }
-}
-} // namespace
-
-OpKernel *getPad()
-{
- static OpKernel kernel = {preparePad, invokePad};
- return &kernel;
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/operations/Pool2D.cc b/runtime/onert/core/src/interp/operations/Pool2D.cc
deleted file mode 100644
index 92f9d70b2..000000000
--- a/runtime/onert/core/src/interp/operations/Pool2D.cc
+++ /dev/null
@@ -1,140 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <cker/operation/AveragePool.h>
-#include <cker/operation/MaxPool.h>
-
-#include "OperationUtil.h"
-
-#include "interp/Registration.h"
-#include "ir/operation/Pool2D.h"
-#include "util/Utils.h"
-#include "util/ShapeInference.h"
-#include "misc/polymorphic_downcast.h"
-
-namespace onert
-{
-namespace interp
-{
-namespace pool2d
-{
-
-void preparePool2D(ExecEnv *env, const ir::Operation &node)
-{
- const auto &pool_node = nnfw::misc::polymorphic_downcast<const ir::operation::Pool2D &>(node);
- const auto in_index = node.getInputs().at(pool_node.INPUT);
- const auto out_index = node.getOutputs().at(0);
-
- const auto in_tensor = env->tensorAt(in_index);
- UNUSED_RELEASE(in_tensor);
-
- assert(in_tensor->num_dimensions() == 4);
-
- const auto output_info = env->graph().operands().at(out_index).info();
- if (output_info.total_size() == 0)
- {
- // Handle unspecified output shape
- const auto infered_output_shape =
- shape_inference::inferPoolShape(in_tensor->tensorInfo().shape(), pool_node.param());
- env->allocateIfNeeded(
- out_index, ir::OperandInfo::createStaticInfo(infered_output_shape, output_info.typeInfo()));
- }
- else
- {
- env->allocateIfNeeded(out_index, output_info);
- }
-
- auto out_tensor = env->tensorAt(out_index);
- UNUSED_RELEASE(out_tensor);
-
- // Handle same ifm & ofm data type only
- assert(in_tensor->data_type() == out_tensor->data_type());
- assert(out_tensor->num_dimensions() == 4);
-}
-
-template <typename T>
-void invoke(const nnfw::cker::PoolParams &params, const nnfw::cker::Shape &in_shape,
- const T *in_ptr, const nnfw::cker::Shape &out_shape, T *out_ptr,
- ir::operation::Pool2D::PoolType op_type)
-{
- switch (op_type)
- {
- case ir::operation::Pool2D::PoolType::AVG:
- nnfw::cker::AveragePool<T>(params, in_shape, in_ptr, out_shape, out_ptr);
- break;
- case ir::operation::Pool2D::PoolType::MAX:
- nnfw::cker::MaxPool<T>(params, in_shape, in_ptr, out_shape, out_ptr);
- break;
- default:
- throw std::runtime_error{"Interp(Pool2D): NYI unsupported operation"};
- break;
- }
-}
-
-void invokePool2DOps(const ExecEnv *env, const ir::Operation &node)
-{
- const auto &pool_node = nnfw::misc::polymorphic_downcast<const ir::operation::Pool2D &>(node);
-
- const auto in_index = node.getInputs().at(0);
- const auto out_index = node.getOutputs().at(0);
-
- // Check lhs shape is same with rhs (with broadcast)
- const auto in_tensor = env->tensorAt(in_index);
- const auto out_tensor = env->tensorAt(out_index);
-
- // TODO support NCHW frontend
- const auto ifm_shape = in_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC);
- const auto ofm_shape = out_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC);
- const auto param = pool_node.param();
- const auto padding =
- ir::calculatePadding(param.padding, ifm_shape, ofm_shape, param.stride, param.kw, param.kh);
- // Calculate
- nnfw::cker::PoolParams cker_param;
- cker_param.filter_width = param.kw;
- cker_param.filter_height = param.kh;
- cker_param.padding_values.width = padding.left;
- cker_param.padding_values.height = padding.top;
- cker_param.stride_width = param.stride.horizontal;
- cker_param.stride_height = param.stride.vertical;
-
- const auto data_type = in_tensor->data_type();
- if (data_type == ir::DataType::FLOAT32)
- {
- calculateActivationRange(param.activation, &cker_param.float_activation_min,
- &cker_param.float_activation_max);
-
- const auto in_shape = convertShape(in_tensor->tensorInfo().shape());
- const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
- const float *in_ptr = reinterpret_cast<const float *>(in_tensor->bufferRO());
- float *out_ptr = reinterpret_cast<float *>(out_tensor->buffer());
- // Now, invoke() supports only Pool2D in float
- invoke<float>(cker_param, in_shape, in_ptr, out_shape, out_ptr, param.op_type);
- }
- else
- {
- throw std::runtime_error{"NYI: Support float only"};
- }
-}
-} // namespace pool2d
-
-OpKernel *getPool2D()
-{
- static OpKernel kernel = {pool2d::preparePool2D, pool2d::invokePool2DOps};
- return &kernel;
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/operations/Reshape.cc b/runtime/onert/core/src/interp/operations/Reshape.cc
deleted file mode 100644
index 3a118456b..000000000
--- a/runtime/onert/core/src/interp/operations/Reshape.cc
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "interp/Registration.h"
-
-namespace onert
-{
-namespace interp
-{
-namespace
-{
-
-void prepare(ExecEnv *env, const ir::Operation &node)
-{
- const auto in_index = node.getInputs().at(0);
- const auto out_index = node.getOutputs().at(0);
-
- // Unspecified shape is not supported in operation node spec now
- const auto output_info = env->graph().operands().at(out_index).info();
- env->allocateAndShareIfNeeded(out_index, output_info, in_index);
-
- assert(output_info.total_size() == env->graph().operands().at(in_index).info().total_size());
-}
-
-void invoke(const ExecEnv *env, const ir::Operation &node)
-{
- const auto in_index = node.getInputs().at(0);
- const auto out_index = node.getOutputs().at(0);
-
- if (env->tensorAt(in_index)->bufferRO() == env->tensorAt(out_index)->bufferRO())
- {
- // Same data
- return;
- }
-
- const auto output_info = env->graph().operands().at(out_index).info();
- memcpy(env->tensorAt(out_index)->buffer(), env->tensorAt(in_index)->bufferRO(),
- output_info.total_size());
-}
-
-} // namespace
-
-OpKernel *getReshape()
-{
- static OpKernel kernel = {prepare, invoke};
- return &kernel;
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/operations/Softmax.cc b/runtime/onert/core/src/interp/operations/Softmax.cc
deleted file mode 100644
index d30f78deb..000000000
--- a/runtime/onert/core/src/interp/operations/Softmax.cc
+++ /dev/null
@@ -1,123 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <cker/operation/SoftMax.h>
-
-#include "OperationUtil.h"
-
-#include "interp/Registration.h"
-#include "ir/operation/Softmax.h"
-#include "misc/polymorphic_downcast.h"
-
-namespace onert
-{
-namespace interp
-{
-namespace
-{
-
-void prepareSoftMax(ExecEnv *env, const ir::Operation &node)
-{
- const auto in_index = node.getInputs().at(0);
- const auto out_index = node.getOutputs().at(0);
-
- const auto in_tensor = env->tensorAt(in_index);
- UNUSED_RELEASE(in_tensor);
-
- assert((in_tensor->num_dimensions() == 4) || (in_tensor->num_dimensions() == 2));
-
- // Output shape should be same with input
- // Output type is pre-defined in model
- const auto output_shape = env->graph().operands().at(in_index).info().shape();
- const auto output_type = env->graph().operands().at(out_index).info().typeInfo();
-
- const auto output_info = ir::OperandInfo::createStaticInfo(output_shape, output_type);
- env->allocateIfNeeded(out_index, output_info);
-
- auto out_tensor = env->tensorAt(out_index);
- UNUSED_RELEASE(out_tensor);
-
- // Check output shape is same with input
- assert(out_tensor->num_dimensions() == out_tensor->num_dimensions());
- for (uint32_t i = 0; i < in_tensor->num_dimensions(); i++)
- {
- assert(in_tensor->dimension(i) == out_tensor->dimension(i));
- }
-}
-
-void invoke(const ITensor *in_tensor, const ITensor *out_tensor,
- const ir::operation::Softmax::Param &param)
-{
- const float *in_ptr = reinterpret_cast<const float *>(in_tensor->bufferRO());
- float *out_ptr = reinterpret_cast<float *>(out_tensor->buffer());
-
- float beta = param.beta;
-
- if (in_tensor->num_dimensions() == 2)
- {
- uint32_t batch_size = in_tensor->dimension(0);
- uint32_t input_size = in_tensor->dimension(1);
-
- nnfw::cker::Softmax(in_ptr, input_size, batch_size, beta, out_ptr);
- }
- else if (in_tensor->num_dimensions() == 4)
- {
- const auto in_shape = convertShape(in_tensor->tensorInfo().shape());
- const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
-
- nnfw::cker::SoftmaxParams cker_param;
- cker_param.beta = beta;
-
- nnfw::cker::Softmax(cker_param, in_shape, in_ptr, out_shape, out_ptr);
- }
- else
- {
- throw std::runtime_error{"Unsuported input dimension: support 2D or 4D"};
- }
-}
-
-void invokeSoftMax(const ExecEnv *env, const ir::Operation &node)
-{
- const auto &softmax_node = nnfw::misc::polymorphic_downcast<const ir::operation::Softmax &>(node);
-
- const auto in_index = node.getInputs().at(0);
- const auto out_index = node.getOutputs().at(0);
-
- const auto in_tensor = env->tensorAt(in_index);
- const auto out_tensor = env->tensorAt(out_index);
-
- const auto in_data_type = in_tensor->data_type();
- const auto out_data_type = out_tensor->data_type();
- if ((in_data_type == ir::DataType::FLOAT32) && (out_data_type == ir::DataType::FLOAT32))
- {
- invoke(in_tensor, out_tensor, softmax_node.param());
- }
- else
- {
- throw std::runtime_error{"NYI: Support float32 only"};
- }
-}
-
-} // namespace
-
-OpKernel *getSoftmax()
-{
- static OpKernel kernel = {prepareSoftMax, invokeSoftMax};
- return &kernel;
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/interp/operations/TransposeConv.cc b/runtime/onert/core/src/interp/operations/TransposeConv.cc
deleted file mode 100644
index cc2ced26b..000000000
--- a/runtime/onert/core/src/interp/operations/TransposeConv.cc
+++ /dev/null
@@ -1,141 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <cker/operation/TransposeConv.h>
-#include <misc/polymorphic_downcast.h>
-
-#include "OperationUtil.h"
-
-#include "interp/Registration.h"
-#include "ir/operation/TransposeConv.h"
-
-namespace onert
-{
-namespace interp
-{
-namespace
-{
-
-void prepareTransposeConv(ExecEnv *env, const ir::Operation &node)
-{
- const auto ifm_index = node.getInputs().at(ir::operation::TransposeConv::INPUT);
- const auto ker_index = node.getInputs().at(ir::operation::TransposeConv::KERNEL);
- const auto ofm_shape_index = node.getInputs().at(ir::operation::TransposeConv::OUTPUT_SHAPE);
- const auto ofm_index = node.getOutputs().at(0);
-
- const auto ifm_tensor = env->tensorAt(ifm_index);
- const auto ker_tensor = env->tensorAt(ker_index);
- const auto ofm_shape_tensor = env->tensorAt(ofm_shape_index);
-
- assert(ifm_tensor->num_dimensions() == 4);
- assert(ker_tensor->num_dimensions() == 4);
- assert(ofm_shape_tensor->num_dimensions() == 1);
-
- UNUSED_RELEASE(ifm_tensor);
- UNUSED_RELEASE(ker_tensor);
- UNUSED_RELEASE(ofm_shape_tensor);
-
- const auto output_info = env->graph().operands().at(ofm_index).info();
- if (output_info.total_size() == 0)
- {
- // TODO: Handle unspecified output shape
- throw std::runtime_error{"Interp(TConv): NYI unspecified output shape"};
- }
- else
- {
- env->allocateIfNeeded(ofm_index, output_info);
- }
-
- auto ofm_tensor = env->tensorAt(ofm_index);
- UNUSED_RELEASE(ofm_tensor);
-
- // Handle same ifm & ofm data type only
- if (ifm_tensor->data_type() != ofm_tensor->data_type())
- {
- throw std::runtime_error{"Interp(TConv): Different I/O data dype"};
- }
-
- if (ofm_tensor->num_dimensions() != 4)
- {
- throw std::runtime_error{"Interp(TConv): Invalid output rank"};
- }
-}
-
-void invoke(const ITensor *ifm_tensor, const ITensor *ker_tensor, const ITensor *ofm_tensor,
- const ir::operation::TransposeConv::Param &param)
-{
- const auto ifm_shape = ifm_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC);
- const auto ofm_shape = ofm_tensor->tensorInfo().shape().asFeature(ir::Layout::NHWC);
- // Kernel format is [depth_out, kernel_height, kernel_width, depth_in].
- const auto ker_shape = ker_tensor->tensorInfo().shape();
- const auto ker_height = ker_shape.dim(1);
- const auto ker_width = ker_shape.dim(2);
- const auto padding = ir::calculatePadding(param.padding, ofm_shape, ifm_shape, param.stride,
- ker_width, ker_height);
-
- nnfw::cker::TransposeConvParams cker_param;
- cker_param.padding_values.width = padding.left;
- cker_param.padding_values.height = padding.top;
- cker_param.stride_width = param.stride.horizontal;
- cker_param.stride_height = param.stride.vertical;
- cker_param.dilation_width_factor = 1;
- cker_param.dilation_height_factor = 1;
-
- const auto cker_ifm_shape = convertShape(ifm_tensor->tensorInfo().shape());
- const auto cker_ker_shape = convertShape(ker_tensor->tensorInfo().shape());
- const auto cker_ofm_shape = convertShape(ofm_tensor->tensorInfo().shape());
- const float *ifm_ptr = reinterpret_cast<const float *>(ifm_tensor->bufferRO());
- const float *ker_ptr = reinterpret_cast<const float *>(ker_tensor->bufferRO());
- float *ofm_ptr = reinterpret_cast<float *>(ofm_tensor->buffer());
-
- nnfw::cker::TransposeConv(cker_param, cker_ifm_shape, ifm_ptr, cker_ker_shape, ker_ptr,
- cker_ofm_shape, ofm_ptr);
-}
-
-void invokeTransposeConv(const ExecEnv *env, const ir::Operation &node)
-{
- const auto &tconv_node =
- nnfw::misc::polymorphic_downcast<const ir::operation::TransposeConv &>(node);
-
- const auto ifm_index = node.getInputs().at(ir::operation::TransposeConv::INPUT);
- const auto ker_index = node.getInputs().at(ir::operation::TransposeConv::KERNEL);
- const auto ofm_index = node.getOutputs().at(0);
-
- const auto ifm_tensor = env->tensorAt(ifm_index);
- const auto ker_tensor = env->tensorAt(ker_index);
- const auto ofm_tensor = env->tensorAt(ofm_index);
-
- const auto data_type = ifm_tensor->data_type();
- if (data_type == ir::DataType::FLOAT32)
- {
- invoke(ifm_tensor, ker_tensor, ofm_tensor, tconv_node.param());
- }
- else
- {
- throw std::runtime_error{"Interp(TConv): Support float32 only"};
- }
-}
-
-} // namespace
-
-OpKernel *getTransposeConv()
-{
- static OpKernel kernel = {prepareTransposeConv, invokeTransposeConv};
- return &kernel;
-}
-
-} // namespace interp
-} // namespace onert
diff --git a/runtime/onert/core/src/ir/DataType.cc b/runtime/onert/core/src/ir/DataType.cc
index 80c659b3a..07670c720 100644
--- a/runtime/onert/core/src/ir/DataType.cc
+++ b/runtime/onert/core/src/ir/DataType.cc
@@ -41,11 +41,17 @@ size_t sizeOfDataType(DataType data_type)
case DataType::UINT8:
return sizeof(uint8_t);
case DataType::QUANT_INT8_SYMM:
+ case DataType::QUANT_INT8_ASYMM:
+ case DataType::QUANT_INT8_SYMM_PER_CHANNEL:
return sizeof(int8_t);
case DataType::FLOAT16:
return sizeof(float16);
case DataType::INT64:
return sizeof(int64_t);
+ case DataType::QUANT_INT16_ASYMM:
+ return sizeof(int16_t);
+ case DataType::QUANT_INT16_SYMM:
+ return sizeof(int16_t);
default:
throw std::runtime_error{"Unsupported type size"};
}
diff --git a/runtime/onert/core/src/ir/Graph.cc b/runtime/onert/core/src/ir/Graph.cc
index fe8b1b443..306572c99 100644
--- a/runtime/onert/core/src/ir/Graph.cc
+++ b/runtime/onert/core/src/ir/Graph.cc
@@ -16,18 +16,10 @@
#include "ir/Graph.h"
-#include <algorithm>
-#include <bitset>
-#include <sstream>
-
-#include "util/logging.h"
+#include "OperationValidator.h"
#include "verifier/Verifier.h"
-#include "ir/operation/LowerInfo.h"
-#include "ir/operand/LowerInfo.h"
-#include "ir/operand/PermuteFactor.h"
-#include "ir/OperandIndexMap.h"
-#include "ir/GraphIterator.h"
-#include "backend/IConfig.h"
+
+#include "util/Set.h"
namespace onert
{
@@ -36,6 +28,8 @@ namespace ir
Graph::Graph() = default;
+Graph::Graph(const Graph &) = default;
+
Graph::~Graph(void) = default;
OperandIndex Graph::addOperand(const Shape &shape, const TypeInfo &type)
@@ -43,22 +37,91 @@ OperandIndex Graph::addOperand(const Shape &shape, const TypeInfo &type)
return _operands.emplace(shape, type);
}
-OperationIndex Graph::addOperation(std::unique_ptr<Operation> &&node)
+OperandIndex Graph::addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand)
+{
+ return _operands.push(std::move(operand), index);
+}
+
+bool Graph::checkOperandsForOperation(const IOperation &operation)
{
- assert(isBuildingPhase());
- return _operations.push(std::move(node));
+ auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
+ auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
+ for (auto &&input : inputs)
+ if (!operands().exist(input))
+ return false;
+ for (auto &&input : outputs)
+ if (!operands().exist(input))
+ return false;
+ return true;
+}
+
+void Graph::linkOperandToOperation(OperationIndex index, const IOperation &operation)
+{
+ auto inputs = operation.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
+ auto outputs = operation.getOutputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED;
+
+ for (auto &&input : inputs)
+ operands().at(input).insertUse(index);
+ for (auto &&output : outputs)
+ operands().at(output).setDef(index);
+}
+
+OperationIndex Graph::addOperation(std::unique_ptr<IOperation> &&operation)
+{
+ const IOperation &op_ref = *operation;
+ if (!checkOperandsForOperation(op_ref))
+ return OperationIndex{};
+ auto ind = _operations.push(std::move(operation));
+ if (ind.valid())
+ linkOperandToOperation(ind, op_ref);
+ return ind;
+}
+
+OperationIndex Graph::addOperation(OperationIndex index, std::unique_ptr<IOperation> &&operation)
+{
+ const IOperation &op_ref = *operation;
+ if (!checkOperandsForOperation(op_ref))
+ return OperationIndex{};
+ auto ind_gen = _operations.push(std::move(operation), index);
+ if (ind_gen.valid())
+ {
+ assert(ind_gen == index);
+ linkOperandToOperation(index, op_ref);
+ }
+ return index;
+}
+
+OperationIndex Graph::replaceOperation(OperationIndex index,
+ std::unique_ptr<IOperation> &&operation)
+{
+ const IOperation &op_ref = *operation;
+ if (!checkOperandsForOperation(op_ref) || !_operations.exist(index))
+ return OperationIndex{};
+
+ // Check the new operation has the same inputs/outputs as the existing operation
+ const auto &old_op = _operations.at(index);
+ if (!(old_op.getInputs() == op_ref.getInputs() && old_op.getOutputs() == op_ref.getOutputs()))
+ {
+ return OperationIndex{};
+ }
+
+ return _operations.set(index, std::move(operation));
}
void Graph::setOperandValue(const OperandIndex &ind, std::shared_ptr<Data> data)
{
- assert(isBuildingPhase());
assert(_operands.exist(ind));
_operands.at(ind).data(std::move(data));
}
+void Graph::changeShape(const OperandIndex &ind, const ir::Shape &new_shape)
+{
+ assert(_operands.exist(ind));
+ _operands.at(ind).info().shape(new_shape);
+}
+
void Graph::addInput(const OperandIndex &ind, const std::string &name)
{
- assert(isBuildingPhase());
if (!name.empty())
_name_to_input.emplace(name, IOIndex{_inputs.size()});
_inputs.append(ind);
@@ -66,7 +129,6 @@ void Graph::addInput(const OperandIndex &ind, const std::string &name)
void Graph::addOutput(const OperandIndex &ind, const std::string &name)
{
- assert(isBuildingPhase());
if (!name.empty())
_name_to_output.emplace(name, IOIndex{_outputs.size()});
_outputs.append(ind);
@@ -84,62 +146,70 @@ IOIndex Graph::getOutputIndex(const std::string &name) const
return (itr == _name_to_output.end()) ? IOIndex{} : itr->second;
}
-void Graph::finishBuilding(void)
+void Graph::verify(void) const
{
- assert(isBuildingPhase());
- _phase = Phase::MODEL;
-
- initializeUseDef();
- sweepGarbageOperands();
-
// Call graph verifications for the MODEL phase
{
- assert(verifier::DAGChecker().verify(*this));
- assert(verifier::EdgeConsistencyChecker().verify(*this));
+ // Except for edge consistency, the user might have been given a bad model
+ // so here it throws an execption rather than assertion.
+ if (!verifier::InputOutputChecker().verify(*this))
+ throw std::runtime_error{"One of model input and output operands does not exist."};
+ if (!verifier::DAGChecker().verify(*this))
+ throw std::runtime_error{"The graph is cyclic."};
+ assert(verifier::EdgeChecker().verify(*this));
}
+
+ // Check shape independent operation feature
+ // - Operand type
+ // - Shape independent parameter
+ OperationValidator{*this}();
}
void Graph::initializeUseDef()
{
- operations().iterate([&](const OperationIndex &index, const Operation &node) -> void {
- auto outputs = node.getOutputs();
- for (auto output : outputs)
+ operations().iterate([&](const OperationIndex &index, const IOperation &node) -> void {
+ const auto &outputs = node.getOutputs();
+ for (auto &&output : outputs | ir::Remove::UNDEFINED)
{
operands().at(output).setDef(index);
}
- for (auto input : node.getInputs() | ir::Remove::UNDEFINED)
+ for (auto &&input : node.getInputs() | ir::Remove::UNDEFINED)
{
operands().at(input).insertUse(index);
}
});
}
-void Graph::sweepGarbageOperands()
+std::vector<ir::OperationIndex> Graph::topolSortOperations() const
{
- // Remove operands that are not used by any operations, except Graph inputs/outputs
- ir::OperandIndexMap<bool> visited;
-
- operations().iterate([&](const OperationIndex &, const Operation &node) {
- for (auto ind : node.getInputs() + node.getOutputs())
- {
- visited[ind] = true;
- }
- });
-
- // Graph's inputs/outputs are always reachable
- for (auto ind : getInputs() + getOutputs())
- {
- visited[ind] = true;
- }
-
- operands().iterate([&](const OperandIndex &ind, const Operand &) {
- if (!visited[ind])
+ std::vector<ir::OperationIndex> ret;
+ util::Set<ir::OperationIndex> unvisited;
+ operations().iterate(
+ [&](const ir::OperationIndex &index, const ir::IOperation &) { unvisited.add(index); });
+
+ std::function<void(const ir::OperationIndex &, const ir::IOperation &)> dfs =
+ [&](const ir::OperationIndex &index, const ir::IOperation &op) -> void {
+ if (!unvisited.contains(index))
+ return;
+ unvisited.remove(index);
+
+ for (const auto &output : op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
{
- VERBOSE(Graph::sweepGarbageOperands) << "Sweep garbage operand " << ind.value() << std::endl;
- operands().remove(ind);
+ const auto &operand = operands().at(output);
+ for (const auto &use : operand.getUses())
+ {
+ dfs(use, operations().at(use));
+ }
}
- });
+ ret.push_back(index);
+ };
+ operations().iterate(dfs);
+
+ assert(unvisited.empty()); // All of the nodes must have been visited
+ // Reversing Postorder DFS result to make it sorted in topoligical order
+ std::reverse(ret.begin(), ret.end());
+ return ret;
}
} // namespace ir
diff --git a/runtime/onert/core/src/ir/Graph.test.cc b/runtime/onert/core/src/ir/Graph.test.cc
new file mode 100644
index 000000000..144500745
--- /dev/null
+++ b/runtime/onert/core/src/ir/Graph.test.cc
@@ -0,0 +1,147 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/Graph.h"
+#include "ir/operation/BinaryArithmetic.h"
+
+#include <gtest/gtest.h>
+
+TEST(Graph, neg_inputs_and_outputs)
+{
+ onert::ir::Graph graph;
+
+ onert::ir::OperandIndex index0{0u};
+ onert::ir::OperandIndex index1{1u};
+
+ graph.addInput({index0});
+ graph.addInput({index1});
+
+ onert::ir::OperandIndex index10{10u};
+ onert::ir::OperandIndex index11{11u};
+ onert::ir::OperandIndex index12{12u};
+
+ graph.addOutput({index10});
+ graph.addOutput({index11});
+ graph.addOutput({index12});
+
+ ASSERT_EQ(graph.getInputs().size(), 2);
+ ASSERT_EQ(graph.getOutputs().size(), 3);
+
+ onert::ir::IOIndex io_index0{0};
+ onert::ir::IOIndex io_index1{1};
+ onert::ir::IOIndex io_index2{2};
+
+ ASSERT_EQ(graph.getInputs().at(io_index0), 0);
+ ASSERT_EQ(graph.getInputs().at(io_index1), 1);
+
+ ASSERT_EQ(graph.getOutputs().at(io_index0), 10);
+ ASSERT_EQ(graph.getOutputs().at(io_index1), 11);
+ ASSERT_EQ(graph.getOutputs().at(io_index2), 12);
+
+ EXPECT_THROW(graph.getOutputs().at(onert::ir::IOIndex{3}), std::out_of_range);
+}
+
+using namespace onert::ir;
+
+OperationIndex addAddOperation(Graph &graph, const OperandIndexSequence inputs,
+ const OperandIndexSequence outputs)
+{
+ // Add "ADD" operation
+ operation::BinaryArithmetic::Param param;
+ param.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param.activation = Activation::NONE;
+ return graph.addOperation(std::make_unique<operation::BinaryArithmetic>(inputs, outputs, param));
+}
+
+TEST(Graph, OneOpGraphSimpleValid)
+{
+ // Simple Graph with just one Add operation
+
+ Graph graph;
+
+ // Add tensors
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+ auto lhs = graph.addOperand(shape, type);
+ auto rhs = graph.addOperand(shape, type);
+ auto res = graph.addOperand(shape, type);
+
+ addAddOperation(graph, {lhs, rhs}, {res});
+
+ // Set model inputs/outputs
+ graph.addInput(lhs);
+ graph.addInput(rhs);
+ graph.addOutput(res);
+
+ graph.verify();
+
+ SUCCEED();
+}
+
+TEST(Graph, neg_InvalidGraph_BadInput)
+{
+ Graph graph;
+
+ // Add tensors
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+ auto in = graph.addOperand(shape, type);
+ auto out = graph.addOperand(shape, type);
+
+ // Set model inputs/outputs
+ graph.addInput(in);
+ graph.addOutput(out);
+ graph.addInput(OperandIndex{89}); // Non-exisiting operand!
+
+ EXPECT_ANY_THROW(graph.verify());
+}
+
+TEST(Graph, neg_InvalidGraph_BadOutput)
+{
+ Graph graph;
+
+ // Add tensors
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+ auto in = graph.addOperand(shape, type);
+ auto out = graph.addOperand(shape, type);
+
+ // Set model inputs/outputs
+ graph.addInput(in);
+ graph.addOutput(out);
+ graph.addOutput(OperandIndex{12}); // Non-exisiting operand!
+
+ EXPECT_ANY_THROW(graph.verify());
+}
+
+TEST(Graph, neg_InvalidAddOperation_BadInputIndex)
+{
+ Graph graph;
+
+ // Add tensors
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+ auto lhs = graph.addOperand(shape, type);
+ auto rhs = graph.addOperand(shape, type);
+ auto res = graph.addOperand(shape, type);
+
+ // Set model inputs/outputs
+ graph.addInput(lhs);
+ graph.addInput(rhs);
+ graph.addOutput(res);
+
+ ASSERT_FALSE(addAddOperation(graph, {lhs, OperandIndex{99}}, {res}).valid());
+}
diff --git a/runtime/onert/core/src/ir/GraphIterator.cc b/runtime/onert/core/src/ir/GraphIterator.cc
deleted file mode 100644
index 4bea1a55d..000000000
--- a/runtime/onert/core/src/ir/GraphIterator.cc
+++ /dev/null
@@ -1,121 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "GraphIterator.h"
-
-#include "ir/OperationIndexMap.h"
-#include "compiler/LoweredGraph.h"
-
-namespace onert
-{
-namespace ir
-{
-
-//
-// Graph::DefaultIterator
-//
-
-template <bool is_const>
-void DefaultIterator<is_const>::iterate(GraphRef graph, const IterFn &fn) const
-{
- graph.operations().iterate(
- [&](const OperationIndex &index, NodeRef node) -> void { fn(index, node); });
-}
-
-//
-// Graph::PostDfsIterator
-//
-
-template <bool is_const>
-void PostDfsIterator<is_const>::iterate(GraphRef graph, const IterFn &fn) const
-{
- assert(!graph.isBuildingPhase()); // Restrict iteration condition
-
- OperationIndexMap<bool> visited;
- graph.operations().iterate([&](const OperationIndex &index, NodeRef) { visited[index] = false; });
-
- std::function<void(const OperationIndex &, NodeRef)> dfs_recursive =
- [&](const OperationIndex &index, NodeRef node) -> void {
- if (visited[index])
- return;
- visited[index] = true;
-
- for (const auto output : node.getOutputs() | Remove::DUPLICATED)
- {
- const auto &operand = graph.operands().at(output);
- for (const auto &use : operand.getUses())
- {
- dfs_recursive(use, graph.operations().at(use));
- }
- }
-
- fn(index, node);
- };
-
- graph.operations().iterate(dfs_recursive);
-
- // All of the operations(nodes) must have been visited.
- assert(std::all_of(visited.begin(), visited.end(),
- [](const std::pair<const OperationIndex, bool> &v) { return v.second; }));
-}
-
-template <bool is_const>
-void PostDfsIterator<is_const>::iterateOpSeqs(LoweredGraphRef lowered_graph,
- const OpSeqIterFn &fn) const
-{
- std::unordered_map<OpSequenceIndex, bool> visited;
- lowered_graph.op_seqs().iterate(
- [&](const OpSequenceIndex &index, OpSequenceRef) { visited[index] = false; });
-
- std::function<void(const OpSequenceIndex &, OpSequenceRef)> dfs_recursive =
- [&](const OpSequenceIndex &index, OpSequenceRef op_seq) -> void {
- if (visited[index])
- return;
- visited[index] = true;
-
- for (const auto output : op_seq.getOutputs() | Remove::DUPLICATED)
- {
- const auto &operand = lowered_graph.graph().operands().at(output);
- for (const auto &use : operand.getUses())
- {
- const auto use_op_seq_index = lowered_graph.op_seqs().getOperation(use);
- dfs_recursive(use_op_seq_index, lowered_graph.op_seqs().at(use_op_seq_index));
- }
- }
-
- fn(index, op_seq);
- };
-
- lowered_graph.op_seqs().iterate(dfs_recursive);
-
- // All of the operations(nodes) must have been visited.
- assert(std::all_of(visited.begin(), visited.end(),
- [](const std::pair<const OpSequenceIndex, bool> &v) { return v.second; }));
-}
-
-// Explicit instantiations to have implementation in the source file.
-// NOTE If these instatiations were in the top of this file, `iterate` is compiled and saved in
-// `GraphIterator.cc.o` but `iterateOpSeqs`. This happens only when cross-building for Android.
-// (Maybe a bug of NDK toolchain(clang)?)
-
-template class DefaultIterator<true>;
-template class DefaultIterator<false>;
-
-template class PostDfsIterator<true>;
-template class PostDfsIterator<false>;
-
-} // namespace ir
-} // namespace onert
diff --git a/runtime/onert/core/src/ir/GraphIterator.h b/runtime/onert/core/src/ir/GraphIterator.h
deleted file mode 100644
index b54314e0e..000000000
--- a/runtime/onert/core/src/ir/GraphIterator.h
+++ /dev/null
@@ -1,90 +0,0 @@
-/*
- * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __ONERT_IR_GRAPH_ITERATOR_H__
-#define __ONERT_IR_GRAPH_ITERATOR_H__
-
-#include <type_traits>
-
-#include "ir/Index.h"
-
-namespace onert
-{
-namespace compiler
-{
-class LoweredGraph;
-} // namespace compiler
-} // namespace onert
-
-namespace onert
-{
-namespace ir
-{
-
-class Graph;
-class Operation;
-class OpSequence;
-
-template <bool is_const> class Iterator
-{
-public:
- using GraphRef = typename std::conditional<is_const, const Graph &, Graph &>::type;
- using IndexRef = const OperationIndex &;
- using NodeRef = typename std::conditional<is_const, const Operation &, Operation &>::type;
- using IterFn = std::function<void(IndexRef, NodeRef)>;
-
-public:
- virtual ~Iterator() = default;
- virtual void iterate(GraphRef graph, const IterFn &fn) const = 0;
-};
-
-template <bool is_const = false> class DefaultIterator final : public Iterator<is_const>
-{
-public:
- using GraphRef = typename Iterator<is_const>::GraphRef;
- using IndexRef = typename Iterator<is_const>::IndexRef;
- using NodeRef = typename Iterator<is_const>::NodeRef;
- using IterFn = typename Iterator<is_const>::IterFn;
-
-public:
- void iterate(GraphRef graph, const IterFn &fn) const;
-};
-using DefaultConstIterator = DefaultIterator<true>;
-
-template <bool is_const = false> class PostDfsIterator final : public Iterator<is_const>
-{
-public:
- using GraphRef = typename Iterator<is_const>::GraphRef;
- using IndexRef = typename Iterator<is_const>::IndexRef;
- using NodeRef = typename Iterator<is_const>::NodeRef;
- using IterFn = typename Iterator<is_const>::IterFn;
- using LoweredGraphRef =
- typename std::conditional<is_const, const typename compiler::LoweredGraph &,
- typename compiler::LoweredGraph &>::type;
- using OpSequenceRef = typename std::conditional<is_const, const OpSequence &, OpSequence &>::type;
- using OpSeqIndexRef = const OpSequenceIndex &;
- using OpSeqIterFn = std::function<void(OpSeqIndexRef, OpSequenceRef)>;
-
-public:
- void iterate(GraphRef graph, const IterFn &fn) const;
- void iterateOpSeqs(LoweredGraphRef lowered_graph, const OpSeqIterFn &f) const;
-};
-using PostDfsConstIterator = PostDfsIterator<true>;
-
-} // namespace ir
-} // namespace onert
-
-#endif // __ONERT_IR_GRAPH_ITERATOR_H__
diff --git a/runtime/onert/core/src/ir/LayoutSet.cc b/runtime/onert/core/src/ir/LayoutSet.cc
index bd3f438ad..732460aa2 100644
--- a/runtime/onert/core/src/ir/LayoutSet.cc
+++ b/runtime/onert/core/src/ir/LayoutSet.cc
@@ -23,7 +23,7 @@ namespace ir
LayoutSet::LayoutSet(std::initializer_list<Layout> layouts)
{
- for (auto layout : layouts)
+ for (auto &&layout : layouts)
{
_set.insert(layout);
}
@@ -32,7 +32,7 @@ LayoutSet::LayoutSet(std::initializer_list<Layout> layouts)
LayoutSet LayoutSet::operator|(const LayoutSet &other) const
{
auto ret = *this;
- for (auto layout : other)
+ for (auto &&layout : other)
{
ret.add(layout);
}
@@ -42,7 +42,7 @@ LayoutSet LayoutSet::operator|(const LayoutSet &other) const
LayoutSet LayoutSet::operator&(const LayoutSet &other) const
{
LayoutSet ret;
- for (auto layout : other)
+ for (auto &&layout : other)
{
if (contains(layout))
{
@@ -55,7 +55,7 @@ LayoutSet LayoutSet::operator&(const LayoutSet &other) const
LayoutSet LayoutSet::operator-(const LayoutSet &other) const
{
auto ret = *this;
- for (auto layout : other)
+ for (auto &&layout : other)
{
ret.remove(layout);
}
diff --git a/runtime/onert/core/src/ir/LayoutSet.h b/runtime/onert/core/src/ir/LayoutSet.h
index 6ce4e38c6..be077f2f0 100644
--- a/runtime/onert/core/src/ir/LayoutSet.h
+++ b/runtime/onert/core/src/ir/LayoutSet.h
@@ -17,6 +17,7 @@
#ifndef __ONERT_IR_LAYOUT_SET_H__
#define __ONERT_IR_LAYOUT_SET_H__
+#include <cstdint>
#include <initializer_list>
#include <unordered_set>
diff --git a/runtime/onert/core/src/ir/LayoutSet.test.cc b/runtime/onert/core/src/ir/LayoutSet.test.cc
new file mode 100644
index 000000000..fc956abe8
--- /dev/null
+++ b/runtime/onert/core/src/ir/LayoutSet.test.cc
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "LayoutSet.h"
+
+#include <gtest/gtest.h>
+
+using onert::ir::Layout;
+using onert::ir::LayoutSet;
+
+TEST(ir_LayoutSet, neg_add_remove)
+{
+ LayoutSet set{Layout::NCHW};
+ set.remove(Layout::NHWC);
+ ASSERT_EQ(set.size(), 1);
+ set.add(Layout::NHWC);
+ ASSERT_EQ(set.size(), 2);
+ set.remove(Layout::NHWC);
+ ASSERT_EQ(set.size(), 1);
+ set.remove(Layout::NCHW);
+ ASSERT_EQ(set.size(), 0);
+ set.remove(Layout::NCHW);
+ ASSERT_EQ(set.size(), 0);
+}
+
+TEST(ir_LayoutSet, neg_add_twice)
+{
+ LayoutSet set;
+ set.add(Layout::NHWC);
+ ASSERT_EQ(set.size(), 1);
+ set.add(Layout::NHWC);
+ ASSERT_EQ(set.size(), 1);
+}
+
+TEST(ir_LayoutSet, set_operators)
+{
+ LayoutSet set1{Layout::NCHW};
+ LayoutSet set2{Layout::NHWC};
+ LayoutSet set3 = set1 | set2;
+
+ ASSERT_EQ(set3.size(), 2);
+
+ ASSERT_EQ((set3 - set1).size(), 1);
+ ASSERT_EQ((set3 - set1).contains(Layout::NHWC), true);
+ ASSERT_EQ((set3 - set2).size(), 1);
+ ASSERT_EQ((set3 - set2).contains(Layout::NCHW), true);
+ ASSERT_EQ((set3 - set3).size(), 0);
+
+ ASSERT_EQ((set3 & set1).size(), 1);
+ ASSERT_EQ((set3 & set1).contains(Layout::NCHW), true);
+ ASSERT_EQ((set3 & set2).size(), 1);
+ ASSERT_EQ((set3 & set2).contains(Layout::NHWC), true);
+ ASSERT_EQ((set1 & set2).size(), 0);
+}
diff --git a/runtime/onert/core/src/ir/MockNode.h b/runtime/onert/core/src/ir/MockNode.h
new file mode 100644
index 000000000..0e7ed977b
--- /dev/null
+++ b/runtime/onert/core/src/ir/MockNode.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_TEST_GRAPH_MOCK_NODE_H__
+#define __ONERT_TEST_GRAPH_MOCK_NODE_H__
+
+#include "ir/Operation.h"
+#include "ir/OperandIndexSequence.h"
+
+namespace onert_test
+{
+namespace ir
+{
+
+class SimpleMock : public onert::ir::Operation
+{
+public:
+ SimpleMock(const onert::ir::OperandIndexSequence &inputs,
+ const onert::ir::OperandIndexSequence &outputs)
+ : Operation{onert::ir::OperandConstraint::createAny()}
+ {
+ setInputs(inputs);
+ setOutputs(outputs);
+ }
+
+public:
+ void accept(onert::ir::OperationVisitor &) const override {}
+ onert::ir::OpCode opcode() const final { return onert::ir::OpCode::Invalid; }
+};
+
+} // namespace ir
+} // namespace onert_test
+
+#endif // __ONERT_TEST_GRAPH_MOCK_NODE_H__
diff --git a/runtime/onert/core/src/ir/OpSequence.cc b/runtime/onert/core/src/ir/OpSequence.cc
deleted file mode 100644
index e2b989d8c..000000000
--- a/runtime/onert/core/src/ir/OpSequence.cc
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "ir/OpSequence.h"
-
-#include "ir/Operations.h"
-#include "ir/OperationVisitor.h"
-#include <sstream>
-
-namespace
-{
-
-std::string getStrFromIndice(const onert::ir::OperandIndexSequence &indice)
-{
- std::string str;
- for (const auto &ind : indice)
- {
- str += std::to_string(ind.value());
- str.push_back(',');
- }
- if (str.back() == ',')
- str.pop_back();
-
- return str;
-}
-}
-
-namespace onert
-{
-namespace ir
-{
-
-OpSequence::OpSequence(Layout layout) : _layout{layout}, _has_dynamic_tensor{false}
-{
- // DO NOTHING
-}
-
-void OpSequence::accept(OperationVisitor &v) const { v.visit(*this); }
-
-// TODO: Impl Dumper instead of this method
-std::string getStrFromOpSeq(const OpSequence &op_seq, const Operations &operations)
-{
- // " OpSequence IN(0,1,2) -> { op0(0,1,2:3), op1(3:4), op2(4:5) } -> OUT(5)"
- std::stringstream ss;
- ss << " OpSequence IN(" << getStrFromIndice(op_seq.getInputs()) << ") -> {";
- for (const auto &op_idx : op_seq)
- {
- ss << " " << op_idx.value() << "(" << operations.at(op_idx).name() << ":"
- << getStrFromIndice(operations.at(op_idx).getInputs()) << ":"
- << getStrFromIndice(operations.at(op_idx).getOutputs()) << ")";
- }
- ss << " } -> OUT(" << getStrFromIndice(op_seq.getOutputs()) << ")";
- return ss.str();
-}
-
-void OpSequence::remove(const OperationIndex &index)
-{
- assert(exist(index));
- for (auto it = _operations.cbegin(); it != _operations.cend(); ++it)
- {
- if (*it == index)
- {
- _operations.erase(it);
- break;
- }
- }
-}
-
-bool OpSequence::exist(const OperationIndex &index) const
-{
- for (const auto &inner_op_idx : _operations)
- {
- if (inner_op_idx == index)
- {
- return true;
- }
- }
- return false;
-}
-
-} // namespace ir
-} // namespace onert
diff --git a/runtime/onert/core/src/ir/OpSequences.cc b/runtime/onert/core/src/ir/OpSequences.cc
deleted file mode 100644
index 68884783e..000000000
--- a/runtime/onert/core/src/ir/OpSequences.cc
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "ir/OpSequences.h"
-#include "util/logging.h"
-#include <memory>
-
-#include <cassert>
-#include <string>
-
-namespace onert
-{
-namespace ir
-{
-
-OpSequenceIndex OpSequences::emplace(const OperationIndex &index, Layout layout)
-{
- std::unique_ptr<OpSequence> op_seq = std::make_unique<OpSequence>(layout);
- op_seq->appendOperation(index);
- const OpSequenceIndex &seq_index = push(std::move(op_seq));
- cacheSequenceIndex(seq_index, index);
- return seq_index;
-}
-
-OpSequenceIndex OpSequences::emplace(std::unique_ptr<OpSequence> &&op_seq)
-{
- auto &operations = op_seq->operations();
- const OpSequenceIndex &seq_index = push(std::move(op_seq));
- for (const auto &op_idx : operations)
- {
- cacheSequenceIndex(seq_index, op_idx);
- }
- return seq_index;
-}
-
-void OpSequences::cacheSequenceIndex(const OpSequenceIndex &seq_index,
- const OperationIndex &op_index) const
-{
- _seq_indexes.emplace(op_index, seq_index);
-}
-
-OpSequenceIndex *OpSequences::findSequenceIndex(const OperationIndex &operation_index) const
-{
- // If opration_index is cached, return sequence_index from cache
- if (_seq_indexes.count(operation_index))
- {
- auto &op_seq_index = _seq_indexes.at(operation_index);
- if (_objects.count(op_seq_index) && _objects.at(op_seq_index)->exist(operation_index))
- {
- return &op_seq_index;
- }
- else
- {
- _seq_indexes.erase(operation_index);
- return nullptr;
- }
- }
- return nullptr;
-}
-
-bool OpSequences::containsOperation(const OperationIndex &operation_index) const
-{
- return findOperation(operation_index).valid();
-}
-
-OpSequenceIndex OpSequences::getOperation(const OperationIndex &operation_index) const
-{
- OpSequenceIndex ret = findOperation(operation_index);
- assert(ret.valid());
- return ret;
-}
-
-void OpSequences::removeFromOpSequence(const OperationIndex &operation_index)
-{
- const auto op_seq_index = findOperation(operation_index);
- auto &op_seq = at(op_seq_index);
- _seq_indexes.erase(operation_index);
- op_seq.remove(operation_index);
- if (op_seq.size() == 0)
- {
- remove(op_seq_index);
- }
-}
-
-OpSequenceIndex OpSequences::findOperation(const OperationIndex &operation_index) const
-{
- if (OpSequenceIndex *op_seq_index = findSequenceIndex(operation_index))
- return *op_seq_index;
-
- for (auto &e : _objects)
- {
- OpSequence &object = *e.second;
- auto it = find(object.operations().begin(), object.operations().end(), operation_index);
- if (it != object.operations().end())
- {
- cacheSequenceIndex(e.first, operation_index);
- return e.first;
- }
- }
- throw std::runtime_error("Operation not found");
-}
-
-void dumpOpSequences(const OpSequences &op_seqs, const Operations &operations)
-{
- op_seqs.iterate([&](const OpSequenceIndex &idx, const OpSequence &op_seq) {
- VERBOSE(OpSequences) << idx.value() << "] " << getStrFromOpSeq(op_seq, operations) << std::endl;
- });
-}
-
-} // namespace ir
-} // namespace onert
diff --git a/runtime/onert/core/src/ir/Operand.cc b/runtime/onert/core/src/ir/Operand.cc
index e29c7a6ec..18981dbf1 100644
--- a/runtime/onert/core/src/ir/Operand.cc
+++ b/runtime/onert/core/src/ir/Operand.cc
@@ -46,5 +46,11 @@ void Operand::setDef(const OperationIndex &idx) { _def = idx; }
void Operand::unsetDef() { _def = OperationIndex{}; }
+void Operand::clearDefUse()
+{
+ unsetDef();
+ _uses.clear();
+}
+
} // namespace ir
} // namespace onert
diff --git a/runtime/onert/core/src/ir/Operand.test.cc b/runtime/onert/core/src/ir/Operand.test.cc
new file mode 100644
index 000000000..0b858792a
--- /dev/null
+++ b/runtime/onert/core/src/ir/Operand.test.cc
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/Graph.h"
+
+#include "MockNode.h"
+#include "verifier/Verifier.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <typeindex>
+
+namespace
+{
+
+using IndexSet = onert::ir::OperandIndexSequence;
+using Mock = onert_test::ir::SimpleMock;
+
+} // namespace
+
+TEST(ir_Operand, neg_usedef)
+{
+ onert::ir::Graph graph;
+ onert::ir::verifier::DAGChecker verifier;
+
+ onert::ir::Shape shape(3);
+ onert::ir::TypeInfo type{onert::ir::DataType::INT32};
+
+ // Model Input/Output
+ auto input_operand = graph.addOperand(shape, type);
+ auto output_operand = graph.addOperand(shape, type);
+
+ graph.addInput(input_operand);
+ graph.addOutput(output_operand);
+
+ // MockNode1
+ auto operand_index1 = graph.addOperand(shape, type);
+ auto mocknode_index1 =
+ graph.addOperation(std::make_unique<Mock>(IndexSet{input_operand}, IndexSet{operand_index1}));
+
+ // MockNode2
+ auto operand_index2 = graph.addOperand(shape, type);
+ auto mocknode_index2 =
+ graph.addOperation(std::make_unique<Mock>(IndexSet{input_operand}, IndexSet{operand_index2}));
+
+ // MockNode3(two input)
+ auto multiinput_index = graph.addOperation(
+ std::make_unique<Mock>(IndexSet{operand_index1, operand_index2}, IndexSet{output_operand}));
+
+ graph.verify();
+
+ ASSERT_TRUE(verifier.verify(graph));
+
+ // Check def
+ ASSERT_EQ(graph.operands().at(operand_index1).getDef(), mocknode_index1);
+ ASSERT_EQ(graph.operands().at(operand_index2).getDef(), mocknode_index2);
+ ASSERT_EQ(graph.operands().at(output_operand).getDef(), multiinput_index);
+
+ ASSERT_NE(graph.operands().at(operand_index1).getDef(), mocknode_index2);
+ ASSERT_NE(graph.operands().at(operand_index1).getDef(), multiinput_index);
+
+ // Check use
+ ASSERT_EQ(graph.operands().at(input_operand).getUses().contains(mocknode_index1), true);
+ ASSERT_EQ(graph.operands().at(input_operand).getUses().contains(mocknode_index2), true);
+ ASSERT_EQ(graph.operands().at(input_operand).getUses().contains(multiinput_index), false);
+ ASSERT_EQ(graph.operands().at(operand_index1).getUses().contains(multiinput_index), true);
+ ASSERT_EQ(graph.operands().at(operand_index2).getUses().contains(multiinput_index), true);
+
+ ASSERT_EQ(graph.operands().at(input_operand).getUses().size(), 2);
+ ASSERT_EQ(graph.operands().at(operand_index1).getUses().size(), 1);
+ ASSERT_EQ(graph.operands().at(output_operand).getUses().size(), 0);
+}
diff --git a/runtime/onert/core/src/ir/OperandIndexSequence.cc b/runtime/onert/core/src/ir/OperandIndexSequence.cc
index 73f928280..a15b6d0d6 100644
--- a/runtime/onert/core/src/ir/OperandIndexSequence.cc
+++ b/runtime/onert/core/src/ir/OperandIndexSequence.cc
@@ -31,7 +31,7 @@ OperandIndexSequence::OperandIndexSequence(std::initializer_list<OperandIndex> l
OperandIndexSequence::OperandIndexSequence(std::initializer_list<int32_t> list)
{
- for (auto val : list)
+ for (auto &&val : list)
{
_vec.emplace_back(static_cast<uint32_t>(val));
}
@@ -39,7 +39,7 @@ OperandIndexSequence::OperandIndexSequence(std::initializer_list<int32_t> list)
OperandIndexSequence::OperandIndexSequence(std::initializer_list<uint32_t> list)
{
- for (auto val : list)
+ for (auto &&val : list)
{
_vec.emplace_back(val);
}
@@ -55,6 +55,11 @@ void OperandIndexSequence::replace(const OperandIndex &from, const OperandIndex
std::replace(_vec.begin(), _vec.end(), from, to);
}
+bool OperandIndexSequence::operator==(const OperandIndexSequence &other) const
+{
+ return _vec == other._vec;
+}
+
OperandIndexSequence OperandIndexSequence::operator+(const OperandIndexSequence &other) const
{
OperandIndexSequence ret = *this;
@@ -62,10 +67,10 @@ OperandIndexSequence OperandIndexSequence::operator+(const OperandIndexSequence
return ret;
}
-std::ostream &operator<<(std::ostream &o, const OperandIndexSequence &op_seq)
+std::ostream &operator<<(std::ostream &o, const OperandIndexSequence &operand_seq)
{
std::string delimeter;
- for (const auto &ind : op_seq._vec)
+ for (const auto &ind : operand_seq._vec)
{
o << delimeter << ind;
delimeter = ',';
diff --git a/runtime/onert/core/src/ir/OperandIndexSequence.test.cc b/runtime/onert/core/src/ir/OperandIndexSequence.test.cc
new file mode 100644
index 000000000..588c4e419
--- /dev/null
+++ b/runtime/onert/core/src/ir/OperandIndexSequence.test.cc
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/OperandIndexSequence.h"
+
+#include <gtest/gtest.h>
+
+using onert::ir::OperandIndex;
+using onert::ir::OperandIndexSequence;
+
+TEST(ir_OperandIndexSequence, neg_append)
+{
+ OperandIndexSequence iset{0, 2, 4, 8};
+
+ ASSERT_EQ(iset.size(), 4);
+
+ iset.append(OperandIndex{10});
+
+ ASSERT_EQ(iset.size(), 5);
+
+ onert::ir::IOIndex index1{1};
+ onert::ir::IOIndex index2{4};
+
+ ASSERT_EQ(iset.at(index1), 2);
+ ASSERT_EQ(iset.at(index2), 10);
+
+ ASSERT_TRUE(iset.contains(OperandIndex{2}));
+ ASSERT_TRUE(iset.contains(OperandIndex{10}));
+ ASSERT_FALSE(iset.contains(OperandIndex{11}));
+}
+
+TEST(graph_OperandIndexSequence, neg_replace)
+{
+ OperandIndexSequence iset{0, 1, 2, 3};
+
+ iset.replace(OperandIndex{1}, OperandIndex{9});
+ ASSERT_FALSE(iset.contains(OperandIndex{1}));
+ ASSERT_TRUE(iset.contains(OperandIndex{9}));
+}
diff --git a/runtime/onert/core/src/ir/Operands.cc b/runtime/onert/core/src/ir/Operands.cc
index ab32e478a..f8cfd16ef 100644
--- a/runtime/onert/core/src/ir/Operands.cc
+++ b/runtime/onert/core/src/ir/Operands.cc
@@ -29,7 +29,7 @@ Operands::Operands(const Operands &obj)
obj.iterate([&](const OperandIndex &index, const Operand &operand) {
_objects.emplace(index, std::make_unique<Operand>(operand));
});
- _index_count = obj._index_count;
+ _next_index = obj._next_index;
}
} // namespace ir
diff --git a/runtime/onert/core/src/ir/Operands.test.cc b/runtime/onert/core/src/ir/Operands.test.cc
new file mode 100644
index 000000000..aff228b10
--- /dev/null
+++ b/runtime/onert/core/src/ir/Operands.test.cc
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/Operands.h"
+
+#include <gtest/gtest.h>
+
+TEST(ir_Operands, neg_set_test)
+{
+ onert::ir::Operands set;
+
+ onert::ir::Shape shape0{1, 2, 3};
+
+ onert::ir::Shape shape1(4);
+ shape1.dim(0) = 10;
+ shape1.dim(1) = 20;
+ shape1.dim(2) = 30;
+ shape1.dim(3) = 40;
+
+ onert::ir::TypeInfo type{onert::ir::DataType::INT32};
+
+ set.emplace(shape0, type);
+ set.emplace(shape1, type);
+
+ ASSERT_EQ(set.exist(onert::ir::OperandIndex{0u}), true);
+ ASSERT_EQ(set.exist(onert::ir::OperandIndex{1u}), true);
+ ASSERT_EQ(set.exist(onert::ir::OperandIndex{2u}), false);
+
+ ASSERT_EQ(set.at(onert::ir::OperandIndex{0u}).shape().dim(0), 1);
+ ASSERT_EQ(set.at(onert::ir::OperandIndex{0u}).shape().dim(1), 2);
+ ASSERT_EQ(set.at(onert::ir::OperandIndex{0u}).shape().dim(2), 3);
+}
diff --git a/runtime/onert/core/src/ir/Operation.cc b/runtime/onert/core/src/ir/Operation.cc
index 04be8c0d9..64792525d 100644
--- a/runtime/onert/core/src/ir/Operation.cc
+++ b/runtime/onert/core/src/ir/Operation.cc
@@ -24,22 +24,33 @@ namespace ir
{
Operation::Operation(OperandConstraint input_constr, const OperandIndexSequence &inputs,
- const OperandIndexSequence &outputs)
- : _input_constr{input_constr}, _inputs{inputs}, _outputs{outputs}
+ const OperandIndexSequence &outputs, OperandConstraint output_constr)
+ : _input_constr{input_constr}, _output_constr{output_constr}
{
+ setInputs(inputs);
+ setOutputs(outputs);
}
-Operation::Operation(OperandConstraint input_constr) : _input_constr{input_constr} {}
+Operation::Operation(OperandConstraint input_constr, OperandConstraint output_constr)
+ : _input_constr{input_constr}, _output_constr{output_constr}
+{
+}
Operation::~Operation() = default;
void Operation::setInputs(const OperandIndexSequence &indexes)
{
- assert(_input_constr.check(indexes.size()));
+ if (!_input_constr.check(indexes.size()))
+ throw std::runtime_error{"Invalid number of input tensors for this operation."};
_inputs = indexes;
}
-void Operation::setOutputs(const OperandIndexSequence &indexes) { _outputs = indexes; }
+void Operation::setOutputs(const OperandIndexSequence &indexes)
+{
+ if (!_output_constr.check(indexes.size()))
+ throw std::runtime_error{"Invalid number of output tensors for this operation."};
+ _outputs = indexes;
+}
void Operation::replaceInputs(const OperandIndex &from, const OperandIndex &to)
{
diff --git a/runtime/onert/core/src/ir/Operation.test.cc b/runtime/onert/core/src/ir/Operation.test.cc
new file mode 100644
index 000000000..b3c4e852d
--- /dev/null
+++ b/runtime/onert/core/src/ir/Operation.test.cc
@@ -0,0 +1,98 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/Graph.h"
+#include "ir/Index.h"
+#include "ir/OperandIndexSequence.h"
+#include "ir/operation/Concat.h"
+#include "ir/operation/Conv2D.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <stdexcept>
+
+using Index = onert::ir::IOIndex;
+using IndexSet = onert::ir::OperandIndexSequence;
+
+TEST(ir_Operation_setIO, operation_setIO_conv)
+{
+ onert::ir::Graph graph;
+
+ onert::ir::Shape shape{3};
+ onert::ir::TypeInfo type{onert::ir::DataType::INT32};
+
+ // Add Conv
+ using Graph = onert::ir::operation::Conv2D;
+
+ auto input_operand = graph.addOperand(shape, type);
+ auto kernel_operand = graph.addOperand(shape, type);
+ auto bias_operand = graph.addOperand(shape, type);
+ IndexSet inputs{input_operand, kernel_operand, bias_operand};
+
+ Graph::Param conv_params;
+ conv_params.padding.type = onert::ir::PaddingType::SAME;
+ conv_params.stride.horizontal = 1;
+ conv_params.stride.vertical = 1;
+ conv_params.activation = onert::ir::Activation::NONE;
+
+ auto output_operand = graph.addOperand(shape, type).value();
+ IndexSet outputs{output_operand};
+
+ auto conv = std::make_unique<Graph>(inputs, outputs, conv_params);
+
+ ASSERT_NE(conv, nullptr);
+ ASSERT_EQ(conv->getInputs().at(Index{0}).value(), inputs.at(0).value());
+ conv->setInputs({8, 9, 10});
+ ASSERT_NE(conv->getInputs().at(Index{0}).value(), inputs.at(0).value());
+ ASSERT_EQ(conv->getInputs().at(Index{0}).value(), 8);
+}
+
+TEST(ir_Operation_setIO, neg_operation_setIO_concat)
+{
+ onert::ir::Graph graph;
+
+ onert::ir::Shape shape{3};
+
+ onert::ir::TypeInfo type{onert::ir::DataType::INT32};
+
+ using Graph = onert::ir::operation::Concat;
+
+ // Add Concat
+ IndexSet inputs;
+ for (int i = 0; i < 6; ++i)
+ {
+ inputs.append(graph.addOperand(shape, type));
+ }
+
+ Graph::Param concat_params{0};
+
+ auto output_operand = graph.addOperand(shape, type).value();
+ IndexSet outputs{output_operand};
+
+ auto concat = std::make_unique<Graph>(inputs, outputs, concat_params);
+
+ ASSERT_NE(concat, nullptr);
+ ASSERT_EQ(concat->getInputs().size(), 6);
+ ASSERT_EQ(concat->getInputs().at(Index{0}).value(), inputs.at(0).value());
+
+ concat->setInputs({80, 6, 9, 11});
+ ASSERT_EQ(concat->getInputs().size(), 4);
+ ASSERT_NE(concat->getInputs().at(Index{0}).value(), inputs.at(0).value());
+ ASSERT_EQ(concat->getInputs().at(Index{0}).value(), 80);
+ ASSERT_EQ(concat->getInputs().at(Index{2}).value(), 9);
+ ASSERT_THROW(concat->getInputs().at(Index{5}), std::out_of_range);
+}
diff --git a/runtime/onert/core/src/ir/OperationCloner.cc b/runtime/onert/core/src/ir/OperationCloner.cc
index b4e60f0bc..64e1cc807 100644
--- a/runtime/onert/core/src/ir/OperationCloner.cc
+++ b/runtime/onert/core/src/ir/OperationCloner.cc
@@ -23,6 +23,23 @@ namespace onert
namespace ir
{
+namespace
+{
+
+class OperationCloner : public OperationVisitor
+{
+public:
+#define OP(Name) void visit(const operation::Name &o) override;
+#include "ir/Operations.lst"
+#undef OP
+
+public:
+ std::unique_ptr<Operation> releaseClone();
+
+private:
+ std::unique_ptr<Operation> _return_op;
+};
+
#define OP(Name) \
void OperationCloner::visit(const operation::Name &o) \
{ \
@@ -38,5 +55,14 @@ std::unique_ptr<Operation> OperationCloner::releaseClone()
return std::move(_return_op);
}
+} // namespace
+
+std::unique_ptr<Operation> clone(const IOperation &operation)
+{
+ OperationCloner cloner;
+ operation.accept(cloner);
+ return cloner.releaseClone();
+}
+
} // namespace ir
} // namespace onert
diff --git a/runtime/onert/core/src/ir/OperationCloner.h b/runtime/onert/core/src/ir/OperationCloner.h
index 0e8cda2a0..49297a05c 100644
--- a/runtime/onert/core/src/ir/OperationCloner.h
+++ b/runtime/onert/core/src/ir/OperationCloner.h
@@ -26,19 +26,7 @@ namespace onert
namespace ir
{
-class OperationCloner : public OperationVisitor
-{
-public:
-#define OP(Name) void visit(const operation::Name &o) override;
-#include "ir/Operations.lst"
-#undef OP
-
-public:
- std::unique_ptr<Operation> releaseClone();
-
-private:
- std::unique_ptr<Operation> _return_op;
-};
+std::unique_ptr<Operation> clone(const IOperation &operation);
} // namespace ir
} // namespace onert
diff --git a/runtime/onert/core/src/ir/OperationDumper.cc b/runtime/onert/core/src/ir/OperationDumper.cc
index 48361f464..5aa4693ad 100644
--- a/runtime/onert/core/src/ir/OperationDumper.cc
+++ b/runtime/onert/core/src/ir/OperationDumper.cc
@@ -29,19 +29,21 @@ using namespace operation;
namespace
{
-void dumpUnaryInputOp(const Operation &node, const std::string &adding_input = "")
+
+// Dump all input and output.
+// Use this function when there is no special input or(and) output.
+void dumpOpGeneric(const Operation &node, const std::string &adding_input = "")
{
VERBOSE(LIR) << "* " << node.name() << std::endl;
- VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs().at(0) << ") " << adding_input
- << std::endl;
- VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl;
+ VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs() << ") " << adding_input << std::endl;
+ VERBOSE(LIR) << " - Output : Output(" << node.getOutputs() << ")" << std::endl;
}
-void dumpBinaryInputOp(const Operation &node, const std::string &adding_input = "")
+void dumpUnaryInputOp(const Operation &node, const std::string &adding_input = "")
{
VERBOSE(LIR) << "* " << node.name() << std::endl;
- VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs().at(0) << ", " << node.getInputs().at(0)
- << ") " << adding_input << std::endl;
+ VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs().at(0) << ") " << adding_input
+ << std::endl;
VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl;
}
@@ -53,18 +55,6 @@ void dumpConvOp(const Operation &node, const std::string &padding_type)
<< node.getInputs().at(Conv2D::Input::BIAS) << ")" << std::endl;
VERBOSE(LIR) << " - Output : OFM(" << node.getOutputs().at(0) << ")" << std::endl;
}
-
-void dumpPackingOp(const Operation &node)
-{
- VERBOSE(LIR) << "* " << node.name() << std::endl;
- std::string inputs;
- for (auto i : node.getInputs())
- {
- inputs += std::to_string(i.value()) + ",";
- }
- VERBOSE(LIR) << " - Inputs : Inputs(" << inputs << ")" << std::endl;
- VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl;
-}
} // namespace
OperationDumper::OperationDumper(const std::string &start_msg)
@@ -72,41 +62,62 @@ OperationDumper::OperationDumper(const std::string &start_msg)
VERBOSE(LIR) << start_msg << std::endl;
}
-void OperationDumper::visit(const ArgMax &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const ArgMinMax &node)
+{
+ std::string min_max = node.param().is_arg_max ? "(Max)" : "(Min)";
+ VERBOSE(LIR) << "* " << node.name() << min_max << std::endl;
+ VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs().at(ArgMinMax::INPUT) << ") Axis("
+ << node.getInputs().at(ArgMinMax::AXIS) << ") " << std::endl;
+ VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl;
+}
void OperationDumper::visit(const BatchToSpaceND &node)
{
std::string block_size =
- "BlockSize(" +
- std::to_string(node.getInputs().at(BatchToSpaceND::Input::BLOCK_SIZE).value()) + ")";
- dumpUnaryInputOp(node, block_size);
+ "BlockSize(" + std::to_string(node.getInputs().at(BatchToSpaceND::Input::BLOCK_SIZE).value()) +
+ ")";
+ dumpOpGeneric(node, block_size);
}
-void OperationDumper::visit(const BinaryArithmetic &node) { dumpBinaryInputOp(node); }
+void OperationDumper::visit(const BCQFullyConnected &node)
+{
+ VERBOSE(LIR) << "* " << node.name() << std::endl;
+ VERBOSE(LIR) << " - Inputs : IFM(" << node.getInputs().at(BCQFullyConnected::Input::INPUT)
+ << ") WeightsBinary("
+ << node.getInputs().at(BCQFullyConnected::Input::WEIGHTS_BINARY)
+ << ") WeightsScales("
+ << node.getInputs().at(BCQFullyConnected::Input::WEIGHTS_SCALES)
+ << ") WeightsClusters("
+ << node.getInputs().at(BCQFullyConnected::Input::WEIGHTS_CLUSTERS) << ") Bias("
+ << node.getInputs().at(BCQFullyConnected::Input::BIAS) << ")" << std::endl;
+ VERBOSE(LIR) << " - Output : OFM(" << node.getOutputs().at(0) << ")" << std::endl;
+}
+
+void OperationDumper::visit(const BinaryArithmetic &node) { dumpOpGeneric(node); }
-void OperationDumper::visit(const operation::BroadcastTo &node) { dumpBinaryInputOp(node); }
+void OperationDumper::visit(const operation::BroadcastTo &node) { dumpOpGeneric(node); }
-void OperationDumper::visit(const Comparison &node) { dumpBinaryInputOp(node); }
+void OperationDumper::visit(const Comparison &node) { dumpOpGeneric(node); }
-void OperationDumper::visit(const Concat &node) { dumpPackingOp(node); }
+void OperationDumper::visit(const Concat &node) { dumpOpGeneric(node); }
void OperationDumper::visit(const Conv2D &node)
{
std::string padding_type =
- node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit";
+ node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit";
dumpConvOp(node, padding_type);
}
-void OperationDumper::visit(const ConvertFp16ToFp32 &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const ConvertFp16ToFp32 &node) { dumpOpGeneric(node); }
-void OperationDumper::visit(const ConvertFp32ToFp16 &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const ConvertFp32ToFp16 &node) { dumpOpGeneric(node); }
-void OperationDumper::visit(const DepthToSpace &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const DepthToSpace &node) { dumpOpGeneric(node); }
void OperationDumper::visit(const DepthwiseConv2D &node)
{
std::string padding_type =
- node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit";
+ node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit";
dumpConvOp(node, padding_type);
}
@@ -122,12 +133,12 @@ void OperationDumper::visit(const ElementwiseActivation &node)
{
params = " alpha value(" + std::to_string(node.param().alpha) + ")";
}
- dumpUnaryInputOp(node, params);
+ dumpOpGeneric(node, params);
}
-void OperationDumper::visit(const ElementwiseBinary &node) { dumpBinaryInputOp(node); }
+void OperationDumper::visit(const ElementwiseBinary &node) { dumpOpGeneric(node); }
-void OperationDumper::visit(const ElementwiseUnary &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const ElementwiseUnary &node) { dumpOpGeneric(node); }
void OperationDumper::visit(const EmbeddingLookup &node)
{
@@ -141,22 +152,31 @@ void OperationDumper::visit(const EmbeddingLookup &node)
void OperationDumper::visit(const ExpandDims &node)
{
std::string axis =
- "AXIS(" + std::to_string(node.getInputs().at(ExpandDims::Input::AXIS).value()) + ")";
+ "AXIS(" + std::to_string(node.getInputs().at(ExpandDims::Input::AXIS).value()) + ")";
dumpUnaryInputOp(node, axis);
}
+void OperationDumper::visit(const Fill &node)
+{
+ VERBOSE(LIR) << "* " << node.name() << std::endl;
+ VERBOSE(LIR) << " - Inputs : Shape(" << node.getInputs().at(Fill::Input::SHAPE) << ") Value("
+ << node.getInputs().at(Fill::Input::VALUE) << ")" << std::endl;
+ VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl;
+}
+
void OperationDumper::visit(const FullyConnected &node)
{
- std::string inputs =
- "Weight(" + std::to_string(node.getInputs().at(FullyConnected::Input::WEIGHT).value()) +
- ") Bias(" + std::to_string(node.getInputs().at(FullyConnected::Input::BIAS).value()) + ")";
- dumpUnaryInputOp(node, inputs);
+ VERBOSE(LIR) << "* " << node.name() << std::endl;
+ VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs().at(ArgMinMax::INPUT) << ") Weight("
+ << node.getInputs().at(FullyConnected::Input::WEIGHT) << ") Bias("
+ << node.getInputs().at(FullyConnected::Input::BIAS) << ")" << std::endl;
+ VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl;
}
void OperationDumper::visit(const Gather &node)
{
std::string indices =
- "Indices(" + std::to_string(node.getInputs().at(Gather::Input::INDICES).value()) + ")";
+ "Indices(" + std::to_string(node.getInputs().at(Gather::Input::INDICES).value()) + ")";
dumpUnaryInputOp(node, indices);
}
@@ -174,50 +194,70 @@ void OperationDumper::visit(const HashtableLookup &node)
void OperationDumper::visit(const InstanceNorm &node)
{
std::string inputs =
- "Gamma(" + std::to_string(node.getInputs().at(InstanceNorm::Input::GAMMA).value()) +
- ") Beta(" + std::to_string(node.getInputs().at(InstanceNorm::Input::BETA).value()) + ")";
+ "Gamma(" + std::to_string(node.getInputs().at(InstanceNorm::Input::GAMMA).value()) + ") Beta(" +
+ std::to_string(node.getInputs().at(InstanceNorm::Input::BETA).value()) + ")";
dumpUnaryInputOp(node, inputs);
}
-void OperationDumper::visit(const L2Normalization &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const L2Normalization &node) { dumpOpGeneric(node); }
-void OperationDumper::visit(const LocalResponseNormalization &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const LocalResponseNormalization &node) { dumpOpGeneric(node); }
+
+void OperationDumper::visit(const Loss &node)
+{
+ VERBOSE(LIR) << "* " << node.name() << std::endl;
+ VERBOSE(LIR) << " - Inputs : Prediction(" << node.getInputs().at(Loss::Input::Y_PRED) << ") True("
+ << node.getInputs().at(Loss::Input::Y_TRUE) << ")" << std::endl;
+ VERBOSE(LIR) << " - Outputs : Output(" << node.getOutputs().at(0) << ")" << std::endl;
+}
void OperationDumper::visit(const LSTM &node)
{
+ VERBOSE(LIR) << "* " << node.name() << std::endl;
VERBOSE(LIR)
- << " - Inputs : Input(" << node.getInputs().at(LSTM::Input::INPUT)
- << ") Input To Input Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_INPUT_WEIGHTS)
- << ") Input To Forget Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_FORGET_WEIGHTS)
- << ") Input To Cell Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_CELL_WEIGHTS)
- << ") Input To Output Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)
- << ") Recurrent To Input Weights("
- << node.getInputs().at(LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)
- << ") Recurrent To Forget Weights("
- << node.getInputs().at(LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)
- << ") Recurrent To Cell Weights("
- << node.getInputs().at(LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)
- << ") Recurrent To Output Weights("
- << node.getInputs().at(LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS) << ") Cell To Input Weights("
- << node.getInputs().at(LSTM::Input::CELL_TO_INPUT_WEIGHTS) << ") Cell To Forget Weights("
- << node.getInputs().at(LSTM::Input::CELL_TO_FORGET_WEIGHTS) << ") Cell To OUTPUT Weights("
- << node.getInputs().at(LSTM::Input::CELL_TO_OUTPUT_WEIGHTS) << ") Input Gate Bias("
- << node.getInputs().at(LSTM::Input::INPUT_GATE_BIAS) << ") Forget Gate Bias("
- << node.getInputs().at(LSTM::Input::FORGET_GATE_BIAS) << ") Cell Bias("
- << node.getInputs().at(LSTM::Input::CELL_BIAS) << ") Output Gate Bias("
- << node.getInputs().at(LSTM::Input::OUTPUT_GATE_BIAS) << ") Projection Weights("
- << node.getInputs().at(LSTM::Input::PROJECTION_WEIGHTS) << ") Projection Bias("
- << node.getInputs().at(LSTM::Input::PROJECTION_BIAS) << ") Output State In("
- << node.getInputs().at(LSTM::Input::OUTPUT_STATE_IN) << ") Cell State In("
- << node.getInputs().at(LSTM::Input::CELL_STATE_IN) << ")" << std::endl;
+ << " - Inputs : Input(" << node.getInputs().at(LSTM::Input::INPUT)
+ << ") Input To Input Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_INPUT_WEIGHTS)
+ << ") Input To Forget Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_FORGET_WEIGHTS)
+ << ") Input To Cell Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_CELL_WEIGHTS)
+ << ") Input To Output Weights(" << node.getInputs().at(LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)
+ << ") Recurrent To Input Weights("
+ << node.getInputs().at(LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)
+ << ") Recurrent To Forget Weights("
+ << node.getInputs().at(LSTM::Input::RECURRENT_TO_FORGET_WEIGHTS)
+ << ") Recurrent To Cell Weights(" << node.getInputs().at(LSTM::Input::RECURRENT_TO_CELL_WEIGHTS)
+ << ") Recurrent To Output Weights("
+ << node.getInputs().at(LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS) << ") Cell To Input Weights("
+ << node.getInputs().at(LSTM::Input::CELL_TO_INPUT_WEIGHTS) << ") Cell To Forget Weights("
+ << node.getInputs().at(LSTM::Input::CELL_TO_FORGET_WEIGHTS) << ") Cell To OUTPUT Weights("
+ << node.getInputs().at(LSTM::Input::CELL_TO_OUTPUT_WEIGHTS) << ") Input Gate Bias("
+ << node.getInputs().at(LSTM::Input::INPUT_GATE_BIAS) << ") Forget Gate Bias("
+ << node.getInputs().at(LSTM::Input::FORGET_GATE_BIAS) << ") Cell Bias("
+ << node.getInputs().at(LSTM::Input::CELL_BIAS) << ") Output Gate Bias("
+ << node.getInputs().at(LSTM::Input::OUTPUT_GATE_BIAS) << ") Projection Weights("
+ << node.getInputs().at(LSTM::Input::PROJECTION_WEIGHTS) << ") Projection Bias("
+ << node.getInputs().at(LSTM::Input::PROJECTION_BIAS) << ") Output State In("
+ << node.getInputs().at(LSTM::Input::OUTPUT_STATE_IN) << ") Cell State In("
+ << node.getInputs().at(LSTM::Input::CELL_STATE_IN);
+ if (node.getInputs().size() == 24)
+ {
+ VERBOSE(LIR) << ") Input Layer Normalization Weights("
+ << node.getInputs().at(LSTM::Input::INPUT_LAYER_NORMALIZATION_WEIGHTS)
+ << ") Forget Layer Normalization Weights("
+ << node.getInputs().at(LSTM::Input::FORGET_LAYER_NORMALIZATION_WEIGHTS)
+ << ") Cell Layer Normalization Weights("
+ << node.getInputs().at(LSTM::Input::CELL_LAYER_NORMALIZATION_WEIGHTS)
+ << ") Ouput Layer Normalization Weights("
+ << node.getInputs().at(LSTM::Input::OUTPUT_LAYER_NORMALIZATION_WEIGHTS);
+ }
+ VERBOSE(LIR) << ")" << std::endl;
VERBOSE(LIR) << " - Output : Scratch Buffer("
<< node.getOutputs().at(LSTM::Output::SCRATCH_BUFFER) << ") Output State Out("
- << node.getInputs().at(LSTM::Output::OUTPUT_STATE_OUT) << ") Cell State Out("
- << node.getInputs().at(LSTM::Output::CELL_STATE_OUT) << ") Output("
- << node.getInputs().at(LSTM::Output::OUTPUT) << ")" << std::endl;
+ << node.getOutputs().at(LSTM::Output::OUTPUT_STATE_OUT) << ") Cell State Out("
+ << node.getOutputs().at(LSTM::Output::CELL_STATE_OUT) << ") Output("
+ << node.getOutputs().at(LSTM::Output::OUTPUT) << ")" << std::endl;
}
-void OperationDumper::visit(const Pack &node) { dumpPackingOp(node); }
+void OperationDumper::visit(const Pack &node) { dumpOpGeneric(node); }
void OperationDumper::visit(const Pad &node)
{
@@ -249,23 +289,23 @@ void OperationDumper::visit(const Permute &node)
void OperationDumper::visit(const Pool2D &node)
{
std::string padding_type =
- node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit";
+ node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit";
VERBOSE(LIR) << "* " << node.name() << "(" << padding_type << ")" << std::endl;
VERBOSE(LIR) << " - Inputs : IFM(" << node.getInputs().at(Pool2D::Input::INPUT) << ")"
<< std::endl;
VERBOSE(LIR) << " - Output : OFM(" << node.getOutputs().at(0) << ")" << std::endl;
}
-void OperationDumper::visit(const Pow &node) { dumpBinaryInputOp(node); }
+void OperationDumper::visit(const Pow &node) { dumpOpGeneric(node); }
void OperationDumper::visit(const PReLU &node)
{
std::string alpha =
- "Alpha(" + std::to_string(node.getInputs().at(PReLU::Input::ALPHA).value()) + ")";
- dumpUnaryInputOp(node, alpha);
+ "Alpha(" + std::to_string(node.getInputs().at(PReLU::Input::ALPHA).value()) + ")";
+ dumpOpGeneric(node, alpha);
}
-void OperationDumper::visit(const Rank &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const Rank &node) { dumpOpGeneric(node); }
void OperationDumper::visit(const Reduce &node) { dumpUnaryInputOp(node); }
@@ -273,18 +313,20 @@ void OperationDumper::visit(const Reshape &node)
{
// optional param
std::string shape =
- node.getInputs().size() == 2
- ? "Shape(" + std::to_string(node.getInputs().at(Reshape::Input::SHAPE).value()) + ")"
- : "Shape(not provided)";
+ node.getInputs().size() == 2
+ ? "Shape(" + std::to_string(node.getInputs().at(Reshape::Input::SHAPE).value()) + ")"
+ : "Shape(not provided)";
dumpUnaryInputOp(node, shape);
}
-void OperationDumper::visit(const ResizeBilinear &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const ResizeBilinear &node) { dumpOpGeneric(node); }
+
+void OperationDumper::visit(const ResizeNearestNeighbor &node) { dumpOpGeneric(node); }
void OperationDumper::visit(const Reverse &node)
{
std::string axis =
- "Axis(" + std::to_string(node.getInputs().at(Reverse::Input::AXIS).value()) + ")";
+ "Axis(" + std::to_string(node.getInputs().at(Reverse::Input::AXIS).value()) + ")";
dumpUnaryInputOp(node, axis);
}
@@ -320,25 +362,24 @@ void OperationDumper::visit(const Select &node)
VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl;
}
-void OperationDumper::visit(const ir::operation::Shape &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const ir::operation::Shape &node) { dumpOpGeneric(node); }
-void OperationDumper::visit(const Softmax &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const Softmax &node) { dumpOpGeneric(node); }
void OperationDumper::visit(const SpaceToBatchND &node)
{
std::string inputs =
- "BlockSize(" +
- std::to_string(node.getInputs().at(SpaceToBatchND::Input::BLOCK_SIZE).value()) +
- ") Paddings(" + std::to_string(node.getInputs().at(SpaceToBatchND::Input::PADDINGS).value()) +
- ")";
+ "BlockSize(" + std::to_string(node.getInputs().at(SpaceToBatchND::Input::BLOCK_SIZE).value()) +
+ ") Paddings(" + std::to_string(node.getInputs().at(SpaceToBatchND::Input::PADDINGS).value()) +
+ ")";
dumpUnaryInputOp(node, inputs);
}
-void OperationDumper::visit(const SpaceToDepth &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const SpaceToDepth &node) { dumpOpGeneric(node); }
-void OperationDumper::visit(const Split &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const Split &node) { dumpOpGeneric(node); }
-void OperationDumper::visit(const SquaredDifference &node) { dumpBinaryInputOp(node); }
+void OperationDumper::visit(const SquaredDifference &node) { dumpOpGeneric(node); }
void OperationDumper::visit(const StatelessRandomUniform &node)
{
@@ -349,7 +390,7 @@ void OperationDumper::visit(const StatelessRandomUniform &node)
VERBOSE(LIR) << " - Output : Output(" << node.getOutputs().at(0) << ")" << std::endl;
}
-void OperationDumper::visit(const Squeeze &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const Squeeze &node) { dumpOpGeneric(node); }
void OperationDumper::visit(const Slice &node) { dumpUnaryInputOp(node); }
@@ -358,7 +399,7 @@ void OperationDumper::visit(const StridedSlice &node) { dumpUnaryInputOp(node);
void OperationDumper::visit(const Tile &node)
{
std::string multiples =
- "Multiples(" + std::to_string(node.getInputs().at(Tile::Input::MULTIPLES).value()) + ")";
+ "Multiples(" + std::to_string(node.getInputs().at(Tile::Input::MULTIPLES).value()) + ")";
dumpUnaryInputOp(node, multiples);
}
@@ -375,7 +416,7 @@ void OperationDumper::visit(const TopKV2 &node)
void OperationDumper::visit(const TransposeConv &node)
{
std::string padding_type =
- node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit";
+ node.param().padding.type == PaddingType::EXPLICIT ? "Explicit" : "Implicit";
VERBOSE(LIR) << "* TransposeConv(" << padding_type << ")" << std::endl;
VERBOSE(LIR) << " - Inputs : Output Shape("
<< node.getInputs().at(TransposeConv::Input::OUTPUT_SHAPE) << ") KERNEL("
@@ -384,22 +425,14 @@ void OperationDumper::visit(const TransposeConv &node)
VERBOSE(LIR) << " - Output : OFM(" << node.getOutputs().at(0) << ")" << std::endl;
}
-void OperationDumper::visit(const Transpose &node) { dumpUnaryInputOp(node); }
+void OperationDumper::visit(const Transpose &node) { dumpOpGeneric(node); }
void OperationDumper::visit(const Unpack &node)
{
VERBOSE(LIR) << "* " << node.name() << std::endl;
VERBOSE(LIR) << " - Inputs : Input(" << node.getInputs().at(Unpack::Input::INPUT) << ")"
<< std::endl;
- std::string outputs;
- const auto &output_indices = node.getOutputs();
- for (auto it = std::begin(output_indices); it != std::end(output_indices); ++it)
- {
- outputs += std::to_string(it->value());
- if (std::next(it) != std::end(output_indices))
- outputs += ", ";
- }
- VERBOSE(LIR) << " - Outputs : Outputs(" << outputs << ")" << std::endl;
+ VERBOSE(LIR) << " - Output : Outputs(" << node.getOutputs() << ")" << std::endl;
}
void OperationDumper::visit(const OneHot &node)
@@ -413,51 +446,21 @@ void OperationDumper::visit(const OneHot &node)
void OperationDumper::visit(const If &node)
{
VERBOSE(LIR) << "* " << node.name() << std::endl;
- std::string inputs;
- const auto &input_indices = node.getInputs();
- for (auto it = std::begin(input_indices); it != std::end(input_indices); ++it)
- {
- inputs += std::to_string(it->value());
- if (std::next(it) != std::end(input_indices))
- inputs += ", ";
- }
VERBOSE(LIR) << " - Inputs : "
<< "Then subgraph (" << node.param().then_subg_index << ") Else subgraph ("
- << node.param().else_subg_index << ") Inputs(" << inputs << ")" << std::endl;
- std::string outputs;
- const auto &output_indices = node.getOutputs();
- for (auto it = std::begin(output_indices); it != std::end(output_indices); ++it)
- {
- outputs += std::to_string(it->value());
- if (std::next(it) != std::end(output_indices))
- outputs += ", ";
- }
- VERBOSE(LIR) << " - Output : Outputs(" << outputs << ")" << std::endl;
+ << node.param().else_subg_index << ") Inputs(" << node.getInputs() << ")"
+ << std::endl;
+ VERBOSE(LIR) << " - Output : Outputs(" << node.getOutputs() << ")" << std::endl;
}
void OperationDumper::visit(const While &node)
{
VERBOSE(LIR) << "* " << node.name() << std::endl;
- std::string inputs;
- const auto &input_indices = node.getInputs();
- for (auto it = std::begin(input_indices); it != std::end(input_indices); ++it)
- {
- inputs += std::to_string(it->value());
- if (std::next(it) != std::end(input_indices))
- inputs += ", ";
- }
VERBOSE(LIR) << " - Inputs : "
<< "Cond subgraph (" << node.param().cond_subg_index << ") Body subgraph ("
- << node.param().cond_subg_index << ") Inputs(" << inputs << ")" << std::endl;
- std::string outputs;
- const auto &output_indices = node.getOutputs();
- for (auto it = std::begin(output_indices); it != std::end(output_indices); ++it)
- {
- outputs += std::to_string(it->value());
- if (std::next(it) != std::end(output_indices))
- outputs += ", ";
- }
- VERBOSE(LIR) << " - Output : Outputs(" << outputs << ")" << std::endl;
+ << node.param().body_subg_index << ") Inputs(" << node.getInputs() << ")"
+ << std::endl;
+ VERBOSE(LIR) << " - Output : Outputs(" << node.getOutputs() << ")" << std::endl;
}
} // namespace ir
diff --git a/runtime/onert/core/src/ir/OperationDumper.h b/runtime/onert/core/src/ir/OperationDumper.h
index e8ab3b3cd..99bf869d5 100644
--- a/runtime/onert/core/src/ir/OperationDumper.h
+++ b/runtime/onert/core/src/ir/OperationDumper.h
@@ -31,8 +31,9 @@ public:
OperationDumper(const std::string &start_msg);
public:
- void visit(const operation::ArgMax &) override;
+ void visit(const operation::ArgMinMax &) override;
void visit(const operation::BatchToSpaceND &node) override;
+ void visit(const operation::BCQFullyConnected &node) override;
void visit(const operation::BinaryArithmetic &node) override;
void visit(const operation::BroadcastTo &) override;
void visit(const operation::Comparison &) override;
@@ -47,12 +48,14 @@ public:
void visit(const operation::ElementwiseUnary &) override;
void visit(const operation::EmbeddingLookup &) override;
void visit(const operation::ExpandDims &) override;
+ void visit(const operation::Fill &) override;
void visit(const operation::FullyConnected &node) override;
void visit(const operation::Gather &) override;
void visit(const operation::HashtableLookup &) override;
void visit(const operation::InstanceNorm &) override;
void visit(const operation::L2Normalization &) override;
void visit(const operation::LocalResponseNormalization &) override;
+ void visit(const operation::Loss &node) override;
void visit(const operation::LSTM &) override;
void visit(const operation::Pack &) override;
void visit(const operation::Pad &) override;
@@ -65,6 +68,7 @@ public:
void visit(const operation::Reduce &) override;
void visit(const operation::Reshape &node) override;
void visit(const operation::ResizeBilinear &) override;
+ void visit(const operation::ResizeNearestNeighbor &) override;
void visit(const operation::Reverse &) override;
void visit(const operation::RNN &) override;
void visit(const operation::Select &node) override;
diff --git a/runtime/onert/core/src/ir/OperationValidator.cc b/runtime/onert/core/src/ir/OperationValidator.cc
new file mode 100644
index 000000000..5598c4043
--- /dev/null
+++ b/runtime/onert/core/src/ir/OperationValidator.cc
@@ -0,0 +1,546 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "OperationValidator.h"
+
+#include "ir/Graph.h"
+#include "util/logging.h"
+
+#define OP_REQUIRES(EXP) \
+ do \
+ { \
+ if (!(EXP)) \
+ throw std::runtime_error("OperationValidator failed at line " + std::to_string(__LINE__)); \
+ } while (0)
+
+namespace onert
+{
+namespace ir
+{
+
+OperationValidator::OperationValidator(const Graph &graph)
+ : _operations{graph.operations()}, _operands{graph.operands()}
+{
+}
+
+void OperationValidator::operator()()
+{
+ _operations.iterate([&](const OperationIndex &, const IOperation &node) { node.accept(*this); });
+}
+
+DataType OperationValidator::operandType(const OperandIndex &idx)
+{
+ return _operands.at(idx).typeInfo().type();
+}
+
+bool OperationValidator::isConstant(const OperandIndex &idx)
+{
+ return _operands.at(idx).isConstant();
+}
+
+bool OperationValidator::isSameType(const OperandIndex &idx1, const OperandIndex &idx2)
+{
+ return operandType(idx1) == operandType(idx2);
+}
+
+bool OperationValidator::isSameQuantParam(const OperandIndex &idx1, const OperandIndex &idx2)
+{
+ if (_operands.at(idx1).typeInfo().scale() != _operands.at(idx2).typeInfo().scale())
+ return false;
+
+ if (_operands.at(idx1).typeInfo().zero_point() != _operands.at(idx2).typeInfo().zero_point())
+ return false;
+
+ return true;
+}
+
+bool OperationValidator::isValidType(const OperandIndex &idx, const DataType &type)
+{
+ return operandType(idx) == type;
+}
+
+bool OperationValidator::isValidType(const OperandIndex &idx,
+ std::initializer_list<DataType> valid_types)
+{
+ for (auto &&type_to_check : valid_types)
+ {
+ if (isValidType(idx, type_to_check))
+ {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+void OperationValidator::visit(const operation::AddN &node)
+{
+ const auto output_index(node.getOutputs().at(0));
+
+ int size = node.getInputs().size();
+ for (int i = 0; i < size; i++)
+ {
+ const auto input_index(node.getInputs().at(i));
+ OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32}));
+ OP_REQUIRES(isSameType(input_index, output_index));
+ }
+}
+
+void OperationValidator::visit(const operation::ArgMinMax &node)
+{
+ const auto input_index(node.getInputs().at(operation::ArgMinMax::Input::INPUT));
+ const auto axis_index(node.getInputs().at(operation::ArgMinMax::Input::AXIS));
+ const auto output_index(node.getOutputs().at(0));
+ const auto output_type = node.param().output_type;
+
+ OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::UINT8,
+ DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
+ OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64}));
+ OP_REQUIRES(isValidType(output_index, {DataType::INT32, DataType::INT64}));
+ OP_REQUIRES(isValidType(output_index, output_type));
+}
+
+void OperationValidator::visit(const operation::BatchMatMul &node)
+{
+ const auto lhs_index(node.getInputs().at(operation::BatchMatMul::Input::LHS));
+ const auto rhs_index(node.getInputs().at(operation::BatchMatMul::Input::RHS));
+ const auto output_index(node.getOutputs().at(0));
+
+ // Constant lhs and rhs is not implemented yet
+ OP_REQUIRES(!isConstant(lhs_index) && !isConstant(rhs_index));
+
+ // Allow hybrid quantization (lhs: float / rhs: qint8 / out: float)
+ OP_REQUIRES(isValidType(
+ lhs_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
+ OP_REQUIRES(isSameType(lhs_index, rhs_index) ||
+ ((operandType(lhs_index) == DataType::FLOAT32) &&
+ (operandType(rhs_index) == DataType::QUANT_INT8_ASYMM)));
+ OP_REQUIRES(isSameType(lhs_index, output_index));
+}
+
+void OperationValidator::visit(const operation::BatchToSpaceND &node)
+{
+ const auto input_index{node.getInputs().at(operation::BatchToSpaceND::Input::INPUT)};
+ const auto output_index{node.getOutputs().at(0)};
+
+ OP_REQUIRES(isSameType(input_index, output_index));
+}
+
+void OperationValidator::visit(const operation::BinaryArithmetic &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto lhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::LHS)};
+ const auto rhs_index{node.getInputs().at(operation::BinaryArithmetic::Input::RHS)};
+
+ OP_REQUIRES(isSameType(lhs_index, rhs_index));
+ OP_REQUIRES(isSameType(lhs_index, output_index));
+}
+
+void OperationValidator::visit(const operation::Comparison &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+
+ const auto lhs_index{node.getInputs().at(operation::Comparison::Input::INPUT0)};
+ const auto rhs_index{node.getInputs().at(operation::Comparison::Input::INPUT1)};
+
+ OP_REQUIRES(isSameType(lhs_index, rhs_index));
+ OP_REQUIRES(isValidType(output_index, DataType::BOOL8));
+}
+
+void OperationValidator::visit(const operation::Concat &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+
+ for (auto &&input_index : node.getInputs())
+ {
+ OP_REQUIRES(isSameType(input_index, output_index));
+
+ // Int8 quantization requires same scale and zero point
+ if (isValidType(output_index, DataType::QUANT_INT8_ASYMM))
+ {
+ OP_REQUIRES(isSameQuantParam(input_index, output_index));
+ }
+ }
+}
+
+void OperationValidator::visit(const operation::Conv2D &node)
+{
+ const auto input_index{node.getInputs().at(operation::Conv2D::Input::INPUT)};
+ const auto kernel_index{node.getInputs().at(operation::Conv2D::Input::KERNEL)};
+ const auto output_index{node.getOutputs().at(0)};
+
+ uint32_t stride_horizontal = node.param().stride.horizontal;
+ uint32_t stride_vertical = node.param().stride.vertical;
+ uint32_t dilation_width = node.param().dilation.width_factor;
+ uint32_t dilation_height = node.param().dilation.height_factor;
+
+ OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0));
+ OP_REQUIRES((dilation_width > 0) && (dilation_height > 0));
+ OP_REQUIRES(isSameType(input_index, output_index));
+
+ if (isConstant(kernel_index) && operandType(kernel_index) == DataType::QUANT_INT8_ASYMM)
+ {
+ for (const auto zeropoint : _operands.at(kernel_index).typeInfo().zero_points())
+ OP_REQUIRES(zeropoint == 0);
+ }
+}
+
+void OperationValidator::visit(const operation::DepthToSpace &node)
+{
+ const auto input_index{node.getInputs().at(operation::DepthToSpace::Input::INPUT)};
+ const auto output_index{node.getOutputs().at(0)};
+
+ int32_t block_size = node.param().block_size;
+
+ OP_REQUIRES(isValidType(input_index, {DataType::FLOAT32, DataType::INT32, DataType::INT64,
+ DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
+ OP_REQUIRES(isSameType(input_index, output_index));
+
+ OP_REQUIRES(block_size > 0);
+}
+
+void OperationValidator::visit(const operation::DetectionPostProcess &node)
+{
+ const auto &param = node.param();
+
+ // FIXME: number of classes should be 1 for now.
+ OP_REQUIRES(param.num_classes == 1);
+}
+
+void OperationValidator::visit(const operation::DepthwiseConv2D &node)
+{
+ const auto input_index{node.getInputs().at(operation::DepthwiseConv2D::Input::INPUT)};
+ const auto kernel_index{node.getInputs().at(operation::DepthwiseConv2D::Input::KERNEL)};
+ const auto output_index{node.getOutputs().at(0)};
+
+ uint32_t stride_horizontal = node.param().stride.horizontal;
+ uint32_t stride_vertical = node.param().stride.vertical;
+ uint32_t dilation_width = node.param().dilation.width_factor;
+ uint32_t dilation_height = node.param().dilation.height_factor;
+
+ OP_REQUIRES((stride_horizontal > 0) && (stride_vertical > 0));
+ OP_REQUIRES((dilation_width > 0) && (dilation_height > 0));
+ OP_REQUIRES(isSameType(input_index, output_index));
+
+ if (isConstant(kernel_index) && operandType(kernel_index) == DataType::QUANT_INT8_ASYMM)
+ {
+ for (const auto zeropoint : _operands.at(kernel_index).typeInfo().zero_points())
+ OP_REQUIRES(zeropoint == 0);
+ }
+}
+
+void OperationValidator::visit(const operation::ElementwiseActivation &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(0)};
+
+ // Check if I/O types match
+ OP_REQUIRES(isSameType(output_index, input_index));
+
+ switch (node.param().op_type)
+ {
+ case operation::ElementwiseActivation::Type::ELU:
+ OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
+ break;
+ case operation::ElementwiseActivation::Type::LEAKY_RELU:
+ OP_REQUIRES(
+ isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
+ DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
+ break;
+ case operation::ElementwiseActivation::Type::LOGISTIC:
+ OP_REQUIRES(
+ isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
+ DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
+ break;
+ case operation::ElementwiseActivation::Type::RELU:
+ OP_REQUIRES(isValidType(
+ input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
+ break;
+ case operation::ElementwiseActivation::Type::TANH:
+ OP_REQUIRES(
+ isValidType(input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM,
+ DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT16_ASYMM}));
+ break;
+ }
+}
+
+void OperationValidator::visit(const operation::ElementwiseBinary &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto lhs_index{node.getInputs().at(operation::ElementwiseBinary::Input::LHS)};
+ const auto rhs_index{node.getInputs().at(operation::ElementwiseBinary::Input::RHS)};
+
+ OP_REQUIRES(isSameType(lhs_index, rhs_index));
+ OP_REQUIRES(isSameType(lhs_index, output_index));
+
+ const auto op_type = node.param().op_type;
+ if (op_type == operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_AND ||
+ op_type == operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_OR)
+ {
+ OP_REQUIRES(isValidType(lhs_index, DataType::BOOL8));
+ }
+}
+
+void OperationValidator::visit(const operation::ElementwiseUnary &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(operation::ElementwiseUnary::Input::INPUT)};
+
+ // Check if I/O types match
+ if (node.param().op_type == operation::ElementwiseUnary::Type::DEQUANTIZE)
+ {
+ // NNAPI allow QUANT_INT8_SYMM type input
+ OP_REQUIRES(isValidType(input_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_SYMM,
+ DataType::QUANT_INT8_ASYMM}));
+ OP_REQUIRES(isValidType(output_index, DataType::FLOAT32));
+ }
+ else if (node.param().op_type == operation::ElementwiseUnary::Type::QUANTIZE)
+ {
+ OP_REQUIRES(isValidType(
+ input_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
+ OP_REQUIRES(
+ isValidType(output_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
+ }
+ else if (node.param().op_type == operation::ElementwiseUnary::Type::FLOOR)
+ {
+ OP_REQUIRES(isValidType(input_index, DataType::FLOAT32));
+ OP_REQUIRES(isSameType(output_index, input_index));
+ }
+ else if (node.param().op_type != operation::ElementwiseUnary::Type::CAST)
+ {
+ OP_REQUIRES(isSameType(output_index, input_index));
+ }
+}
+
+void OperationValidator::visit(const operation::EmbeddingLookup &node)
+{
+ const auto lookups_index{node.getInputs().at(operation::EmbeddingLookup::Input::LOOKUPS)};
+ const auto values_index{node.getInputs().at(operation::EmbeddingLookup::Input::VALUES)};
+ const auto output_index{node.getOutputs().at(0)};
+
+ OP_REQUIRES(isValidType(lookups_index, DataType::INT32));
+
+ // TFLite: Allow hybrid type - value table & output
+ // NNAPI: Require same value table and output type
+ OP_REQUIRES(
+ isSameType(values_index, output_index) ||
+ (isValidType(output_index, DataType::FLOAT32) &&
+ (isValidType(values_index, {DataType::QUANT_INT8_ASYMM, DataType::QUANT_INT8_SYMM}))));
+}
+
+void OperationValidator::visit(const operation::ExpandDims &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(operation::ExpandDims::Input::INPUT)};
+ const auto axis_index{node.getInputs().at(operation::ExpandDims::Input::AXIS)};
+
+ OP_REQUIRES(isSameType(output_index, input_index));
+ OP_REQUIRES(isValidType(axis_index, {DataType::INT32, DataType::INT64}));
+}
+
+void OperationValidator::visit(const operation::Fill &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(operation::Fill::Input::SHAPE)};
+ const auto value_index{node.getInputs().at(operation::Fill::Input::VALUE)};
+
+ OP_REQUIRES(isSameType(output_index, value_index));
+ OP_REQUIRES(isValidType(input_index, {DataType::INT32, DataType::INT64}));
+ OP_REQUIRES(isValidType(output_index,
+ {DataType::FLOAT32, DataType::INT32, DataType::INT64, DataType::BOOL8}));
+}
+
+void OperationValidator::visit(const operation::HashtableLookup &node)
+{
+ const auto hits_index{node.getOutputs().at(operation::HashtableLookup::Output::HITS)};
+ const auto lookups_index{node.getInputs().at(operation::HashtableLookup::Input::LOOKUPS)};
+ const auto keys_index{node.getInputs().at(operation::HashtableLookup::Input::KEYS)};
+
+ OP_REQUIRES(isValidType(lookups_index, DataType::INT32));
+ OP_REQUIRES(isValidType(keys_index, DataType::INT32));
+ OP_REQUIRES(isValidType(hits_index, DataType::QUANT_UINT8_ASYMM));
+}
+
+void OperationValidator::visit(const operation::Pack &node)
+{
+ const auto num{node.param().num};
+
+ OP_REQUIRES(num == static_cast<int32_t>(node.getInputs().size()));
+}
+
+void OperationValidator::visit(const operation::Pad &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(operation::Pad::Input::INPUT)};
+ const auto pad_index{node.getInputs().at(operation::Pad::Input::PAD)};
+ bool isQuantType =
+ isValidType(output_index, {DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM});
+ bool isPadV2 = node.getInputs().size() == 3 ? true : false;
+
+ OP_REQUIRES(isValidType(pad_index, DataType::INT32));
+ OP_REQUIRES(isSameType(input_index, output_index));
+
+ if (isQuantType)
+ OP_REQUIRES(isSameQuantParam(input_index, output_index));
+
+ if (isPadV2)
+ {
+ const auto value_index{node.getInputs().at(operation::Pad::Input::VALUE)};
+ const bool cond_same = isSameType(input_index, value_index);
+ const bool cond_same_quant = (!isQuantType || isSameQuantParam(input_index, value_index));
+ const auto input_t = operandType(input_index);
+ const auto value_t = operandType(value_index);
+ // NNAPI accepts this case. scale and zeroPoint are assumed to be the same as in input0.
+ const bool cond_quant8 =
+ ((input_t == DataType::QUANT_UINT8_ASYMM || input_t == DataType::QUANT_INT8_ASYMM) &&
+ value_t == DataType::INT32);
+ OP_REQUIRES((cond_same && cond_same_quant) || cond_quant8);
+ }
+}
+
+void OperationValidator::visit(const operation::Rank &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+
+ OP_REQUIRES(isValidType(output_index, DataType::INT32));
+}
+
+void OperationValidator::visit(const operation::ResizeBilinear &node)
+{
+ auto align_corners = node.param().align_corners;
+ auto half_pixel_centers = node.param().half_pixel_centers;
+
+ OP_REQUIRES(!align_corners || !half_pixel_centers);
+}
+
+void OperationValidator::visit(const operation::Reverse &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(operation::Reverse::Input::INPUT)};
+ const auto axis_index{node.getInputs().at(operation::Reverse::Input::AXIS)};
+
+ OP_REQUIRES(isValidType(axis_index, DataType::INT32));
+ OP_REQUIRES(isSameType(output_index, input_index));
+}
+
+void OperationValidator::visit(const operation::Select &node)
+{
+ const auto condition_index{node.getInputs().at(operation::Select::Input::CONDITION)};
+ const auto input_true_index{node.getInputs().at(operation::Select::Input::INPUT_TRUE)};
+ const auto input_false_index{node.getInputs().at(operation::Select::Input::INPUT_FALSE)};
+
+ OP_REQUIRES(isValidType(condition_index, DataType::BOOL8));
+ OP_REQUIRES(isSameType(input_true_index, input_false_index));
+}
+
+void OperationValidator::visit(const operation::Shape &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+
+ OP_REQUIRES(isValidType(output_index, {DataType::UINT32, DataType::INT32, DataType::INT64}));
+}
+
+void OperationValidator::visit(const operation::Slice &node)
+{
+ const auto begins_index{node.getInputs().at(operation::Slice::BEGINS)};
+ const auto sizes_index{node.getInputs().at(operation::Slice::SIZES)};
+
+ OP_REQUIRES(isValidType(begins_index, {DataType::INT32, DataType::INT64}));
+ OP_REQUIRES(isSameType(begins_index, sizes_index));
+}
+
+void OperationValidator::visit(const operation::Softmax &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(operation::Softmax::INPUT)};
+
+ OP_REQUIRES(isSameType(input_index, output_index));
+ OP_REQUIRES(isValidType(
+ output_index, {DataType::FLOAT32, DataType::QUANT_UINT8_ASYMM, DataType::QUANT_INT8_ASYMM}));
+}
+
+void OperationValidator::visit(const operation::SpaceToBatchND &node)
+{
+ const auto block_size_index{node.getInputs().at(operation::SpaceToBatchND::Input::BLOCK_SIZE)};
+ const auto paddings_index{node.getInputs().at(operation::SpaceToBatchND::Input::PADDINGS)};
+
+ // Non-constant block_size and padding is not implemented yet
+ OP_REQUIRES(isConstant(block_size_index));
+ OP_REQUIRES(isConstant(paddings_index));
+}
+
+void OperationValidator::visit(const operation::SpaceToDepth &node)
+{
+ const auto block_size = node.param().block_size;
+ OP_REQUIRES(block_size >= 1);
+}
+
+void OperationValidator::visit(const operation::Split &node)
+{
+ const auto num_splits = node.param().num_splits;
+
+ OP_REQUIRES(num_splits > 0 && num_splits <= 0xFFFF);
+ OP_REQUIRES(node.getOutputs().size() == static_cast<uint32_t>(num_splits));
+}
+
+void OperationValidator::visit(const operation::SquaredDifference &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto lhs_index{node.getInputs().at(operation::SquaredDifference::Input::LHS)};
+ const auto rhs_index{node.getInputs().at(operation::SquaredDifference::Input::RHS)};
+
+ OP_REQUIRES(isSameType(output_index, lhs_index));
+ OP_REQUIRES(isSameType(lhs_index, rhs_index));
+}
+
+void OperationValidator::visit(const operation::StatelessRandomUniform &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto shape_index{node.getInputs().at(operation::StatelessRandomUniform::Input::SHAPE)};
+ const auto seed_index{node.getInputs().at(operation::StatelessRandomUniform::Input::SEED)};
+
+ OP_REQUIRES(isValidType(output_index, DataType::FLOAT32));
+ OP_REQUIRES(isValidType(shape_index, DataType::INT32));
+ OP_REQUIRES(isValidType(seed_index, DataType::INT32));
+}
+
+void OperationValidator::visit(const operation::StridedSlice &node)
+{
+ const auto output_index{node.getOutputs().at(0)};
+ const auto input_index{node.getInputs().at(operation::StridedSlice::Input::INPUT)};
+
+ OP_REQUIRES(isSameType(output_index, input_index));
+}
+
+void OperationValidator::visit(const operation::TransposeConv &node)
+{
+ OP_REQUIRES((node.param().padding.type == PaddingType::SAME) ||
+ (node.param().padding.type == PaddingType::VALID));
+}
+
+void OperationValidator::visit(const operation::Unpack &node)
+{
+ const auto num{node.param().num};
+ OP_REQUIRES(num == static_cast<int32_t>(node.getOutputs().size()));
+}
+
+void OperationValidator::visit(const operation::While &node)
+{
+ OP_REQUIRES(node.getInputs().size() == node.getOutputs().size());
+}
+
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/OperationValidator.h b/runtime/onert/core/src/ir/OperationValidator.h
new file mode 100644
index 000000000..b9bcc4ee8
--- /dev/null
+++ b/runtime/onert/core/src/ir/OperationValidator.h
@@ -0,0 +1,101 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_OPERATION_VALIDATOR_H__
+#define __ONERT_IR_OPERATION_VALIDATOR_H__
+
+#include "ir/OperationVisitor.h"
+#include "ir/Operations.h"
+#include "ir/Operands.h"
+
+namespace onert
+{
+namespace ir
+{
+class Graph;
+class Operands;
+} // namespace ir
+} // namespace onert
+
+namespace onert
+{
+namespace ir
+{
+
+class OperationValidator : public OperationVisitor
+{
+public:
+ OperationValidator(void) = delete;
+ OperationValidator(const Graph &graph);
+
+public:
+ void operator()();
+
+public:
+ void visit(const operation::AddN &node) override;
+ void visit(const operation::ArgMinMax &node) override;
+ void visit(const operation::BatchMatMul &node) override;
+ void visit(const operation::BatchToSpaceND &node) override;
+ void visit(const operation::BinaryArithmetic &node) override;
+ void visit(const operation::Comparison &node) override;
+ void visit(const operation::Concat &node) override;
+ void visit(const operation::Conv2D &node) override;
+ void visit(const operation::DepthToSpace &node) override;
+ void visit(const operation::DepthwiseConv2D &node) override;
+ void visit(const operation::DetectionPostProcess &node) override;
+ void visit(const operation::ElementwiseActivation &node) override;
+ void visit(const operation::ElementwiseBinary &node) override;
+ void visit(const operation::ElementwiseUnary &node) override;
+ void visit(const operation::EmbeddingLookup &node) override;
+ void visit(const operation::ExpandDims &node) override;
+ void visit(const operation::Fill &node) override;
+ void visit(const operation::HashtableLookup &node) override;
+ void visit(const operation::Pack &node) override;
+ void visit(const operation::Pad &node) override;
+ void visit(const operation::Rank &node) override;
+ void visit(const operation::ResizeBilinear &node) override;
+ void visit(const operation::Reverse &node) override;
+ void visit(const operation::Select &node) override;
+ void visit(const operation::Shape &node) override;
+ void visit(const operation::Slice &node) override;
+ void visit(const operation::Softmax &node) override;
+ void visit(const operation::SpaceToBatchND &node) override;
+ void visit(const operation::SpaceToDepth &node) override;
+ void visit(const operation::Split &node) override;
+ void visit(const operation::SquaredDifference &node) override;
+ void visit(const operation::StatelessRandomUniform &node) override;
+ void visit(const operation::StridedSlice &node) override;
+ void visit(const operation::TransposeConv &node) override;
+ void visit(const operation::Unpack &node) override;
+ void visit(const operation::While &node) override;
+
+private:
+ DataType operandType(const OperandIndex &idx);
+ bool isConstant(const OperandIndex &idx);
+ bool isSameType(const OperandIndex &idx1, const OperandIndex &idx2);
+ bool isSameQuantParam(const OperandIndex &idx1, const OperandIndex &idx2);
+ bool isValidType(const OperandIndex &idx, const DataType &type);
+ bool isValidType(const OperandIndex &idx, std::initializer_list<DataType> valid_types);
+
+private:
+ const Operations &_operations;
+ const Operands &_operands;
+};
+
+} // namespace ir
+} // namespace onert
+
+#endif // __ONERT_IR_OPERATION_VALIDATOR_H__
diff --git a/runtime/onert/core/src/ir/Operations.cc b/runtime/onert/core/src/ir/Operations.cc
index 64d0bd6f0..1b4691f58 100644
--- a/runtime/onert/core/src/ir/Operations.cc
+++ b/runtime/onert/core/src/ir/Operations.cc
@@ -25,12 +25,9 @@ namespace ir
Operations::Operations(const Operations &obj)
{
- obj.iterate([&](const OperationIndex &index, const Operation &op) {
- OperationCloner cloner;
- op.accept(cloner);
- _objects.emplace(index, cloner.releaseClone());
- });
- _index_count = obj._index_count;
+ obj.iterate(
+ [&](const OperationIndex &index, const IOperation &op) { _objects.emplace(index, clone(op)); });
+ _next_index = obj._next_index;
}
} // namespace ir
diff --git a/runtime/onert/core/src/ir/Operations.test.cc b/runtime/onert/core/src/ir/Operations.test.cc
new file mode 100644
index 000000000..e57872689
--- /dev/null
+++ b/runtime/onert/core/src/ir/Operations.test.cc
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/Operations.h"
+
+#include "MockNode.h"
+
+#include <gtest/gtest.h>
+
+using onert::ir::Operation;
+using onert::ir::OperationIndex;
+using onert::ir::Operations;
+
+TEST(ir_Operations, basic)
+{
+ Operations ops;
+ ops.push(std::unique_ptr<Operation>(new onert_test::ir::SimpleMock({1, 2, 3, 4}, {5, 6, 7})));
+ OperationIndex idx{0u};
+ ASSERT_EQ(ops.at(idx).getInputs().size(), 4);
+ ASSERT_EQ(ops.at(idx).getOutputs().size(), 3);
+}
+
+TEST(ir_Operations, neg_at)
+{
+ Operations ops;
+ ops.push(std::unique_ptr<Operation>(new onert_test::ir::SimpleMock({1, 2, 3, 4}, {5, 6, 7})));
+ OperationIndex idx{99u};
+ EXPECT_THROW(ops.at(idx), std::out_of_range);
+}
diff --git a/runtime/onert/core/src/ir/Padding.cc b/runtime/onert/core/src/ir/Padding.cc
index d74f80217..b2b004e7a 100644
--- a/runtime/onert/core/src/ir/Padding.cc
+++ b/runtime/onert/core/src/ir/Padding.cc
@@ -66,14 +66,14 @@ inline ExplicitPadding samePaddingUsingIFM(const FeatureShape &ifm_shape, const
const int32_t vertical_expected_output = (ifm_shape.H + stride.vertical - 1) / stride.vertical;
const int32_t horizontal_expected_output =
- (ifm_shape.W + stride.horizontal - 1) / stride.horizontal;
+ (ifm_shape.W + stride.horizontal - 1) / stride.horizontal;
const int32_t vertical_needed_input =
- (vertical_expected_output - 1) * stride.vertical + effective_filter_h_size;
+ (vertical_expected_output - 1) * stride.vertical + effective_filter_h_size;
const int32_t vertical_total_padding = std::max(0, vertical_needed_input - ifm_shape.H);
const int32_t horizontal_needed_input =
- (horizontal_expected_output - 1) * stride.horizontal + effective_filter_w_size;
+ (horizontal_expected_output - 1) * stride.horizontal + effective_filter_w_size;
const int32_t horizontal_total_padding = std::max(0, horizontal_needed_input - ifm_shape.W);
padding.top = vertical_total_padding / 2;
@@ -90,7 +90,7 @@ inline ExplicitPadding samePadding(const FeatureShape &ifm_shape, const FeatureS
{
const int32_t vertical_expected_output = (ifm_shape.H + stride.vertical - 1) / stride.vertical;
const int32_t horizontal_expected_output =
- (ifm_shape.W + stride.horizontal - 1) / stride.horizontal;
+ (ifm_shape.W + stride.horizontal - 1) / stride.horizontal;
assert(vertical_expected_output == ofm_shape.H);
assert(horizontal_expected_output == ofm_shape.W);
@@ -129,7 +129,7 @@ Padding::Padding(PaddingType paddingType) : type{paddingType}, param{0, 0, 0, 0}
}
Padding::Padding(uint32_t left, uint32_t right, uint32_t top, uint32_t bottom)
- : type{PaddingType::EXPLICIT}, param{left, right, top, bottom}
+ : type{PaddingType::EXPLICIT}, param{left, right, top, bottom}
{
// DO NOTHING
}
diff --git a/runtime/onert/core/src/ir/Shape.cc b/runtime/onert/core/src/ir/Shape.cc
index 322df7b4c..e4e4c154b 100644
--- a/runtime/onert/core/src/ir/Shape.cc
+++ b/runtime/onert/core/src/ir/Shape.cc
@@ -26,10 +26,10 @@ namespace onert
namespace ir
{
-int32_t const Shape::UNSPECIFIED_DIM = -1;
+int32_t const Shape::kUnspecifiedDim = -1;
// NNFW_MAX_RANK is 6
-int32_t const Shape::MAX_RANK = 6;
+int32_t const Shape::kMaxRank = 6;
FeatureShape Shape::asFeature(Layout layout) const
{
@@ -80,34 +80,37 @@ uint64_t Shape::num_elements() const
{
// if dimension is 0, it means unspecified and cannot calculate the total number of elements
if (std::any_of(_dimensions.begin(), _dimensions.end(),
- [](const int32_t &v) { return v == UNSPECIFIED_DIM; }))
+ [](const int32_t &v) { return v == kUnspecifiedDim; }))
throw std::runtime_error("num_elements() cannot calculate when any dimension is unspecified");
return std::accumulate(_dimensions.cbegin(), _dimensions.cend(), UINT64_C(1),
std::multiplies<uint64_t>());
}
-Shape permuteShape(const Shape &shape, Layout frontend_layout, Layout backend_layout)
+Shape permuteShape(const Shape &shape, Layout from, Layout to)
{
- assert(shape.rank() <= Shape::MAX_RANK);
- Shape backend_shape{shape};
- if (shape.rank() >= 4 && frontend_layout == Layout::NHWC && backend_layout == Layout::NCHW)
+ assert(shape.rank() <= Shape::kMaxRank);
+ Shape ret{shape};
+ if (from == to)
+ return ret;
+ if (shape.rank() < 4)
+ return ret;
+ // Permutation changing layout beyond 4-D is not supported yet
+ assert(shape.rank() <= 4);
+ if (from == Layout::NHWC && to == Layout::NCHW)
{
- // Permutation changing layout beyond 4-D is not supported yet
- assert(shape.rank() <= 4);
- backend_shape.dim(1) = shape.dim(3);
- backend_shape.dim(2) = shape.dim(1);
- backend_shape.dim(3) = shape.dim(2);
+ ret.dim(1) = shape.dim(3);
+ ret.dim(2) = shape.dim(1);
+ ret.dim(3) = shape.dim(2);
}
- else if (shape.rank() >= 4 && frontend_layout == Layout::NCHW && backend_layout == Layout::NHWC)
+ else if (from == Layout::NCHW && to == Layout::NHWC)
{
- // Permutation changing layout beyond 4-D is not supported yet
- assert(shape.rank() <= 4);
- backend_shape.dim(1) = shape.dim(2);
- backend_shape.dim(2) = shape.dim(3);
- backend_shape.dim(3) = shape.dim(1);
+ ret.dim(1) = shape.dim(2);
+ ret.dim(2) = shape.dim(3);
+ ret.dim(3) = shape.dim(1);
}
- return backend_shape;
+ // Other cases(either `from` or `to` is UNKNOWN), just return the original shape
+ return ret;
}
} // namespace ir
diff --git a/runtime/onert/core/src/ir/Shape.test.cc b/runtime/onert/core/src/ir/Shape.test.cc
new file mode 100644
index 000000000..4788522d3
--- /dev/null
+++ b/runtime/onert/core/src/ir/Shape.test.cc
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/Shape.h"
+
+#include <gtest/gtest.h>
+
+TEST(ShapeTest, basic_test)
+{
+ {
+ onert::ir::Shape shape(3);
+
+ shape.dim(0) = 1;
+ shape.dim(1) = 2;
+ shape.dim(2) = 3;
+
+ ASSERT_EQ(shape.rank(), 3);
+ ASSERT_EQ(shape.num_elements(), 6);
+ ASSERT_EQ(onert::ir::rankMaybeUnspecified(shape), false);
+ ASSERT_EQ(shape.hasUnspecifiedDims(), false);
+ }
+ {
+ onert::ir::Shape shape; // scalar or rank is unspecified
+
+ ASSERT_EQ(shape.rank(), 0);
+ ASSERT_EQ(shape.num_elements(), 1);
+ ASSERT_EQ(onert::ir::rankMaybeUnspecified(shape), true);
+ ASSERT_EQ(shape.hasUnspecifiedDims(), false);
+ }
+}
+
+TEST(ShapeTest, neg_basic_test)
+{
+ {
+ onert::ir::Shape shape(2);
+
+ shape.dim(0) = 1;
+ shape.dim(1) = onert::ir::Shape::kUnspecifiedDim;
+
+ ASSERT_EQ(shape.rank(), 2);
+ ASSERT_EQ(onert::ir::rankMaybeUnspecified(shape), false);
+ ASSERT_EQ(shape.hasUnspecifiedDims(), true);
+ EXPECT_ANY_THROW(shape.num_elements());
+ }
+}
diff --git a/runtime/onert/core/src/ir/TypeInfo.cc b/runtime/onert/core/src/ir/TypeInfo.cc
index ab8af287e..5d1c7ba8b 100644
--- a/runtime/onert/core/src/ir/TypeInfo.cc
+++ b/runtime/onert/core/src/ir/TypeInfo.cc
@@ -28,7 +28,7 @@ bool operator==(const TypeInfo &lhs, const TypeInfo &rhs)
return false;
}
- if (lhs.offset() != rhs.offset())
+ if (lhs.zero_point() != rhs.zero_point())
{
return false;
}
diff --git a/runtime/onert/core/src/ir/operation/AddN.cc b/runtime/onert/core/src/ir/operation/AddN.cc
new file mode 100644
index 000000000..a51e12dff
--- /dev/null
+++ b/runtime/onert/core/src/ir/operation/AddN.cc
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/operation/AddN.h"
+#include "ir/OperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace operation
+{
+
+void AddN::accept(OperationVisitor &v) const { v.visit(*this); }
+
+AddN::AddN(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
+ : Operation{OperandConstraint::createExact(inputs.size()), inputs, outputs}
+{
+}
+
+} // namespace operation
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/operation/ArgMax.cc b/runtime/onert/core/src/ir/operation/ArgMinMax.cc
index 1275ae43a..2f18ff2e2 100644
--- a/runtime/onert/core/src/ir/operation/ArgMax.cc
+++ b/runtime/onert/core/src/ir/operation/ArgMinMax.cc
@@ -14,10 +14,7 @@
* limitations under the License.
*/
-#include "ir/operation/ArgMax.h"
-
-#include <cassert>
-
+#include "ir/operation/ArgMinMax.h"
#include "ir/OperationVisitor.h"
namespace onert
@@ -27,11 +24,11 @@ namespace ir
namespace operation
{
-void ArgMax::accept(OperationVisitor &v) const { v.visit(*this); }
+void ArgMinMax::accept(OperationVisitor &v) const { v.visit(*this); }
-ArgMax::ArgMax(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
- const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
+ArgMinMax::ArgMinMax(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
+ const Param &param)
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/BCQFullyConnected.cc b/runtime/onert/core/src/ir/operation/BCQFullyConnected.cc
index 9dc54e6e9..ccda674ad 100644
--- a/runtime/onert/core/src/ir/operation/BCQFullyConnected.cc
+++ b/runtime/onert/core/src/ir/operation/BCQFullyConnected.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/BCQFullyConnected.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void BCQFullyConnected::accept(OperationVisitor &v) const { v.visit(*this); }
BCQFullyConnected::BCQFullyConnected(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs, const Param &param)
- : Operation{OperandConstraint::createExact(5u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(5u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/BCQGather.cc b/runtime/onert/core/src/ir/operation/BCQGather.cc
index 80efa6460..1ca5b0c9f 100644
--- a/runtime/onert/core/src/ir/operation/BCQGather.cc
+++ b/runtime/onert/core/src/ir/operation/BCQGather.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/BCQGather.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void BCQGather::accept(OperationVisitor &v) const { v.visit(*this); }
BCQGather::BCQGather(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(4u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(4u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/BatchMatMul.cc b/runtime/onert/core/src/ir/operation/BatchMatMul.cc
index b9616158d..20c5682f9 100644
--- a/runtime/onert/core/src/ir/operation/BatchMatMul.cc
+++ b/runtime/onert/core/src/ir/operation/BatchMatMul.cc
@@ -28,7 +28,7 @@ void BatchMatMul::accept(OperationVisitor &v) const { v.visit(*this); }
BatchMatMul::BatchMatMul(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/BatchToSpaceND.cc b/runtime/onert/core/src/ir/operation/BatchToSpaceND.cc
index 9ef2b125f..3c5578ac4 100644
--- a/runtime/onert/core/src/ir/operation/BatchToSpaceND.cc
+++ b/runtime/onert/core/src/ir/operation/BatchToSpaceND.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/BatchToSpaceND.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void BatchToSpaceND::accept(OperationVisitor &v) const { v.visit(*this); }
BatchToSpaceND::BatchToSpaceND(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(3u), inputs, outputs}
+ : Operation{OperandConstraint::createInRange(2u, 3u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/BinaryArithmetic.cc b/runtime/onert/core/src/ir/operation/BinaryArithmetic.cc
index 2b1422c73..5eb3fc3d7 100644
--- a/runtime/onert/core/src/ir/operation/BinaryArithmetic.cc
+++ b/runtime/onert/core/src/ir/operation/BinaryArithmetic.cc
@@ -15,12 +15,10 @@
*/
#include "ir/operation/BinaryArithmetic.h"
+#include "ir/OperationVisitor.h"
-#include <cassert>
#include <unordered_map>
-#include "ir/OperationVisitor.h"
-
namespace onert
{
namespace ir
@@ -32,7 +30,7 @@ void BinaryArithmetic::accept(OperationVisitor &v) const { v.visit(*this); }
BinaryArithmetic::BinaryArithmetic(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs, const Param &param)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
{
}
@@ -40,10 +38,10 @@ std::string BinaryArithmetic::name() const
{
using ArithmeticType = onert::ir::operation::BinaryArithmetic::ArithmeticType;
static const std::unordered_map<ArithmeticType, std::string> name_map{
- {ArithmeticType::ADD, std::string{"Add"}},
- {ArithmeticType::SUB, std::string{"Sub"}},
- {ArithmeticType::MUL, std::string{"Mul"}},
- {ArithmeticType::DIV, std::string{"Div"}}};
+ {ArithmeticType::ADD, std::string{"Add"}},
+ {ArithmeticType::SUB, std::string{"Sub"}},
+ {ArithmeticType::MUL, std::string{"Mul"}},
+ {ArithmeticType::DIV, std::string{"Div"}}};
return name_map.at(_param.arithmetic_type);
}
diff --git a/runtime/onert/core/src/ir/operation/BroadcastTo.cc b/runtime/onert/core/src/ir/operation/BroadcastTo.cc
index a8f5e59cf..eab6c0611 100644
--- a/runtime/onert/core/src/ir/operation/BroadcastTo.cc
+++ b/runtime/onert/core/src/ir/operation/BroadcastTo.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/BroadcastTo.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -29,7 +26,7 @@ namespace operation
void BroadcastTo::accept(OperationVisitor &v) const { v.visit(*this); }
BroadcastTo::BroadcastTo(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Bulk.cc b/runtime/onert/core/src/ir/operation/Bulk.cc
new file mode 100644
index 000000000..4b96c9d94
--- /dev/null
+++ b/runtime/onert/core/src/ir/operation/Bulk.cc
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/operation/Bulk.h"
+#include "ir/OperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace operation
+{
+void Bulk::accept(OperationVisitor &v) const { v.visit(*this); }
+
+Bulk::Bulk(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
+ const Bulk::Param &param)
+ : Operation{OperandConstraint::createAny(), inputs, outputs}, _param{param}
+{
+}
+
+} // namespace operation
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/operation/Comparison.cc b/runtime/onert/core/src/ir/operation/Comparison.cc
index 2f6775411..33365657c 100644
--- a/runtime/onert/core/src/ir/operation/Comparison.cc
+++ b/runtime/onert/core/src/ir/operation/Comparison.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Comparison.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void Comparison::accept(OperationVisitor &v) const { v.visit(*this); }
Comparison::Comparison(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Concat.cc b/runtime/onert/core/src/ir/operation/Concat.cc
index 608bc29a6..3a21e36f2 100644
--- a/runtime/onert/core/src/ir/operation/Concat.cc
+++ b/runtime/onert/core/src/ir/operation/Concat.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Concat.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void Concat::accept(OperationVisitor &v) const { v.visit(*this); }
Concat::Concat(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createAtLeast(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createAtLeast(1u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Conv2D.cc b/runtime/onert/core/src/ir/operation/Conv2D.cc
index 3a2e1d1fe..d615ae416 100644
--- a/runtime/onert/core/src/ir/operation/Conv2D.cc
+++ b/runtime/onert/core/src/ir/operation/Conv2D.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Conv2D.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void Conv2D::accept(OperationVisitor &v) const { v.visit(*this); }
Conv2D::Conv2D(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/ConvertFp16ToFp32.cc b/runtime/onert/core/src/ir/operation/ConvertFp16ToFp32.cc
index 676e039fa..365745ea8 100644
--- a/runtime/onert/core/src/ir/operation/ConvertFp16ToFp32.cc
+++ b/runtime/onert/core/src/ir/operation/ConvertFp16ToFp32.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/ConvertFp16ToFp32.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void ConvertFp16ToFp32::accept(OperationVisitor &v) const { v.visit(*this); }
ConvertFp16ToFp32::ConvertFp16ToFp32(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/ConvertFp32ToFp16.cc b/runtime/onert/core/src/ir/operation/ConvertFp32ToFp16.cc
index bcfcbfc04..d4fc7031c 100644
--- a/runtime/onert/core/src/ir/operation/ConvertFp32ToFp16.cc
+++ b/runtime/onert/core/src/ir/operation/ConvertFp32ToFp16.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/ConvertFp32ToFp16.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void ConvertFp32ToFp16::accept(OperationVisitor &v) const { v.visit(*this); }
ConvertFp32ToFp16::ConvertFp32ToFp16(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Custom.cc b/runtime/onert/core/src/ir/operation/Custom.cc
index 25c53e1ba..06c84f81a 100644
--- a/runtime/onert/core/src/ir/operation/Custom.cc
+++ b/runtime/onert/core/src/ir/operation/Custom.cc
@@ -29,7 +29,7 @@ void Custom::accept(OperationVisitor &v) const { v.visit(*this); }
Custom::Custom(OperandConstraint input_constr, const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs, std::string id, const Userdata &userdata)
- : Operation{input_constr, inputs, outputs}, _id(std::move(id)), _userdata(userdata)
+ : Operation{input_constr, inputs, outputs}, _id(std::move(id)), _userdata(userdata)
{
}
diff --git a/runtime/onert/core/src/ir/operation/DepthToSpace.cc b/runtime/onert/core/src/ir/operation/DepthToSpace.cc
index f2d6c7c1b..e3edea777 100644
--- a/runtime/onert/core/src/ir/operation/DepthToSpace.cc
+++ b/runtime/onert/core/src/ir/operation/DepthToSpace.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/DepthToSpace.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void DepthToSpace::accept(OperationVisitor &v) const { v.visit(*this); }
DepthToSpace::DepthToSpace(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/DepthwiseConv2D.cc b/runtime/onert/core/src/ir/operation/DepthwiseConv2D.cc
index d587a5591..0e7137306 100644
--- a/runtime/onert/core/src/ir/operation/DepthwiseConv2D.cc
+++ b/runtime/onert/core/src/ir/operation/DepthwiseConv2D.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/DepthwiseConv2D.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void DepthwiseConv2D::accept(OperationVisitor &v) const { v.visit(*this); }
DepthwiseConv2D::DepthwiseConv2D(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs, const Param &param)
- : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/DetectionPostProcess.cc b/runtime/onert/core/src/ir/operation/DetectionPostProcess.cc
new file mode 100644
index 000000000..cd708796d
--- /dev/null
+++ b/runtime/onert/core/src/ir/operation/DetectionPostProcess.cc
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/operation/DetectionPostProcess.h"
+#include "ir/OperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace operation
+{
+
+DetectionPostProcess::DetectionPostProcess(const OperandIndexSequence &inputs,
+ const OperandIndexSequence &outputs, const Param &param)
+ : Operation(OperandConstraint::createExact(3u), inputs, outputs), _param(param)
+{
+}
+
+void DetectionPostProcess::accept(OperationVisitor &v) const { v.visit(*this); }
+
+} // namespace operation
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/operation/Einsum.cc b/runtime/onert/core/src/ir/operation/Einsum.cc
index 3c1473aaa..b50f070e7 100644
--- a/runtime/onert/core/src/ir/operation/Einsum.cc
+++ b/runtime/onert/core/src/ir/operation/Einsum.cc
@@ -28,7 +28,7 @@ void Einsum::accept(OperationVisitor &v) const { v.visit(*this); }
Einsum::Einsum(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createAtLeast(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createAtLeast(1u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/ElementwiseActivation.cc b/runtime/onert/core/src/ir/operation/ElementwiseActivation.cc
index f6718b656..e83c26e28 100644
--- a/runtime/onert/core/src/ir/operation/ElementwiseActivation.cc
+++ b/runtime/onert/core/src/ir/operation/ElementwiseActivation.cc
@@ -15,12 +15,10 @@
*/
#include "ir/operation/ElementwiseActivation.h"
+#include "ir/OperationVisitor.h"
-#include <cassert>
#include <unordered_map>
-#include "ir/OperationVisitor.h"
-
namespace onert
{
namespace ir
@@ -33,13 +31,14 @@ void ElementwiseActivation::accept(OperationVisitor &v) const { v.visit(*this);
ElementwiseActivation::ElementwiseActivation(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
{
if (param.op_type == Type::LOGISTIC)
{
- assert(param.alpha == 0.0f && param.beta == 0.0f && "Logistic will be supported only as "
- "sigmoid function(L=1, k=1, x0=0). So, do "
- "not use alpha and beta");
+ assert(param.alpha == 0.0f && param.beta == 0.0f &&
+ "Logistic will be supported only as "
+ "sigmoid function(L=1, k=1, x0=0). So, do "
+ "not use alpha and beta");
}
else if (param.op_type == Type::RELU)
{
@@ -47,9 +46,10 @@ ElementwiseActivation::ElementwiseActivation(const OperandIndexSequence &inputs,
}
else if (param.op_type == Type::TANH)
{
- assert(param.alpha == 1.0f && param.beta == 1.0f && "f(x) = alpha * tanh(beta * x), Tanh is "
- "supported only the values of alpha and "
- "beta are 1.f");
+ assert(param.alpha == 1.0f && param.beta == 1.0f &&
+ "f(x) = alpha * tanh(beta * x), Tanh is "
+ "supported only the values of alpha and "
+ "beta are 1.f");
}
}
@@ -57,11 +57,11 @@ std::string ElementwiseActivation::name() const
{
using ElementwiseActivationType = onert::ir::operation::ElementwiseActivation::Type;
static const std::unordered_map<Type, std::string> name_map{
- {ElementwiseActivationType::ELU, "ELU"},
- {ElementwiseActivationType::LOGISTIC, "Logistic"},
- {ElementwiseActivationType::RELU, "ReLU"},
- {ElementwiseActivationType::TANH, "Tanh"},
- {ElementwiseActivationType::LEAKY_RELU, "LeakyRelu"}};
+ {ElementwiseActivationType::ELU, "ELU"},
+ {ElementwiseActivationType::LOGISTIC, "Logistic"},
+ {ElementwiseActivationType::RELU, "ReLU"},
+ {ElementwiseActivationType::TANH, "Tanh"},
+ {ElementwiseActivationType::LEAKY_RELU, "LeakyRelu"}};
return name_map.at(_param.op_type);
}
diff --git a/runtime/onert/core/src/ir/operation/ElementwiseBinary.cc b/runtime/onert/core/src/ir/operation/ElementwiseBinary.cc
index 3287fc0a3..d445171fb 100644
--- a/runtime/onert/core/src/ir/operation/ElementwiseBinary.cc
+++ b/runtime/onert/core/src/ir/operation/ElementwiseBinary.cc
@@ -15,12 +15,10 @@
*/
#include "ir/operation/ElementwiseBinary.h"
+#include "ir/OperationVisitor.h"
-#include <cassert>
#include <unordered_map>
-#include "ir/OperationVisitor.h"
-
namespace onert
{
namespace ir
@@ -32,7 +30,7 @@ void ElementwiseBinary::accept(OperationVisitor &v) const { v.visit(*this); }
ElementwiseBinary::ElementwiseBinary(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs, const Param &param)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
{
}
@@ -40,10 +38,12 @@ std::string ElementwiseBinary::name() const
{
using ElementwiseBinaryType = onert::ir::operation::ElementwiseBinary::ElementwiseBinaryType;
static const std::unordered_map<ElementwiseBinaryType, std::string> name_map{
- {ElementwiseBinaryType::LOGICAL_AND, std::string{"LogicalAnd"}},
- {ElementwiseBinaryType::LOGICAL_OR, std::string{"LogicalOr"}},
- {ElementwiseBinaryType::MAX, std::string{"Max"}},
- {ElementwiseBinaryType::MIN, std::string{"Min"}}};
+ {ElementwiseBinaryType::FLOOR_DIV, std::string{"FloorDiv"}},
+ {ElementwiseBinaryType::FLOOR_MOD, std::string{"FloorMod"}},
+ {ElementwiseBinaryType::LOGICAL_AND, std::string{"LogicalAnd"}},
+ {ElementwiseBinaryType::LOGICAL_OR, std::string{"LogicalOr"}},
+ {ElementwiseBinaryType::MAX, std::string{"Max"}},
+ {ElementwiseBinaryType::MIN, std::string{"Min"}}};
return name_map.at(_param.op_type);
}
diff --git a/runtime/onert/core/src/ir/operation/ElementwiseUnary.cc b/runtime/onert/core/src/ir/operation/ElementwiseUnary.cc
index 7dfcd4a98..fd463e0fe 100644
--- a/runtime/onert/core/src/ir/operation/ElementwiseUnary.cc
+++ b/runtime/onert/core/src/ir/operation/ElementwiseUnary.cc
@@ -15,12 +15,10 @@
*/
#include "ir/operation/ElementwiseUnary.h"
+#include "ir/OperationVisitor.h"
-#include <cassert>
#include <unordered_map>
-#include "ir/OperationVisitor.h"
-
namespace onert
{
namespace ir
@@ -32,7 +30,9 @@ void ElementwiseUnary::accept(OperationVisitor &v) const { v.visit(*this); }
ElementwiseUnary::ElementwiseUnary(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs, const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs,
+ OperandConstraint::createExact(1u)},
+ _param{param}
{
}
@@ -40,23 +40,23 @@ std::string ElementwiseUnary::name() const
{
using ElementwiseUnaryType = onert::ir::operation::ElementwiseUnary::Type;
static const std::unordered_map<ElementwiseUnaryType, std::string> name_map{
- {ElementwiseUnaryType::ABS, std::string{"Abs"}},
- {ElementwiseUnaryType::CAST, std::string{"Cast"}},
- {ElementwiseUnaryType::COS, std::string{"Cos"}},
- {ElementwiseUnaryType::DEQUANTIZE, std::string{"Dequantize"}},
- {ElementwiseUnaryType::ERF, std::string{"Erf"}},
- {ElementwiseUnaryType::EXP, std::string{"Exp"}},
- {ElementwiseUnaryType::FLOOR, std::string{"Floor"}},
- {ElementwiseUnaryType::LOG, std::string{"Log"}},
- {ElementwiseUnaryType::LOGICAL_NOT, std::string{"LogicalNot"}},
- {ElementwiseUnaryType::NEG, std::string{"Neg"}},
- {ElementwiseUnaryType::QUANTIZE, std::string{"Quantize"}},
- {ElementwiseUnaryType::ROUND, std::string{"Round"}},
- {ElementwiseUnaryType::RSQRT, std::string{"RSqrt"}},
- {ElementwiseUnaryType::SIN, std::string{"Sin"}},
- {ElementwiseUnaryType::SQRT, std::string{"Sqrt"}},
- {ElementwiseUnaryType::SQURE, std::string{"Squre"}},
- {ElementwiseUnaryType::ZEROS_LIKE, std::string{"ZerosLike"}}};
+ {ElementwiseUnaryType::ABS, std::string{"Abs"}},
+ {ElementwiseUnaryType::CAST, std::string{"Cast"}},
+ {ElementwiseUnaryType::COS, std::string{"Cos"}},
+ {ElementwiseUnaryType::DEQUANTIZE, std::string{"Dequantize"}},
+ {ElementwiseUnaryType::ERF, std::string{"Erf"}},
+ {ElementwiseUnaryType::EXP, std::string{"Exp"}},
+ {ElementwiseUnaryType::FLOOR, std::string{"Floor"}},
+ {ElementwiseUnaryType::LOG, std::string{"Log"}},
+ {ElementwiseUnaryType::LOGICAL_NOT, std::string{"LogicalNot"}},
+ {ElementwiseUnaryType::NEG, std::string{"Neg"}},
+ {ElementwiseUnaryType::QUANTIZE, std::string{"Quantize"}},
+ {ElementwiseUnaryType::ROUND, std::string{"Round"}},
+ {ElementwiseUnaryType::RSQRT, std::string{"RSqrt"}},
+ {ElementwiseUnaryType::SIN, std::string{"Sin"}},
+ {ElementwiseUnaryType::SQRT, std::string{"Sqrt"}},
+ {ElementwiseUnaryType::SQUARE, std::string{"Square"}},
+ {ElementwiseUnaryType::ZEROS_LIKE, std::string{"ZerosLike"}}};
return name_map.at(_param.op_type);
}
diff --git a/runtime/onert/core/src/ir/operation/EmbeddingLookup.cc b/runtime/onert/core/src/ir/operation/EmbeddingLookup.cc
index b300b004e..66b80b2c5 100644
--- a/runtime/onert/core/src/ir/operation/EmbeddingLookup.cc
+++ b/runtime/onert/core/src/ir/operation/EmbeddingLookup.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/EmbeddingLookup.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void EmbeddingLookup::accept(OperationVisitor &v) const { v.visit(*this); }
EmbeddingLookup::EmbeddingLookup(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/ExpandDims.cc b/runtime/onert/core/src/ir/operation/ExpandDims.cc
index 3f555bd23..e421bc383 100644
--- a/runtime/onert/core/src/ir/operation/ExpandDims.cc
+++ b/runtime/onert/core/src/ir/operation/ExpandDims.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/ExpandDims.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -30,7 +27,7 @@ namespace operation
void ExpandDims::accept(OperationVisitor &v) const { v.visit(*this); }
ExpandDims::ExpandDims(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Fill.cc b/runtime/onert/core/src/ir/operation/Fill.cc
index c44f45aab..60355c609 100644
--- a/runtime/onert/core/src/ir/operation/Fill.cc
+++ b/runtime/onert/core/src/ir/operation/Fill.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Fill.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -30,7 +27,7 @@ namespace operation
void Fill::accept(OperationVisitor &v) const { v.visit(*this); }
Fill::Fill(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/FullyConnected.cc b/runtime/onert/core/src/ir/operation/FullyConnected.cc
index 118ae554a..3533df097 100644
--- a/runtime/onert/core/src/ir/operation/FullyConnected.cc
+++ b/runtime/onert/core/src/ir/operation/FullyConnected.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/FullyConnected.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void FullyConnected::accept(OperationVisitor &v) const { v.visit(*this); }
FullyConnected::FullyConnected(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs, const Param &param)
- : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createInRange(2u, 3u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/FusedBatchNorm.cc b/runtime/onert/core/src/ir/operation/FusedBatchNorm.cc
index 7b9301ea6..b5679f308 100644
--- a/runtime/onert/core/src/ir/operation/FusedBatchNorm.cc
+++ b/runtime/onert/core/src/ir/operation/FusedBatchNorm.cc
@@ -28,7 +28,7 @@ void FusedBatchNorm::accept(OperationVisitor &v) const { v.visit(*this); }
FusedBatchNorm::FusedBatchNorm(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs, const Param &param)
- : Operation{OperandConstraint::createAtLeast(5u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createAtLeast(5u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Gather.cc b/runtime/onert/core/src/ir/operation/Gather.cc
index 11d46e75b..e0c4630a0 100644
--- a/runtime/onert/core/src/ir/operation/Gather.cc
+++ b/runtime/onert/core/src/ir/operation/Gather.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Gather.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void Gather::accept(OperationVisitor &v) const { v.visit(*this); }
Gather::Gather(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/HashtableLookup.cc b/runtime/onert/core/src/ir/operation/HashtableLookup.cc
index e9a7a82ff..5d1589cd1 100644
--- a/runtime/onert/core/src/ir/operation/HashtableLookup.cc
+++ b/runtime/onert/core/src/ir/operation/HashtableLookup.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/HashtableLookup.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void HashtableLookup::accept(OperationVisitor &v) const { v.visit(*this); }
HashtableLookup::HashtableLookup(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(3u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(3u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/If.cc b/runtime/onert/core/src/ir/operation/If.cc
index 599751dfd..380c87dbe 100644
--- a/runtime/onert/core/src/ir/operation/If.cc
+++ b/runtime/onert/core/src/ir/operation/If.cc
@@ -24,7 +24,7 @@ namespace operation
{
void If::accept(OperationVisitor &v) const { v.visit(*this); }
If::If(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs, const Param &param)
- : Operation{OperandConstraint::createAny(), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createAny(), inputs, outputs}, _param{param}
{
}
} // namespace operation
diff --git a/runtime/onert/core/src/ir/operation/InstanceNorm.cc b/runtime/onert/core/src/ir/operation/InstanceNorm.cc
index 2334560ef..9fb55383e 100644
--- a/runtime/onert/core/src/ir/operation/InstanceNorm.cc
+++ b/runtime/onert/core/src/ir/operation/InstanceNorm.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/InstanceNorm.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void InstanceNorm::accept(OperationVisitor &v) const { v.visit(*this); }
InstanceNorm::InstanceNorm(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/L2Normalization.cc b/runtime/onert/core/src/ir/operation/L2Normalization.cc
index 9a7d3eb61..6725df596 100644
--- a/runtime/onert/core/src/ir/operation/L2Normalization.cc
+++ b/runtime/onert/core/src/ir/operation/L2Normalization.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/L2Normalization.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void L2Normalization::accept(OperationVisitor &v) const { v.visit(*this); }
L2Normalization::L2Normalization(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/LSTM.cc b/runtime/onert/core/src/ir/operation/LSTM.cc
index 30a865326..06e66158b 100644
--- a/runtime/onert/core/src/ir/operation/LSTM.cc
+++ b/runtime/onert/core/src/ir/operation/LSTM.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/LSTM.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,8 +28,16 @@ void LSTM::accept(OperationVisitor &v) const { v.visit(*this); }
LSTM::LSTM(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(23u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createInRange(20u, 24u), inputs, outputs}, _param{param}
+{
+}
+
+std::string LSTM::name() const
{
+ if (getOutputs().at(Output::SCRATCH_BUFFER).undefined())
+ return std::string{"UnidirectionalSequenceLSTM"};
+ else
+ return Operation::name();
}
} // namespace operation
diff --git a/runtime/onert/core/src/ir/operation/LocalResponseNormalization.cc b/runtime/onert/core/src/ir/operation/LocalResponseNormalization.cc
index 1ae97c142..73fca9938 100644
--- a/runtime/onert/core/src/ir/operation/LocalResponseNormalization.cc
+++ b/runtime/onert/core/src/ir/operation/LocalResponseNormalization.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/LocalResponseNormalization.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -32,7 +29,7 @@ void LocalResponseNormalization::accept(OperationVisitor &v) const { v.visit(*th
LocalResponseNormalization::LocalResponseNormalization(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/LogSoftmax.cc b/runtime/onert/core/src/ir/operation/LogSoftmax.cc
index 73c6580ec..d580e63e1 100644
--- a/runtime/onert/core/src/ir/operation/LogSoftmax.cc
+++ b/runtime/onert/core/src/ir/operation/LogSoftmax.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/LogSoftmax.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void LogSoftmax::accept(OperationVisitor &v) const { v.visit(*this); }
LogSoftmax::LogSoftmax(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Loss.cc b/runtime/onert/core/src/ir/operation/Loss.cc
new file mode 100644
index 000000000..2a0d6c4c8
--- /dev/null
+++ b/runtime/onert/core/src/ir/operation/Loss.cc
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/operation/Loss.h"
+#include "ir/OperationVisitor.h"
+
+#include <unordered_map>
+
+namespace onert
+{
+namespace ir
+{
+namespace operation
+{
+
+void Loss::accept(OperationVisitor &v) const { v.visit(*this); }
+
+Loss::Loss(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
+ : Operation{OperandConstraint::createAtLeast(2u), inputs, outputs}
+{
+ assert(inputs.size() == 2);
+}
+
+} // namespace operation
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/operation/MatrixBandPart.cc b/runtime/onert/core/src/ir/operation/MatrixBandPart.cc
index bac31f13e..e52bddc1f 100644
--- a/runtime/onert/core/src/ir/operation/MatrixBandPart.cc
+++ b/runtime/onert/core/src/ir/operation/MatrixBandPart.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/MatrixBandPart.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void MatrixBandPart::accept(OperationVisitor &v) const { v.visit(*this); }
MatrixBandPart::MatrixBandPart(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(3u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(3u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/OneHot.cc b/runtime/onert/core/src/ir/operation/OneHot.cc
index 22935e7d6..90898f1ed 100644
--- a/runtime/onert/core/src/ir/operation/OneHot.cc
+++ b/runtime/onert/core/src/ir/operation/OneHot.cc
@@ -28,7 +28,7 @@ void OneHot::accept(OperationVisitor &v) const { v.visit(*this); }
OneHot::OneHot(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(4u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(4u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/PReLU.cc b/runtime/onert/core/src/ir/operation/PReLU.cc
index a2e37e0ad..87bd12e60 100644
--- a/runtime/onert/core/src/ir/operation/PReLU.cc
+++ b/runtime/onert/core/src/ir/operation/PReLU.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/PReLU.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -30,7 +27,7 @@ namespace operation
void PReLU::accept(OperationVisitor &v) const { v.visit(*this); }
PReLU::PReLU(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Pack.cc b/runtime/onert/core/src/ir/operation/Pack.cc
index f0908a2c6..00feadfb0 100644
--- a/runtime/onert/core/src/ir/operation/Pack.cc
+++ b/runtime/onert/core/src/ir/operation/Pack.cc
@@ -25,7 +25,7 @@ namespace operation
void Pack::accept(OperationVisitor &v) const { v.visit(*this); }
Pack::Pack(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createAtLeast(3u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createAtLeast(1u), inputs, outputs}, _param{param}
{
}
} // namespace operation
diff --git a/runtime/onert/core/src/ir/operation/Pad.cc b/runtime/onert/core/src/ir/operation/Pad.cc
index 0c56e92e3..a3f2d9752 100644
--- a/runtime/onert/core/src/ir/operation/Pad.cc
+++ b/runtime/onert/core/src/ir/operation/Pad.cc
@@ -30,7 +30,7 @@ void Pad::accept(OperationVisitor &v) const { v.visit(*this); }
// PAD: 2 inputs
// PADV2: 3 inputs
Pad::Pad(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createInRange(2u, 3u), inputs, outputs}
+ : Operation{OperandConstraint::createInRange(2u, 3u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Permute.cc b/runtime/onert/core/src/ir/operation/Permute.cc
index eefb6c542..813fbaf30 100644
--- a/runtime/onert/core/src/ir/operation/Permute.cc
+++ b/runtime/onert/core/src/ir/operation/Permute.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Permute.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -30,7 +27,7 @@ namespace operation
void Permute::accept(OperationVisitor &v) const { v.visit(*this); }
Permute::Permute(const OperandIndex &input, const OperandIndex &output, Type type)
- : Operation{OperandConstraint::createExact(1u)}, _type{type}
+ : Operation{OperandConstraint::createExact(1u)}, _type{type}
{
setInputs({input});
setOutputs({output});
diff --git a/runtime/onert/core/src/ir/operation/Pool2D.cc b/runtime/onert/core/src/ir/operation/Pool2D.cc
index 761d14c3d..e32b876e6 100644
--- a/runtime/onert/core/src/ir/operation/Pool2D.cc
+++ b/runtime/onert/core/src/ir/operation/Pool2D.cc
@@ -15,12 +15,10 @@
*/
#include "ir/operation/Pool2D.h"
+#include "ir/OperationVisitor.h"
-#include <cassert>
#include <unordered_map>
-#include "ir/OperationVisitor.h"
-
namespace onert
{
namespace ir
@@ -32,7 +30,7 @@ void Pool2D::accept(OperationVisitor &v) const { v.visit(*this); }
Pool2D::Pool2D(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
{
}
@@ -40,9 +38,9 @@ std::string Pool2D::name() const
{
using PoolType = onert::ir::operation::Pool2D::PoolType;
static const std::unordered_map<PoolType, std::string> name_map{
- {PoolType::AVG, "Avg" + std::string{toString(opcode())}},
- {PoolType::L2, "L2" + std::string{toString(opcode())}},
- {PoolType::MAX, "Max" + std::string{toString(opcode())}}};
+ {PoolType::AVG, "Avg" + std::string{toString(opcode())}},
+ {PoolType::L2, "L2" + std::string{toString(opcode())}},
+ {PoolType::MAX, "Max" + std::string{toString(opcode())}}};
return name_map.at(_param.op_type);
}
diff --git a/runtime/onert/core/src/ir/operation/Pow.cc b/runtime/onert/core/src/ir/operation/Pow.cc
index 940b1391a..f7c159a12 100644
--- a/runtime/onert/core/src/ir/operation/Pow.cc
+++ b/runtime/onert/core/src/ir/operation/Pow.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Pow.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -30,7 +27,7 @@ namespace operation
void Pow::accept(OperationVisitor &v) const { v.visit(*this); }
Pow::Pow(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/RNN.cc b/runtime/onert/core/src/ir/operation/RNN.cc
index 298c5e745..988a50669 100644
--- a/runtime/onert/core/src/ir/operation/RNN.cc
+++ b/runtime/onert/core/src/ir/operation/RNN.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/RNN.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void RNN::accept(OperationVisitor &v) const { v.visit(*this); }
RNN::RNN(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(5u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(5u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Range.cc b/runtime/onert/core/src/ir/operation/Range.cc
index 96ab04c1b..8ced92a0b 100644
--- a/runtime/onert/core/src/ir/operation/Range.cc
+++ b/runtime/onert/core/src/ir/operation/Range.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Range.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -30,7 +27,7 @@ namespace operation
void Range::accept(OperationVisitor &v) const { v.visit(*this); }
Range::Range(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(3u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(3u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Rank.cc b/runtime/onert/core/src/ir/operation/Rank.cc
index c357e9018..40797bf29 100644
--- a/runtime/onert/core/src/ir/operation/Rank.cc
+++ b/runtime/onert/core/src/ir/operation/Rank.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Rank.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -30,7 +27,7 @@ namespace operation
void Rank::accept(OperationVisitor &v) const { v.visit(*this); }
Rank::Rank(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Reduce.cc b/runtime/onert/core/src/ir/operation/Reduce.cc
index d6a1d953c..8da1940fa 100644
--- a/runtime/onert/core/src/ir/operation/Reduce.cc
+++ b/runtime/onert/core/src/ir/operation/Reduce.cc
@@ -15,12 +15,10 @@
*/
#include "ir/operation/Reduce.h"
+#include "ir/OperationVisitor.h"
-#include <cassert>
#include <unordered_map>
-#include "ir/OperationVisitor.h"
-
namespace onert
{
namespace ir
@@ -32,7 +30,7 @@ void Reduce::accept(OperationVisitor &v) const { v.visit(*this); }
Reduce::Reduce(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
{
}
@@ -40,13 +38,13 @@ std::string Reduce::name() const
{
using ReduceType = onert::ir::operation::Reduce::ReduceType;
static const std::unordered_map<ReduceType, std::string> name_map{
- {ReduceType::ALL, std::string{toString(opcode())} + "All"},
- {ReduceType::ANY, std::string{toString(opcode())} + "Any"},
- {ReduceType::MAX, std::string{toString(opcode())} + "Max"},
- {ReduceType::MEAN, std::string{toString(opcode())} + "Mean"},
- {ReduceType::MIN, std::string{toString(opcode())} + "Min"},
- {ReduceType::PROD, std::string{toString(opcode())} + "Prod"},
- {ReduceType::SUM, std::string{toString(opcode())} + "SUM"}};
+ {ReduceType::ALL, std::string{toString(opcode())} + "All"},
+ {ReduceType::ANY, std::string{toString(opcode())} + "Any"},
+ {ReduceType::MAX, std::string{toString(opcode())} + "Max"},
+ {ReduceType::MEAN, std::string{toString(opcode())} + "Mean"},
+ {ReduceType::MIN, std::string{toString(opcode())} + "Min"},
+ {ReduceType::PROD, std::string{toString(opcode())} + "Prod"},
+ {ReduceType::SUM, std::string{toString(opcode())} + "SUM"}};
return name_map.at(_param.reduce_type);
// return std::string(toString(opcode())) + reduce_type_str_map.at(_param.reduce_type);
}
diff --git a/runtime/onert/core/src/ir/operation/Reshape.cc b/runtime/onert/core/src/ir/operation/Reshape.cc
index 92aa89ac6..0ed4affa1 100644
--- a/runtime/onert/core/src/ir/operation/Reshape.cc
+++ b/runtime/onert/core/src/ir/operation/Reshape.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Reshape.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void Reshape::accept(OperationVisitor &v) const { v.visit(*this); }
Reshape::Reshape(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param(param)
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param(param)
{
}
diff --git a/runtime/onert/core/src/ir/operation/ResizeBilinear.cc b/runtime/onert/core/src/ir/operation/ResizeBilinear.cc
index d0d89f45f..7d256f447 100644
--- a/runtime/onert/core/src/ir/operation/ResizeBilinear.cc
+++ b/runtime/onert/core/src/ir/operation/ResizeBilinear.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/ResizeBilinear.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void ResizeBilinear::accept(OperationVisitor &v) const { v.visit(*this); }
ResizeBilinear::ResizeBilinear(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs, const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createInRange(1u, 2u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/ResizeNearestNeighbor.cc b/runtime/onert/core/src/ir/operation/ResizeNearestNeighbor.cc
index 9f17af97c..58be87b95 100644
--- a/runtime/onert/core/src/ir/operation/ResizeNearestNeighbor.cc
+++ b/runtime/onert/core/src/ir/operation/ResizeNearestNeighbor.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/ResizeNearestNeighbor.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -32,7 +29,7 @@ void ResizeNearestNeighbor::accept(OperationVisitor &v) const { v.visit(*this);
ResizeNearestNeighbor::ResizeNearestNeighbor(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createInRange(1u, 2u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Reverse.cc b/runtime/onert/core/src/ir/operation/Reverse.cc
index 4b3c1e1af..6c3746426 100644
--- a/runtime/onert/core/src/ir/operation/Reverse.cc
+++ b/runtime/onert/core/src/ir/operation/Reverse.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Reverse.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -30,7 +27,7 @@ namespace operation
void Reverse::accept(OperationVisitor &v) const { v.visit(*this); }
Reverse::Reverse(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Select.cc b/runtime/onert/core/src/ir/operation/Select.cc
index 1f22b5234..59684190c 100644
--- a/runtime/onert/core/src/ir/operation/Select.cc
+++ b/runtime/onert/core/src/ir/operation/Select.cc
@@ -28,7 +28,7 @@ namespace operation
void Select::accept(OperationVisitor &v) const { v.visit(*this); }
Select::Select(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(3u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(3u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Shape.cc b/runtime/onert/core/src/ir/operation/Shape.cc
index 2a63d6dcf..f90924488 100644
--- a/runtime/onert/core/src/ir/operation/Shape.cc
+++ b/runtime/onert/core/src/ir/operation/Shape.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Shape.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -30,7 +27,7 @@ namespace operation
void Shape::accept(OperationVisitor &v) const { v.visit(*this); }
Shape::Shape(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Slice.cc b/runtime/onert/core/src/ir/operation/Slice.cc
index 888b563fb..1362c0f91 100644
--- a/runtime/onert/core/src/ir/operation/Slice.cc
+++ b/runtime/onert/core/src/ir/operation/Slice.cc
@@ -27,7 +27,7 @@ namespace operation
void Slice::accept(OperationVisitor &v) const { v.visit(*this); }
Slice::Slice(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(3u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(3u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Softmax.cc b/runtime/onert/core/src/ir/operation/Softmax.cc
index 3f1aa0af1..c06c85309 100644
--- a/runtime/onert/core/src/ir/operation/Softmax.cc
+++ b/runtime/onert/core/src/ir/operation/Softmax.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Softmax.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void Softmax::accept(OperationVisitor &v) const { v.visit(*this); }
Softmax::Softmax(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/SpaceToBatchND.cc b/runtime/onert/core/src/ir/operation/SpaceToBatchND.cc
index 53fab4fa9..94acccb0c 100644
--- a/runtime/onert/core/src/ir/operation/SpaceToBatchND.cc
+++ b/runtime/onert/core/src/ir/operation/SpaceToBatchND.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/SpaceToBatchND.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void SpaceToBatchND::accept(OperationVisitor &v) const { v.visit(*this); }
SpaceToBatchND::SpaceToBatchND(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(3u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(3u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/SpaceToDepth.cc b/runtime/onert/core/src/ir/operation/SpaceToDepth.cc
index d8a45aee5..08e7e5190 100644
--- a/runtime/onert/core/src/ir/operation/SpaceToDepth.cc
+++ b/runtime/onert/core/src/ir/operation/SpaceToDepth.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/SpaceToDepth.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void SpaceToDepth::accept(OperationVisitor &v) const { v.visit(*this); }
SpaceToDepth::SpaceToDepth(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Split.cc b/runtime/onert/core/src/ir/operation/Split.cc
index 244884e41..3e371188d 100644
--- a/runtime/onert/core/src/ir/operation/Split.cc
+++ b/runtime/onert/core/src/ir/operation/Split.cc
@@ -13,9 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
#include "ir/operation/Split.h"
-#include <cassert>
#include "ir/OperationVisitor.h"
+
namespace onert
{
namespace ir
@@ -25,7 +26,7 @@ namespace operation
void Split::accept(OperationVisitor &v) const { v.visit(*this); }
Split::Split(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
{
}
} // namespace operation
diff --git a/runtime/onert/core/src/ir/operation/SplitV.cc b/runtime/onert/core/src/ir/operation/SplitV.cc
index e638c9ac9..be13f167e 100644
--- a/runtime/onert/core/src/ir/operation/SplitV.cc
+++ b/runtime/onert/core/src/ir/operation/SplitV.cc
@@ -13,9 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
#include "ir/operation/SplitV.h"
-#include <cassert>
#include "ir/OperationVisitor.h"
+
namespace onert
{
namespace ir
@@ -25,7 +26,7 @@ namespace operation
void SplitV::accept(OperationVisitor &v) const { v.visit(*this); }
SplitV::SplitV(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param}
{
}
} // namespace operation
diff --git a/runtime/onert/core/src/ir/operation/SquaredDifference.cc b/runtime/onert/core/src/ir/operation/SquaredDifference.cc
index 49e58aaf2..db93903c7 100644
--- a/runtime/onert/core/src/ir/operation/SquaredDifference.cc
+++ b/runtime/onert/core/src/ir/operation/SquaredDifference.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/SquaredDifference.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void SquaredDifference::accept(OperationVisitor &v) const { v.visit(*this); }
SquaredDifference::SquaredDifference(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Squeeze.cc b/runtime/onert/core/src/ir/operation/Squeeze.cc
index 8cf928fb4..e059c4bee 100644
--- a/runtime/onert/core/src/ir/operation/Squeeze.cc
+++ b/runtime/onert/core/src/ir/operation/Squeeze.cc
@@ -28,7 +28,7 @@ void Squeeze::accept(OperationVisitor &v) const { v.visit(*this); }
Squeeze::Squeeze(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param(param)
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param(param)
{
}
diff --git a/runtime/onert/core/src/ir/operation/StatelessRandomUniform.cc b/runtime/onert/core/src/ir/operation/StatelessRandomUniform.cc
index cbb0ff251..94be0be86 100644
--- a/runtime/onert/core/src/ir/operation/StatelessRandomUniform.cc
+++ b/runtime/onert/core/src/ir/operation/StatelessRandomUniform.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/StatelessRandomUniform.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -30,7 +27,7 @@ void StatelessRandomUniform::accept(OperationVisitor &v) const { v.visit(*this);
StatelessRandomUniform::StatelessRandomUniform(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/StridedSlice.cc b/runtime/onert/core/src/ir/operation/StridedSlice.cc
index 2a7905995..a38282c93 100644
--- a/runtime/onert/core/src/ir/operation/StridedSlice.cc
+++ b/runtime/onert/core/src/ir/operation/StridedSlice.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/StridedSlice.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void StridedSlice::accept(OperationVisitor &v) const { v.visit(*this); }
StridedSlice::StridedSlice(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(4u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(4u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Tile.cc b/runtime/onert/core/src/ir/operation/Tile.cc
index 5ba3df2ad..51c1ff1dc 100644
--- a/runtime/onert/core/src/ir/operation/Tile.cc
+++ b/runtime/onert/core/src/ir/operation/Tile.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Tile.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -30,7 +27,7 @@ namespace operation
void Tile::accept(OperationVisitor &v) const { v.visit(*this); }
Tile::Tile(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/TopKV2.cc b/runtime/onert/core/src/ir/operation/TopKV2.cc
index a5e6c6a85..e1723d180 100644
--- a/runtime/onert/core/src/ir/operation/TopKV2.cc
+++ b/runtime/onert/core/src/ir/operation/TopKV2.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/TopKV2.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void TopKV2::accept(OperationVisitor &v) const { v.visit(*this); }
TopKV2::TopKV2(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Transpose.cc b/runtime/onert/core/src/ir/operation/Transpose.cc
index 3a663fbce..dbc5ef2aa 100644
--- a/runtime/onert/core/src/ir/operation/Transpose.cc
+++ b/runtime/onert/core/src/ir/operation/Transpose.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/Transpose.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -29,9 +26,8 @@ namespace operation
void Transpose::accept(OperationVisitor &v) const { v.visit(*this); }
-Transpose::Transpose(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
- const Param &param)
- : Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
+Transpose::Transpose(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs)
+ : Operation{OperandConstraint::createExact(2u), inputs, outputs}
{
}
diff --git a/runtime/onert/core/src/ir/operation/TransposeConv.cc b/runtime/onert/core/src/ir/operation/TransposeConv.cc
index 7f29ca44e..944cc365d 100644
--- a/runtime/onert/core/src/ir/operation/TransposeConv.cc
+++ b/runtime/onert/core/src/ir/operation/TransposeConv.cc
@@ -15,9 +15,6 @@
*/
#include "ir/operation/TransposeConv.h"
-
-#include <cassert>
-
#include "ir/OperationVisitor.h"
namespace onert
@@ -31,7 +28,7 @@ void TransposeConv::accept(OperationVisitor &v) const { v.visit(*this); }
TransposeConv::TransposeConv(const OperandIndexSequence &inputs,
const OperandIndexSequence &outputs, const Param &param)
- : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(3u), inputs, outputs}, _param{param}
{
}
diff --git a/runtime/onert/core/src/ir/operation/Unpack.cc b/runtime/onert/core/src/ir/operation/Unpack.cc
index 67aa54ab5..185eddce3 100644
--- a/runtime/onert/core/src/ir/operation/Unpack.cc
+++ b/runtime/onert/core/src/ir/operation/Unpack.cc
@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
#include "ir/operation/Unpack.h"
#include "ir/OperationVisitor.h"
@@ -25,7 +26,7 @@ namespace operation
void Unpack::accept(OperationVisitor &v) const { v.visit(*this); }
Unpack::Unpack(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createExact(1u), inputs, outputs}, _param{param}
{
}
} // namespace operation
diff --git a/runtime/onert/core/src/ir/operation/While.cc b/runtime/onert/core/src/ir/operation/While.cc
index 2505c60e3..f35996b07 100644
--- a/runtime/onert/core/src/ir/operation/While.cc
+++ b/runtime/onert/core/src/ir/operation/While.cc
@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
#include "ir/operation/While.h"
#include "ir/OperationVisitor.h"
@@ -25,7 +26,7 @@ namespace operation
void While::accept(OperationVisitor &v) const { v.visit(*this); }
While::While(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
- : Operation{OperandConstraint::createAny(), inputs, outputs}, _param{param}
+ : Operation{OperandConstraint::createAny(), inputs, outputs}, _param{param}
{
}
} // namespace operation
diff --git a/runtime/onert/core/src/ir/train/LossCode.cc b/runtime/onert/core/src/ir/train/LossCode.cc
new file mode 100644
index 000000000..eccae8cd7
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/LossCode.cc
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/LossCode.h"
+
+#include <unordered_map>
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+
+std::string toString(LossCode code)
+{
+ static const std::unordered_map<LossCode, const char *> map{
+ {LossCode::Undefined, "Undefined"},
+ {LossCode::MeanSquaredError, "MeanSquaredError"},
+ {LossCode::CategoricalCrossentropy, "CategoricalCrossentropy"}};
+ return map.at(code);
+}
+
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/OptimizerCode.cc b/runtime/onert/core/src/ir/train/OptimizerCode.cc
new file mode 100644
index 000000000..4ab689085
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/OptimizerCode.cc
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/OptimizerCode.h"
+
+#include <unordered_map>
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+
+std::string toString(OptimizerCode code)
+{
+ static const std::unordered_map<OptimizerCode, const char *> map{
+ {OptimizerCode::Undefined, "Undefined"},
+ {OptimizerCode::SGD, "SGD"},
+ {OptimizerCode::Adam, "Adam"}};
+ return map.at(code);
+}
+
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/TrainableGraph.cc b/runtime/onert/core/src/ir/train/TrainableGraph.cc
new file mode 100644
index 000000000..5ecdcc2cb
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/TrainableGraph.cc
@@ -0,0 +1,337 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/TrainableGraph.h"
+
+#include "ir/OperandIndexMap.h"
+#include "util/Utils.h"
+#include "util/Set.h"
+#include "../verifier/Verifier.h"
+
+#include <algorithm>
+#include <set>
+#include <map>
+#include <misc/polymorphic_downcast.h>
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+
+TrainableGraph::TrainableGraph() : _graph{} {}
+
+TrainableGraph::TrainableGraph(const TrainableGraph &tgraph)
+ : _graph{tgraph._graph}, _backward_operands{tgraph._backward_operands},
+ _training_defuses{tgraph._training_defuses}, _losses{tgraph._losses}
+{
+ tgraph.operations().iterate(
+ [&](const onert::ir::OperationIndex &index, const onert::ir::IOperation &op) {
+ replaceOperation(index, dynamic_cast<const ITrainableOperation &>(op).clone());
+ });
+}
+
+TrainableGraph::TrainableGraph(const Graph &graph) : _graph{graph} {}
+
+OperandIndex TrainableGraph::addOperand(const Shape &shape, const TypeInfo &type)
+{
+ return _graph.addOperand(shape, type);
+}
+
+OperandIndex TrainableGraph::addOperand(OperandIndex index, std::unique_ptr<Operand> &&operand)
+{
+ return _graph.addOperand(index, std::move(operand));
+}
+
+OperationIndex TrainableGraph::addOperation(std::unique_ptr<ITrainableOperation> &&operation)
+{
+ return _graph.addOperation(std::move(operation));
+}
+
+OperationIndex TrainableGraph::replaceOperation(OperationIndex index,
+ std::unique_ptr<ITrainableOperation> &&operation)
+{
+ return _graph.replaceOperation(index, std::move(operation));
+}
+
+OperandIndex TrainableGraph::addBackwardOperand(OperandIndex index,
+ std::unique_ptr<Operand> &&bwd_operand)
+{
+ return _backward_operands.push(std::move(bwd_operand), index);
+}
+
+IOIndex TrainableGraph::getInputIndex(const std::string &name) const
+{
+ return _graph.getInputIndex(name);
+}
+
+IOIndex TrainableGraph::getOutputIndex(const std::string &name) const
+{
+ return _graph.getOutputIndex(name);
+}
+
+void TrainableGraph::changeShape(const OperandIndex &index, const ir::Shape &new_shape)
+{
+ _graph.changeShape(index, new_shape);
+}
+
+void TrainableGraph::changeBackwardShape(const OperandIndex &index, const ir::Shape &new_shape)
+{
+ assert(_backward_operands.exist(index));
+ _backward_operands.at(index).info().shape(new_shape);
+}
+
+void TrainableGraph::addInput(const OperandIndex &ind, const std::string &name)
+{
+ _graph.addInput(ind, name);
+}
+
+void TrainableGraph::addOutput(const OperandIndex &ind, const std::string &name)
+{
+ _graph.addOutput(ind, name);
+}
+
+void TrainableGraph::verify(void) const
+{
+ _graph.verify();
+
+ operations().iterate([](const onert::ir::OperationIndex &, const onert::ir::IOperation &op) {
+ try
+ {
+ UNUSED_RELEASE(dynamic_cast<const onert::ir::train::ITrainableOperation &>(op));
+ }
+ catch (const std::bad_cast &)
+ {
+ throw std::runtime_error("TrainableGraph: " + op.name() + " is not a trainable operation");
+ }
+ });
+
+ verifyTrainingUseDefs();
+}
+
+void TrainableGraph::removeOperand(const OperandIndex &ind) { _graph.removeOperand(ind); }
+
+void TrainableGraph::setLayout(Layout layout) { _graph.setLayout(layout); }
+
+const ITrainableOperation &TrainableGraph::operation(OperationIndex index) const
+{
+ // NOTE Virtual inherited objects cannot be static_casted.
+ return dynamic_cast<const ITrainableOperation &>(_graph.operations().at(index));
+}
+
+void TrainableGraph::enableBackward(const OperationIndex &index)
+{
+ auto op = dynamic_cast<ir::train::ITrainableOperation *>(&_graph.operations().at(index));
+ assert(op);
+ op->enableBackward();
+}
+
+void TrainableGraph::disableBackward(const OperationIndex &index)
+{
+ auto &op = dynamic_cast<ir::train::ITrainableOperation &>(_graph.operations().at(index));
+ op.disableBackward();
+}
+
+void TrainableGraph::setTrainingUseDefs(const UseDefChains &training_defuses)
+{
+ _training_defuses.clear();
+ // TODO Replace this loop with `std::unordered_map::insert_range` since C++23
+ for (const auto &defuse_chain : training_defuses)
+ {
+ _training_defuses.emplace(defuse_chain.first, defuse_chain.second);
+ }
+}
+
+void TrainableGraph::validateTopologicalOrder(std::vector<ir::OperationIndex> order,
+ bool is_forward) const
+{
+ if (!is_forward)
+ std::reverse(order.begin(), order.end());
+
+ const std::string order_type = is_forward ? "forward" : "backward";
+
+ std::map<ir::OperationIndex, uint32_t> position;
+ for (uint32_t p = 0; p < order.size(); ++p)
+ {
+ auto index = order[p];
+ // TODO: replace this with `std::map::contains` after C++20
+ if (position.find(index) != position.end())
+ throw std::runtime_error{"Invalid " + order_type + " topological order: duplicate node @" +
+ std::to_string(index.value())};
+
+ position[index] = p;
+ }
+
+ operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &op) {
+ if (position.count(index) == 0)
+ return;
+
+ uint32_t p = position[index];
+
+ for (const auto &output : op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
+ {
+ const auto &operand = operands().at(output);
+ for (const auto &use : operand.getUses())
+ {
+ if (position.count(use) == 0)
+ continue;
+
+ uint32_t q = position[use];
+ if (p > q)
+ throw std::runtime_error{
+ "Invalid " + order_type + " topological order: inversion between @" +
+ std::to_string(index.value()) + " and @" + std::to_string(use.value())};
+ }
+ }
+ });
+}
+
+void TrainableGraph::validateForwardTopologicalOrder(
+ const std::vector<ir::OperationIndex> &order) const
+{
+ validateTopologicalOrder(order, true);
+}
+
+void TrainableGraph::validateBackwardTopologicalOrder(
+ const std::vector<ir::OperationIndex> &order) const
+{
+ validateTopologicalOrder(order, false);
+}
+
+void TrainableGraph::verifyTrainingUseDefs() const
+{
+ if (!verifier::DAGChecker().verify(_training_defuses))
+ throw std::runtime_error{"The training def-uses is cyclic."};
+ assert(verifier::EdgeChecker().verify(_training_defuses));
+}
+
+std::vector<ir::OperationIndex> TrainableGraph::topolSortOperations() const
+{
+ auto ret = _graph.topolSortOperations();
+ validateForwardTopologicalOrder(ret);
+
+ return ret;
+}
+
+std::vector<ir::OperationIndex> TrainableGraph::btopolSortOperations() const
+{
+ std::vector<ir::OperationIndex> ret;
+ util::Set<ir::OperationIndex> unvisited;
+ ir::OperationIndex loss_idx;
+ operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &op) {
+ unvisited.add(index);
+ if (op.opcode() == ir::OpCode::Loss)
+ {
+ assert(!loss_idx.valid()); // Should be only one loss
+ loss_idx = index;
+ }
+ });
+
+ std::function<void(const ir::OperationIndex &, const ir::IOperation &)> dfs =
+ [&](const ir::OperationIndex &index, const ir::IOperation &op) -> void {
+ if (!unvisited.contains(index))
+ return;
+ unvisited.remove(index);
+
+ for (const auto &input : op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
+ {
+ const auto &operand = operands().at(input);
+ const auto &def = operand.getDef();
+ if (!def.valid())
+ continue;
+ dfs(def, operations().at(def));
+ }
+
+ ret.push_back(index);
+ };
+
+ dfs(loss_idx, operations().at(loss_idx));
+ std::reverse(ret.begin(), ret.end());
+ validateBackwardTopologicalOrder(ret);
+
+ return ret;
+}
+
+std::vector<ir::OperationIndex> TrainableGraph::essentialBackwardOrder() const
+{
+ auto backward_order = btopolSortOperations();
+ // get rid of all nodes not reachable from a node with trainable parameters
+ backward_order = truncateBackwardOrder(backward_order, [&](const OperationIndex &index) {
+ return operation(index).isRequiredForBackward();
+ });
+
+ return truncateBackwardOrder(backward_order);
+}
+
+std::vector<ir::OperationIndex> TrainableGraph::truncateBackwardOrder(
+ std::vector<ir::OperationIndex> backward_order,
+ std::function<bool(const ir::OperationIndex &)> alive_cond) const
+{
+ auto forward_order = backward_order;
+ std::reverse(forward_order.begin(), forward_order.end());
+ std::set<ir::OperationIndex> alive;
+
+ for (const auto &index : forward_order)
+ {
+ if (alive_cond(index))
+ alive.insert(index);
+
+ // TODO: replace this with `std::set::contains` after C++20
+ if (alive.find(index) != alive.end())
+ {
+ const auto &op = operations().at(index);
+ for (const auto &output : op.getOutputs())
+ {
+ const auto &operand = operands().at(output);
+ for (const auto &use : operand.getUses())
+ alive.insert(use);
+ }
+ }
+ }
+
+ // TODO: replace this with `std::erase_if(std::vector)` after C++20
+ backward_order.erase(
+ std::remove_if(backward_order.begin(), backward_order.end(),
+ [&](const auto &index) { return alive.find(index) == alive.end(); }),
+ backward_order.end());
+ return backward_order;
+}
+
+std::vector<ir::OperationIndex>
+TrainableGraph::truncateBackwardOrder(const std::vector<ir::OperationIndex> &backward_order) const
+{
+ return truncateBackwardOrder(backward_order, [&](const ir::OperationIndex &index) {
+ const auto &trainable_op = operation(index);
+
+ return trainable_op.hasTrainableParameter();
+ });
+}
+
+void TrainableGraph::addLoss(const OperandIndex &loss_ind, const IOIndex &pred_ioind)
+{
+ _losses.emplace(pred_ioind, loss_ind);
+}
+
+OperandIndex TrainableGraph::getLossIndex(const IOIndex &pred_ioind) const
+{
+ auto itr = _losses.find(pred_ioind);
+ return (itr == _losses.end()) ? OperandIndex{} : itr->second;
+}
+
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/TrainableGraph.test.cc b/runtime/onert/core/src/ir/train/TrainableGraph.test.cc
new file mode 100644
index 000000000..84df22890
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/TrainableGraph.test.cc
@@ -0,0 +1,378 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/TrainableGraph.h"
+#include "ir/train/operation/BinaryArithmetic.h"
+#include "ir/train/operation/ElementwiseActivation.h"
+#include "ir/train/operation/FullyConnected.h"
+#include "ir/train/operation/Loss.h"
+#include "ir/train/LossInfo.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::ir;
+
+OperationIndex addAddOperation(train::TrainableGraph &tgraph, const OperandIndexSequence inputs,
+ const OperandIndexSequence outputs)
+{
+ // Add "ADD" operation
+ operation::BinaryArithmetic::Param param;
+ param.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+ param.activation = Activation::NONE;
+ auto add_op = operation::BinaryArithmetic(inputs, outputs, param);
+ return tgraph.addOperation(std::make_unique<train::operation::BinaryArithmetic>(add_op));
+}
+
+OperationIndex addElementwiseActivationOperation(train::TrainableGraph &tgraph,
+ const OperandIndexSequence inputs,
+ const OperandIndexSequence outputs)
+{
+ // Add "ElementwiseActivation" operation
+ operation::ElementwiseActivation::Param param;
+ auto ea_op = operation::ElementwiseActivation(inputs, outputs, param);
+ return tgraph.addOperation(std::make_unique<train::operation::ElementwiseActivation>(ea_op));
+}
+
+OperationIndex addFullyConnectedOperation(train::TrainableGraph &tgraph,
+ const OperandIndexSequence inputs,
+ const OperandIndexSequence outputs)
+{
+ // Add "FullyConnected" operation
+ operation::FullyConnected::Param param;
+ param.weights_format = FullyConnectedWeightsFormat::Default;
+ param.activation = Activation::NONE;
+ auto fc_op = operation::FullyConnected(inputs, outputs, param);
+ return tgraph.addOperation(std::make_unique<train::operation::FullyConnected>(fc_op));
+}
+
+OperationIndex addLossOperation(train::TrainableGraph &tgraph, const OperandIndexSequence inputs,
+ const OperandIndexSequence outputs)
+{
+ // Add "Loss" operation
+ auto loss_op = operation::Loss(inputs, outputs);
+ return tgraph.addOperation(std::make_unique<train::operation::Loss>(loss_op, train::LossInfo{}));
+}
+
+TEST(TrainableGraph, topological_sort_linear)
+{
+ train::TrainableGraph tgraph;
+
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+
+ /*
+ (input) ⎼[EA]⎼> (y_pred)
+ ╲
+ [Loss]⎼> (output)
+ ╱
+ (y_true)
+ */
+
+ auto input = tgraph.addOperand(shape, type);
+ auto y_pred = tgraph.addOperand(shape, type);
+ auto y_true = tgraph.addOperand(shape, type);
+ auto output = tgraph.addOperand(shape, type);
+
+ tgraph.addInput({input});
+ tgraph.addInput({y_true});
+ tgraph.addOutput({output});
+
+ addElementwiseActivationOperation(tgraph, {input}, {y_pred});
+ addLossOperation(tgraph, {y_pred, y_true}, {output});
+
+ EXPECT_NO_THROW(tgraph.topolSortOperations());
+ EXPECT_NO_THROW(tgraph.btopolSortOperations());
+}
+
+TEST(TrainableGraph, topological_sort_nonlinear)
+{
+ train::TrainableGraph tgraph;
+
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+
+ /*
+ [EA]⎼> (lhs)
+ ╱ ╲
+ (input) ⎼[EA]⎼> (split) [Add]⎼> (y_pred)
+ ╲ ╱ ╲
+ [EA]⎼> (rhs) [Loss]⎼> (output)
+ ╱
+ (y_true)
+ */
+
+ auto input = tgraph.addOperand(shape, type);
+ auto split = tgraph.addOperand(shape, type);
+ auto lhs = tgraph.addOperand(shape, type);
+ auto rhs = tgraph.addOperand(shape, type);
+ auto y_pred = tgraph.addOperand(shape, type);
+ auto y_true = tgraph.addOperand(shape, type);
+ auto output = tgraph.addOperand(shape, type);
+
+ tgraph.addInput({input});
+ tgraph.addInput({y_true});
+ tgraph.addOutput({output});
+
+ addElementwiseActivationOperation(tgraph, {input}, {split});
+ addElementwiseActivationOperation(tgraph, {split}, {lhs});
+ addElementwiseActivationOperation(tgraph, {split}, {rhs});
+ addAddOperation(tgraph, {lhs, rhs}, {y_pred});
+ addLossOperation(tgraph, {y_pred, y_true}, {output});
+
+ EXPECT_NO_THROW(tgraph.topolSortOperations());
+ EXPECT_NO_THROW(tgraph.btopolSortOperations());
+}
+
+TEST(TrainableGraph, neg_topological_sort_cycle)
+{
+ train::TrainableGraph tgraph;
+
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+
+ /*
+ (input) ⎼[Add]⎼> (v) ⎼[EA]
+ | |
+ v
+ (u) <⎼[EA]⎼ (y_pred)
+ ╲
+ [Loss]⎼> (output)
+ ╱
+ (y_true)
+ */
+
+ auto input = tgraph.addOperand(shape, type);
+ auto u = tgraph.addOperand(shape, type);
+ auto v = tgraph.addOperand(shape, type);
+ auto y_pred = tgraph.addOperand(shape, type);
+ auto y_true = tgraph.addOperand(shape, type);
+ auto output = tgraph.addOperand(shape, type);
+
+ tgraph.addInput({input});
+ tgraph.addInput({y_true});
+ tgraph.addOutput({output});
+
+ addAddOperation(tgraph, {input, u}, {v});
+ addElementwiseActivationOperation(tgraph, {v}, {y_pred});
+ addElementwiseActivationOperation(tgraph, {y_pred}, {u});
+ addLossOperation(tgraph, {y_pred, y_true}, {output});
+
+ EXPECT_ANY_THROW(tgraph.topolSortOperations());
+ EXPECT_ANY_THROW(tgraph.btopolSortOperations());
+}
+
+TEST(TrainableGraph, truncating_backward_topological_order_nonlinear)
+{
+ {
+ train::TrainableGraph tgraph;
+
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+
+ /*
+ [EA1]⎼> (u)
+ ╱ ╲
+ ╱ (weight1) ⎼[FC1]⎼> (v)
+ ╱ ╱ ╲
+ ╱ (bias1) [Add]⎼> (y_pred)
+ (input) ╱ ╲
+ ╲ ╱ [Loss]⎼> (output)
+ [EA2]⎼> (w) ╱ ╱
+ ╲ ╱ (y_true)
+ (weight2) ⎼[FC2]⎼> (x)
+ ╱
+ (bias2)
+ */
+
+ auto input = tgraph.addOperand(shape, type);
+ auto u = tgraph.addOperand(shape, type);
+ auto weight1 = tgraph.addOperand(shape, type);
+ auto bias1 = tgraph.addOperand(shape, type);
+ auto v = tgraph.addOperand(shape, type);
+ auto w = tgraph.addOperand(shape, type);
+ auto weight2 = tgraph.addOperand(shape, type);
+ auto bias2 = tgraph.addOperand(shape, type);
+ auto x = tgraph.addOperand(shape, type);
+ auto y_pred = tgraph.addOperand(shape, type);
+ auto y_true = tgraph.addOperand(shape, type);
+ auto output = tgraph.addOperand(shape, type);
+
+ tgraph.addInput({input});
+ tgraph.addInput({weight1});
+ tgraph.addInput({bias1});
+ tgraph.addInput({weight2});
+ tgraph.addInput({bias2});
+ tgraph.addInput({y_true});
+ tgraph.addOutput({output});
+
+ auto ea1 = addElementwiseActivationOperation(tgraph, {input}, {u});
+ auto fc1 = addFullyConnectedOperation(tgraph, {u, weight1, bias1}, {v});
+ auto ea2 = addElementwiseActivationOperation(tgraph, {input}, {w});
+ auto fc2 = addFullyConnectedOperation(tgraph, {w, weight2, bias2}, {x});
+ auto add = addAddOperation(tgraph, {v, x}, {y_pred});
+ auto loss = addLossOperation(tgraph, {y_pred, y_true}, {output});
+
+ std::vector<OperationIndex> expected_truncation_1{loss, add, fc1, fc2};
+ std::vector<OperationIndex> expected_truncation_2{loss, add, fc2, fc1};
+ std::vector<OperationIndex> truncation =
+ tgraph.truncateBackwardOrder(tgraph.btopolSortOperations());
+
+ ASSERT_TRUE(truncation == expected_truncation_1 || truncation == expected_truncation_2);
+ }
+
+ {
+ train::TrainableGraph tgraph;
+
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+
+ /*
+ (input1) ⎼[FC3]⎼> (r) ⎼⎼[Add]⎼> (s) ⎼[EA1]⎼> (u)
+ ╱ ╱ ╲
+ (weight3) ╱ (weight1) ⎼[FC1]⎼> (v)
+ ╱ ╱ ╲
+ ╱ ╱ ╲
+ ╱ (bias1) [Add]⎼> (y_pred)
+ (input) ╱ ╲
+ ╲ ╱ [Loss]⎼> (output)
+ ╲ ╱ ╱
+ [Add]⎼> (t) ⎼[EA2]⎼> (w) ╱ ╱
+ ╱ ╲ ╱ (y_true)
+ (input2) (weight2) ⎼[FC2]⎼> (x)
+ ╱
+ (bias2)
+ */
+
+ auto input1 = tgraph.addOperand(shape, type);
+ auto weight3 = tgraph.addOperand(shape, type);
+ auto r = tgraph.addOperand(shape, type);
+ auto input = tgraph.addOperand(shape, type);
+ auto s = tgraph.addOperand(shape, type);
+ auto input2 = tgraph.addOperand(shape, type);
+ auto t = tgraph.addOperand(shape, type);
+ auto u = tgraph.addOperand(shape, type);
+ auto weight1 = tgraph.addOperand(shape, type);
+ auto bias1 = tgraph.addOperand(shape, type);
+ auto v = tgraph.addOperand(shape, type);
+ auto w = tgraph.addOperand(shape, type);
+ auto weight2 = tgraph.addOperand(shape, type);
+ auto bias2 = tgraph.addOperand(shape, type);
+ auto x = tgraph.addOperand(shape, type);
+ auto y_pred = tgraph.addOperand(shape, type);
+ auto y_true = tgraph.addOperand(shape, type);
+ auto output = tgraph.addOperand(shape, type);
+
+ tgraph.addInput({input});
+ tgraph.addInput({weight1});
+ tgraph.addInput({bias1});
+ tgraph.addInput({weight2});
+ tgraph.addInput({bias2});
+ tgraph.addInput({y_true});
+ tgraph.addOutput({output});
+
+ auto fc3 = addFullyConnectedOperation(tgraph, {input1, weight3}, {r});
+ auto add1 = addAddOperation(tgraph, {r, input}, {s});
+ auto add2 = addAddOperation(tgraph, {input, input2}, {t});
+ auto ea1 = addElementwiseActivationOperation(tgraph, {s}, {u});
+ auto fc1 = addFullyConnectedOperation(tgraph, {u, weight1, bias1}, {v});
+ auto ea2 = addElementwiseActivationOperation(tgraph, {t}, {w});
+ auto fc2 = addFullyConnectedOperation(tgraph, {w, weight2, bias2}, {x});
+ auto add = addAddOperation(tgraph, {v, x}, {y_pred});
+ auto loss = addLossOperation(tgraph, {y_pred, y_true}, {output});
+
+ // This expected indices are base on dfs
+ std::vector<OperationIndex> expected_truncation_1{loss, add, fc1, ea1, add1, fc3, fc2};
+ std::vector<OperationIndex> expected_truncation_2{loss, add, fc2, fc1, ea1, add1, fc3};
+ std::vector<OperationIndex> truncation =
+ tgraph.truncateBackwardOrder(tgraph.btopolSortOperations());
+
+ ASSERT_TRUE(truncation == expected_truncation_1 || truncation == expected_truncation_2);
+ }
+}
+
+TEST(TrainableGraph, essential_backward_topological_order_nonlinear)
+{
+ {
+ train::TrainableGraph tgraph;
+
+ Shape shape{1, 2, 2, 1};
+ TypeInfo type{DataType::FLOAT32};
+
+ /*
+ (input1) ⎼[FC3]⎼> (r) ⎼⎼[Add]⎼> (s) ⎼[EA1]⎼> (u)
+ ╱ ╱ ╲
+ (weight3) ╱ (weight1) ⎼[FC1]⎼> (v)
+ ╱ ╱ ╲
+ ╱ ╱ ╲
+ ╱ (bias1) [Add]⎼> (y_pred)
+ (input) ╱ ╲
+ ╲ ╱ [Loss]⎼> (output)
+ ╲ ╱ ╱
+ [Add]⎼> (t) ⎼[EA2]⎼> (w) ╱ ╱
+ ╱ ╲ ╱ (y_true)
+ (input2) (weight2) ⎼[FC2]⎼> (x)
+ ╱
+ (bias2)
+ */
+
+ auto input1 = tgraph.addOperand(shape, type);
+ auto weight3 = tgraph.addOperand(shape, type);
+ auto r = tgraph.addOperand(shape, type);
+ auto input = tgraph.addOperand(shape, type);
+ auto s = tgraph.addOperand(shape, type);
+ auto input2 = tgraph.addOperand(shape, type);
+ auto t = tgraph.addOperand(shape, type);
+ auto u = tgraph.addOperand(shape, type);
+ auto weight1 = tgraph.addOperand(shape, type);
+ auto bias1 = tgraph.addOperand(shape, type);
+ auto v = tgraph.addOperand(shape, type);
+ auto w = tgraph.addOperand(shape, type);
+ auto weight2 = tgraph.addOperand(shape, type);
+ auto bias2 = tgraph.addOperand(shape, type);
+ auto x = tgraph.addOperand(shape, type);
+ auto y_pred = tgraph.addOperand(shape, type);
+ auto y_true = tgraph.addOperand(shape, type);
+ auto output = tgraph.addOperand(shape, type);
+
+ tgraph.addInput({input});
+ tgraph.addInput({weight1});
+ tgraph.addInput({bias1});
+ tgraph.addInput({weight2});
+ tgraph.addInput({bias2});
+ tgraph.addInput({y_true});
+ tgraph.addOutput({output});
+
+ auto fc3 = addFullyConnectedOperation(tgraph, {input1, weight3}, {r});
+ auto add1 = addAddOperation(tgraph, {r, input}, {s});
+ auto add2 = addAddOperation(tgraph, {input, input2}, {t});
+ auto ea1 = addElementwiseActivationOperation(tgraph, {s}, {u});
+ auto fc1 = addFullyConnectedOperation(tgraph, {u, weight1, bias1}, {v});
+ auto ea2 = addElementwiseActivationOperation(tgraph, {t}, {w});
+ auto fc2 = addFullyConnectedOperation(tgraph, {w, weight2, bias2}, {x});
+ auto add = addAddOperation(tgraph, {v, x}, {y_pred});
+ auto loss = addLossOperation(tgraph, {y_pred, y_true}, {output});
+
+ tgraph.enableBackward(fc2);
+ tgraph.enableBackward(fc3);
+
+ // These expected indices are base on dfs
+ std::vector<OperationIndex> expected_truncation_1{loss, add, fc1, ea1, add1, fc3, fc2};
+ std::vector<OperationIndex> expected_truncation_2{loss, add, fc2, fc1, ea1, add1, fc3};
+ std::vector<OperationIndex> essential = tgraph.essentialBackwardOrder();
+
+ ASSERT_TRUE(essential == expected_truncation_1 || essential == expected_truncation_2);
+ }
+}
diff --git a/runtime/onert/core/src/ir/train/TrainingInfo.cc b/runtime/onert/core/src/ir/train/TrainingInfo.cc
new file mode 100644
index 000000000..102781173
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/TrainingInfo.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/TrainingInfo.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+
+bool TrainingInfo::isValid() const
+{
+ if (_batch_size == 0)
+ return false;
+
+ if (_optimizer_info.optim_code == OptimizerCode::Undefined)
+ return false;
+
+ if (_optimizer_info.learning_rate <= 0.0f)
+ return false;
+
+ if (_loss_info.loss_code == LossCode::Undefined)
+ return false;
+
+ if (_loss_info.reduction_type == LossReductionType::Undefined)
+ return false;
+
+ // If there are invalid combination, add more condition-check here
+ return true;
+}
+
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/UseDefChain.cc b/runtime/onert/core/src/ir/train/UseDefChain.cc
new file mode 100644
index 000000000..9cb9bb7c9
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/UseDefChain.cc
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/UseDefChain.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+
+void UseDefChain::insertTrainingUse(const TrainingOperationIndex &idx) { _uses.insert(idx); }
+
+void UseDefChain::removeTrainingUse(const TrainingOperationIndex &idx) { _uses.erase(idx); }
+
+void UseDefChain::insertTrainingDef(const TrainingOperationIndex &idx)
+{
+ // defs must be valid
+ assert(idx.valid());
+ _defs.insert(idx);
+}
+
+void UseDefChain::removeTrainingDef(const TrainingOperationIndex &idx) { _defs.erase(idx); }
+
+void UseDefChain::clearTrainingUseDefs()
+{
+ _uses.clear();
+ _defs.clear();
+}
+
+bool UseDefChain::operator==(const UseDefChain &other) const
+{
+ return &_operand == &other._operand && _uses == other._uses && _defs == other._defs;
+}
+
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.cc b/runtime/onert/core/src/ir/train/UseDefGenerator.cc
new file mode 100644
index 000000000..615b1650c
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/UseDefGenerator.cc
@@ -0,0 +1,187 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "UseDefGenerator.h"
+
+#include "ir/train/TrainableGraph.h"
+#include "ir/train/Index.h"
+#include "../verifier/Verifier.h"
+
+#include <cassert>
+#include <memory>
+
+// TODO Reduce duplicate code
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+
+UseDefGenerator::UseDefGenerator(const TrainableGraph &tgraph)
+ : _tgraph{tgraph}, _node_to_idx{}, _training_usedefs{}
+{
+ const auto order = _tgraph.topolSortOperations();
+ for (const auto &index : order)
+ {
+ const auto &node = _tgraph.operation(index);
+ assert(_node_to_idx.find(&node) == _node_to_idx.end());
+ _node_to_idx[&node] = index;
+ }
+
+ // Check whether loss exists
+ assert(std::any_of(order.begin(), order.end(),
+ [&](const auto &index) {
+ return _tgraph.operation(index).opcode() == ir::OpCode::Loss;
+ }) &&
+ "Loss does not exist");
+}
+
+UseDefChains UseDefGenerator::operator()()
+{
+ const auto &graph = _tgraph.graph();
+ assert(ir::verifier::EdgeChecker().verify(graph));
+
+ _training_usedefs.clear();
+ graph.operands().iterate([&](const ir::OperandIndex &idx, const ir::Operand &operand) {
+ // Initialize as emtpy UseDefChain
+ const auto empty_usedef_chain = UseDefChain{operand};
+ _training_usedefs.emplace(TrainingOperandIndex{idx, true}, empty_usedef_chain);
+ _training_usedefs.emplace(TrainingOperandIndex{idx, false}, empty_usedef_chain);
+ });
+
+ initForForwardingNodes();
+
+ initForBackwardingNodes();
+
+ return _training_usedefs;
+}
+
+void UseDefGenerator::visit(const train::operation::Loss &node)
+{
+ assert(_node_to_idx.find(&node) != _node_to_idx.end());
+ const auto &op_index = _node_to_idx.at(&node);
+ const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
+
+ for (const auto &in_index : node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
+ {
+ // Insert use of forwarding inputs
+ const auto in_forwarding_index = TrainingOperandIndex{in_index, true};
+ insertUse(in_forwarding_index, backwarding_op_index);
+ }
+
+ // Set def of backwarding(backprop) y_pred
+ const auto &y_pred_index = node.getInputs().at(train::operation::Loss::Input::Y_PRED);
+ assert(!_tgraph.operands().at(y_pred_index).isConstant());
+ const auto y_pred_outgoing_index = TrainingOperandIndex{y_pred_index, false};
+ insertBackPropDef(y_pred_outgoing_index, backwarding_op_index);
+
+ // Set def of backwarding(backprop) y_true
+ const auto &y_true_index = node.getInputs().at(train::operation::Loss::Input::Y_TRUE);
+ assert(!_tgraph.operands().at(y_true_index).isConstant());
+ const auto y_true_outgoing_index = TrainingOperandIndex{y_true_index, false};
+ insertBackPropDef(y_true_outgoing_index, backwarding_op_index);
+
+ // Remove use of backwarding output
+ const auto &out_index = node.getOutputs().at(0);
+ const auto incoming_index = TrainingOperandIndex{out_index, false};
+ auto &usedef_chain = _training_usedefs.at(incoming_index);
+ usedef_chain.removeTrainingUse(backwarding_op_index);
+}
+
+void UseDefGenerator::insertUse(const TrainingOperandIndex &operand_index,
+ const TrainingOperationIndex &op_index)
+{
+ assert(_training_usedefs.find(operand_index) != _training_usedefs.end());
+ auto &usedef_chain = _training_usedefs.at(operand_index);
+ usedef_chain.insertTrainingUse(op_index);
+}
+
+void UseDefGenerator::insertDef(const TrainingOperandIndex &operand_index,
+ const TrainingOperationIndex &op_index)
+{
+ assert(operand_index.valid());
+
+ assert(_training_usedefs.find(operand_index) != _training_usedefs.end());
+ auto &usedef_chain = _training_usedefs.at(operand_index);
+ usedef_chain.insertTrainingDef(op_index);
+}
+
+void UseDefGenerator::insertBackPropDef(const TrainingOperandIndex &operand_index,
+ const TrainingOperationIndex &op_index)
+{
+ // NOTE There is no need to set def of constant backwarding(backprop) inputs
+ // because it won't be back-propagated.
+ if (!_tgraph.operands().at(operand_index.index()).isConstant())
+ {
+ insertDef(operand_index, op_index);
+ }
+}
+
+void UseDefGenerator::initForForwardingNodes()
+{
+ // Initialize training def-uses of forwarding operands for only forwarding nodes
+ // (i.e. forwarding nodes that do not have any backwarding node)
+ _tgraph.operands().iterate([&](const ir::OperandIndex &idx, const ir::Operand &operand) {
+ // Append forwarding def-uses as it is
+ const bool is_forward = true;
+ const auto forwarding_operand_index = TrainingOperandIndex{idx, is_forward};
+
+ const auto def = operand.getDef();
+ if (def.valid())
+ {
+ insertDef(forwarding_operand_index, TrainingOperationIndex{def, is_forward});
+ auto &usedef_chain = _training_usedefs.at(forwarding_operand_index);
+ usedef_chain.insertTrainingDef(TrainingOperationIndex{def, is_forward});
+ }
+
+ assert(_training_usedefs.at(forwarding_operand_index).getTrainingUses().size() == 0);
+ const auto uses = operand.getUses();
+ for (const auto &use : uses)
+ insertUse(forwarding_operand_index, TrainingOperationIndex{use, is_forward});
+ });
+}
+
+void UseDefGenerator::initForBackwardingNodes()
+{
+ const auto backward_order = _tgraph.essentialBackwardOrder();
+ // Initialize training uses of forwarding operands and def-uses of backwarding operands for
+ // backwarding nodes (i.e. backwarding nodes that do not have any forwarding node)
+ for (const auto &op_index : backward_order)
+ {
+ const auto &node = _tgraph.operation(op_index);
+
+ // Insert use of backwarding operands(only output)
+ {
+ if (node.getOutputs().size() > 1)
+ throw std::runtime_error(
+ "UseDefGenerator does not support multiple outputs of training operation");
+
+ const auto &output = node.getOutputs().at(0);
+ const auto backwarding_op_index = TrainingOperationIndex{op_index, false};
+ const auto incoming_index = TrainingOperandIndex{output, false};
+ insertUse(incoming_index, backwarding_op_index);
+ }
+
+ // Insert uses of forwarding operands and insert defs of backwarding operands
+ node.accept(*this);
+ }
+}
+
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.h b/runtime/onert/core/src/ir/train/UseDefGenerator.h
new file mode 100644
index 000000000..369d9a223
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/UseDefGenerator.h
@@ -0,0 +1,87 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__
+#define __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__
+
+#include "ir/train/TrainableOperationVisitor.h"
+
+#include "ir/train/UseDefChains.h"
+#include "ir/train/Operations.Include.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+class TrainableGraph;
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+
+struct UseDefGeneratorBase : public TrainableOperationVisitor
+{
+ virtual ~UseDefGeneratorBase() = default;
+
+protected:
+#define OP(InternalName) \
+ virtual void visit(const operation::InternalName &) override \
+ { \
+ throw std::runtime_error("UseDefGenerator: NYI for operation '" #InternalName "'"); \
+ }
+#include "ir/train/Operations.lst"
+#undef OP
+};
+
+class UseDefGenerator : public UseDefGeneratorBase
+{
+public:
+ UseDefGenerator(void) = delete;
+ UseDefGenerator(const TrainableGraph &tgraph);
+
+public:
+ UseDefChains operator()();
+
+public:
+ void visit(const train::operation::Loss &node) override;
+
+private:
+ void insertUse(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index);
+ void insertDef(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index);
+ void insertBackPropDef(const TrainingOperandIndex &operand_index,
+ const TrainingOperationIndex &op_index);
+ void initForForwardingNodes();
+ void initForBackwardingNodes();
+
+private:
+ const TrainableGraph &_tgraph;
+ std::unordered_map<const ITrainableOperation *, OperationIndex> _node_to_idx;
+ UseDefChains _training_usedefs;
+};
+
+} // namespace train
+} // namespace ir
+} // namespace onert
+
+#endif // __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__
diff --git a/runtime/onert/core/src/ir/train/operation/BinaryArithmetic.cc b/runtime/onert/core/src/ir/train/operation/BinaryArithmetic.cc
new file mode 100644
index 000000000..473d38735
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/BinaryArithmetic.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/BinaryArithmetic.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> BinaryArithmetic::clone() const
+{
+ return std::make_unique<BinaryArithmetic>(*this);
+}
+
+void BinaryArithmetic::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void BinaryArithmetic::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+BinaryArithmetic::BinaryArithmetic(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/Conv2D.cc b/runtime/onert/core/src/ir/train/operation/Conv2D.cc
new file mode 100644
index 000000000..923861ae3
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/Conv2D.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/Conv2D.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> Conv2D::clone() const
+{
+ return std::make_unique<Conv2D>(*this);
+}
+
+void Conv2D::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void Conv2D::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+Conv2D::Conv2D(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/DepthwiseConv2D.cc b/runtime/onert/core/src/ir/train/operation/DepthwiseConv2D.cc
new file mode 100644
index 000000000..2a7289619
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/DepthwiseConv2D.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/DepthwiseConv2D.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> DepthwiseConv2D::clone() const
+{
+ return std::make_unique<DepthwiseConv2D>(*this);
+}
+
+void DepthwiseConv2D::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void DepthwiseConv2D::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+DepthwiseConv2D::DepthwiseConv2D(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/ElementwiseActivation.cc b/runtime/onert/core/src/ir/train/operation/ElementwiseActivation.cc
new file mode 100644
index 000000000..1dae3f674
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/ElementwiseActivation.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/ElementwiseActivation.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> ElementwiseActivation::clone() const
+{
+ return std::make_unique<ElementwiseActivation>(*this);
+}
+
+void ElementwiseActivation::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void ElementwiseActivation::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+ElementwiseActivation::ElementwiseActivation(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/FullyConnected.cc b/runtime/onert/core/src/ir/train/operation/FullyConnected.cc
new file mode 100644
index 000000000..a26f7c489
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/FullyConnected.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/FullyConnected.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> FullyConnected::clone() const
+{
+ return std::make_unique<FullyConnected>(*this);
+}
+
+void FullyConnected::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void FullyConnected::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+FullyConnected::FullyConnected(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/Loss.cc b/runtime/onert/core/src/ir/train/operation/Loss.cc
new file mode 100644
index 000000000..3a89e0ff6
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/Loss.cc
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/Loss.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+#include <misc/polymorphic_downcast.h>
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> Loss::clone() const { return std::make_unique<Loss>(*this); }
+
+void Loss::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void Loss::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+Loss::Loss(const OperationType &operation, const LossInfo &param)
+ : OperationType{operation.getInputs(), operation.getOutputs()}, _param{param}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/Pad.cc b/runtime/onert/core/src/ir/train/operation/Pad.cc
new file mode 100644
index 000000000..56394f5ef
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/Pad.cc
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/Pad.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> Pad::clone() const { return std::make_unique<Pad>(*this); }
+
+void Pad::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void Pad::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+Pad::Pad(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/Permute.cc b/runtime/onert/core/src/ir/train/operation/Permute.cc
new file mode 100644
index 000000000..adc23aa49
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/Permute.cc
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/Permute.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> Permute::clone() const
+{
+ return std::make_unique<Permute>(*this);
+}
+
+void Permute::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void Permute::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+Permute::Permute(const OperationType &operation)
+ : OperationType{operation.getInputs().at(0), operation.getOutputs().at(0),
+ operation.getPermuteType()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/Pool2D.cc b/runtime/onert/core/src/ir/train/operation/Pool2D.cc
new file mode 100644
index 000000000..021574f19
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/Pool2D.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/Pool2D.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> Pool2D::clone() const
+{
+ return std::make_unique<Pool2D>(*this);
+}
+
+void Pool2D::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void Pool2D::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+Pool2D::Pool2D(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/Reduce.cc b/runtime/onert/core/src/ir/train/operation/Reduce.cc
new file mode 100644
index 000000000..51986a0c2
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/Reduce.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/Reduce.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> Reduce::clone() const
+{
+ return std::make_unique<Reduce>(*this);
+}
+
+void Reduce::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void Reduce::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+Reduce::Reduce(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/Reshape.cc b/runtime/onert/core/src/ir/train/operation/Reshape.cc
new file mode 100644
index 000000000..c76158607
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/Reshape.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/Reshape.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> Reshape::clone() const
+{
+ return std::make_unique<Reshape>(*this);
+}
+
+void Reshape::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void Reshape::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+Reshape::Reshape(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/Softmax.cc b/runtime/onert/core/src/ir/train/operation/Softmax.cc
new file mode 100644
index 000000000..dbd403879
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/Softmax.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/Softmax.h"
+
+#include "ir/OperationVisitor.h"
+#include "ir/train/TrainableOperationVisitor.h"
+
+namespace onert
+{
+namespace ir
+{
+namespace train
+{
+namespace operation
+{
+
+std::unique_ptr<ITrainableOperation> Softmax::clone() const
+{
+ return std::make_unique<Softmax>(*this);
+}
+
+void Softmax::accept(OperationVisitor &v) const { v.visit(*this); }
+
+void Softmax::accept(TrainableOperationVisitor &v) const { v.visit(*this); }
+
+Softmax::Softmax(const OperationType &operation)
+ : OperationType{operation.getInputs(), operation.getOutputs(), operation.param()}
+{
+ // DO NOTHING
+}
+
+} // namespace operation
+} // namespace train
+} // namespace ir
+} // namespace onert
diff --git a/runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc b/runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc
new file mode 100644
index 000000000..e3472ec51
--- /dev/null
+++ b/runtime/onert/core/src/ir/train/operation/UntrainableOperation.test.cc
@@ -0,0 +1,1239 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ir/train/operation/UntrainableOperation.h"
+
+#include "ir/Operations.Include.h"
+
+#include <gtest/gtest.h>
+
+using namespace ::onert::ir;
+
+operation::AddN generateAddN()
+{
+ return operation::AddN{OperandIndexSequence{1, 2}, OperandIndexSequence{0}};
+}
+
+operation::ArgMinMax generateArgMinMax()
+{
+ operation::ArgMinMax::Param param;
+ param.output_type = DataType::FLOAT32;
+
+ return operation::ArgMinMax{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param};
+}
+
+operation::BatchMatMul generateBatchMatMul()
+{
+ operation::BatchMatMul::Param param;
+ param.adj_x = true;
+ param.adj_y = true;
+
+ return operation::BatchMatMul{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param};
+}
+
+operation::BatchToSpaceND generateBatchToSpaceND()
+{
+ return operation::BatchToSpaceND{OperandIndexSequence{1, 2}, OperandIndexSequence{0}};
+}
+
+operation::BCQFullyConnected generateBCQFullyConnected()
+{
+ operation::BCQFullyConnected::Param param;
+ param.activation = Activation::NONE;
+ param.weights_hidden_size = 1;
+
+ return operation::BCQFullyConnected{OperandIndexSequence{1, 2, 3, 4, 5}, OperandIndexSequence{0},
+ param};
+}
+
+operation::BCQGather generateBCQGather()
+{
+ operation::BCQGather::Param param;
+ param.axis = 0;
+ param.input_hidden_size = 1;
+
+ return operation::BCQGather{OperandIndexSequence{1, 2, 3, 4}, OperandIndexSequence{0}, param};
+}
+
+operation::BinaryArithmetic generateBinaryArithmetic()
+{
+ operation::BinaryArithmetic::Param param;
+ param.activation = Activation::NONE;
+ param.arithmetic_type = operation::BinaryArithmetic::ArithmeticType::ADD;
+
+ return operation::BinaryArithmetic{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param};
+}
+
+operation::BroadcastTo generateBroadcastTo()
+{
+ return operation::BroadcastTo{OperandIndexSequence{1, 2}, OperandIndexSequence{0}};
+}
+
+operation::Bulk generateBulk()
+{
+ operation::Bulk::Param param;
+ param.binary_path = "";
+ param.origin_input_shapes = std::vector<onert::ir::Shape>{};
+ param.origin_output_shapes = std::vector<onert::ir::Shape>{};
+
+ return operation::Bulk{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param};
+}
+
+operation::Comparison generateComparison()
+{
+ operation::Comparison::Param param;
+ param.comparison_type = operation::Comparison::ComparisonType::Equal;
+
+ return operation::Comparison{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param};
+}
+
+operation::Concat generateConcat()
+{
+ operation::Concat::Param param;
+ param.axis = 0;
+
+ return operation::Concat{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param};
+}
+
+operation::Conv2D generateConv2D()
+{
+ operation::Conv2D::Param param;
+ param.activation = Activation::NONE;
+ param.dilation = Dilation{};
+ param.padding = Padding{};
+ param.stride = Stride{};
+
+ return operation::Conv2D{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param};
+}
+
+operation::ConvertFp16ToFp32 generateConvertFp16ToFp32()
+{
+ return operation::ConvertFp16ToFp32{OperandIndexSequence{1}, OperandIndexSequence{0}};
+}
+
+operation::ConvertFp32ToFp16 generateConvertFp32ToFp16()
+{
+ return operation::ConvertFp32ToFp16{OperandIndexSequence{1}, OperandIndexSequence{0}};
+}
+
+operation::Custom generateCustom()
+{
+ return operation::Custom{OperandConstraint::createExact(1u), OperandIndexSequence{1},
+ OperandIndexSequence{0}, std::string("id"),
+ operation::Custom::Userdata{}};
+}
+
+operation::DepthToSpace generateDepthToSpace()
+{
+ operation::DepthToSpace::Param param;
+ param.block_size = 1;
+
+ return operation::DepthToSpace{OperandIndexSequence{1}, OperandIndexSequence{0}, param};
+}
+
+operation::DepthwiseConv2D generateDepthwiseConv2D()
+{
+ operation::DepthwiseConv2D::Param param;
+ param.activation = Activation::NONE;
+ param.dilation = Dilation{};
+ param.multiplier = 1u;
+ param.padding = Padding{};
+ param.stride = Stride{};
+
+ return operation::DepthwiseConv2D{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param};
+}
+
+operation::DetectionPostProcess generateDetectionPostProcess()
+{
+ operation::DetectionPostProcess::Param param;
+
+ return operation::DetectionPostProcess{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0},
+ param};
+}
+
+operation::Einsum generateEinsum()
+{
+ operation::Einsum::Param param;
+ param.equation = "";
+
+ return operation::Einsum{OperandIndexSequence{1}, OperandIndexSequence{0}, param};
+}
+
+operation::ElementwiseActivation generateElementwiseActivation()
+{
+ operation::ElementwiseActivation::Param param;
+ param.alpha = 0.f;
+ param.beta = 0.f;
+ param.op_type = operation::ElementwiseActivation::Type::ELU;
+
+ return operation::ElementwiseActivation{OperandIndexSequence{1}, OperandIndexSequence{0}, param};
+}
+
+operation::ElementwiseBinary generateElementwiseBinary()
+{
+ operation::ElementwiseBinary::Param param;
+ param.op_type = operation::ElementwiseBinary::ElementwiseBinaryType::FLOOR_DIV;
+
+ return operation::ElementwiseBinary{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param};
+}
+
+operation::ElementwiseUnary generateElementwiseUnary()
+{
+ operation::ElementwiseUnary::Param param;
+ param.op_type = operation::ElementwiseUnary::Type::ABS;
+
+ return operation::ElementwiseUnary{OperandIndexSequence{1}, OperandIndexSequence{0}, param};
+}
+
+operation::EmbeddingLookup generateEmbeddingLookup()
+{
+ return operation::EmbeddingLookup{OperandIndexSequence{1, 2}, OperandIndexSequence{0}};
+}
+
+operation::ExpandDims generateExpandDims()
+{
+ return operation::ExpandDims{OperandIndexSequence{1, 2}, OperandIndexSequence{0}};
+}
+
+operation::Fill generateFill()
+{
+ return operation::Fill{OperandIndexSequence{1, 2}, OperandIndexSequence{0}};
+}
+
+operation::FullyConnected generateFullyConnected()
+{
+ operation::FullyConnected::Param param;
+ param.activation = Activation::NONE;
+ param.weights_format = FullyConnectedWeightsFormat::Default;
+
+ return operation::FullyConnected{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param};
+}
+
+operation::FusedBatchNorm generateFusedBatchNorm()
+{
+ operation::FusedBatchNorm::Param param;
+ param.is_training = false;
+ param.epsilon = 0.f;
+ param.data_format = "";
+
+ return operation::FusedBatchNorm{OperandIndexSequence{1, 2, 3, 4, 5}, OperandIndexSequence{0},
+ param};
+}
+
+operation::Gather generateGather()
+{
+ operation::Gather::Param param;
+ param.axis = 0;
+
+ return operation::Gather{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param};
+}
+
+operation::HashtableLookup generateHashtableLookup()
+{
+ return operation::HashtableLookup{OperandIndexSequence{2, 3, 4}, OperandIndexSequence{0, 1}};
+}
+
+operation::If generateIf()
+{
+ operation::If::Param param;
+ param.else_subg_index = 1;
+ param.then_subg_index = 2;
+
+ return operation::If{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param};
+}
+
+operation::InstanceNorm generateInstanceNorm()
+{
+ operation::InstanceNorm::Param param;
+ param.activation = Activation::NONE;
+ param.epsilon = 0.f;
+
+ return operation::InstanceNorm{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param};
+}
+
+operation::L2Normalization generateL2Normalization()
+{
+ return operation::L2Normalization{OperandIndexSequence{1}, OperandIndexSequence{0}};
+}
+
+operation::LocalResponseNormalization generateLocalResponseNormalization()
+{
+ operation::LocalResponseNormalization::Param param;
+ param.alpha = 0.f;
+ param.beta = 0.f;
+ param.bias = 0.f;
+ param.radius = 1;
+
+ return operation::LocalResponseNormalization{OperandIndexSequence{1}, OperandIndexSequence{0},
+ param};
+}
+
+operation::LogSoftmax generateLogSoftmax()
+{
+ operation::LogSoftmax::Param param;
+ param.axis = 0;
+ param.beta = 0.f;
+
+ return operation::LogSoftmax{OperandIndexSequence{1}, OperandIndexSequence{0}, param};
+}
+
+operation::LSTM generateLSTM()
+{
+ operation::LSTM::Param param;
+ param.activation = Activation::NONE;
+ param.cell_threshold = 1.f;
+ param.projection_threshold = 1.f;
+ param.time_major = true;
+
+ return operation::LSTM{
+ OperandIndexSequence{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20},
+ OperandIndexSequence{0}, param};
+}
+
+operation::MatrixBandPart generateMatrixBandPart()
+{
+ return operation::MatrixBandPart{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}};
+}
+
+operation::OneHot generateOneHot()
+{
+ operation::OneHot::Param param;
+ param.axis = 0;
+
+ return operation::OneHot{OperandIndexSequence{1, 2, 3, 4}, OperandIndexSequence{0}, param};
+}
+
+operation::Pack generatePack()
+{
+ operation::Pack::Param param;
+ param.axis = 0;
+ param.num = 1;
+
+ return operation::Pack{OperandIndexSequence{1}, OperandIndexSequence{0}, param};
+}
+
+operation::Pad generatePad()
+{
+ return operation::Pad{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}};
+}
+
+operation::Permute generatePermute()
+{
+ return operation::Permute{OperandIndex{1}, OperandIndex{0}, operation::Permute::Type::COPY};
+}
+
+operation::Pool2D generatePool2D()
+{
+ operation::Pool2D::Param param;
+ param.activation = Activation::NONE;
+ param.kh = 1;
+ param.kw = 1;
+ param.op_type = operation::Pool2D::PoolType::AVG;
+ param.padding = Padding{};
+ param.stride = Stride{};
+
+ return operation::Pool2D{OperandIndexSequence{1}, OperandIndexSequence{0}, param};
+}
+
+operation::Pow generatePow()
+{
+ return operation::Pow{OperandIndexSequence{1, 2}, OperandIndexSequence{0}};
+}
+
+operation::PReLU generatePReLU()
+{
+ return operation::PReLU{OperandIndexSequence{1, 2}, OperandIndexSequence{0}};
+}
+
+operation::Range generateRange()
+{
+ return operation::Range{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}};
+}
+
+operation::Rank generateRank()
+{
+ return operation::Rank{OperandIndexSequence{1}, OperandIndexSequence{0}};
+}
+
+operation::Reduce generateReduce()
+{
+ operation::Reduce::Param param;
+ param.keep_dims = true;
+ param.reduce_type = operation::Reduce::ReduceType::ALL;
+
+ return operation::Reduce{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param};
+}
+
+operation::Reshape generateReshape()
+{
+ operation::Reshape::Param param;
+ param.new_shape = std::vector<int32_t>{1};
+
+ return operation::Reshape{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param};
+}
+
+operation::ResizeBilinear generateResizeBilinear()
+{
+ operation::ResizeBilinear::Param param;
+ param.align_corners = true;
+ param.half_pixel_centers = true;
+ param.height_out = 1;
+ param.width_out = 1;
+
+ return operation::ResizeBilinear{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param};
+}
+
+operation::ResizeNearestNeighbor generateResizeNearestNeighbor()
+{
+ operation::ResizeNearestNeighbor::Param param;
+ param.align_corners = true;
+ param.height_out = 1;
+ param.width_out = 1;
+
+ return operation::ResizeNearestNeighbor{OperandIndexSequence{1, 2}, OperandIndexSequence{0},
+ param};
+}
+
+operation::Reverse generateReverse()
+{
+ return operation::Reverse{OperandIndexSequence{1, 2}, OperandIndexSequence{0}};
+}
+
+operation::RNN generateRNN()
+{
+ operation::RNN::Param param;
+ param.activation = Activation::NONE;
+
+ return operation::RNN{OperandIndexSequence{1, 2, 3, 4, 5}, OperandIndexSequence{0}, param};
+}
+
+operation::Select generateSelect()
+{
+ return operation::Select{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}};
+}
+
+operation::Shape generateShape()
+{
+ return operation::Shape{OperandIndexSequence{1}, OperandIndexSequence{0}};
+}
+
+operation::Slice generateSlice()
+{
+ return operation::Slice{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}};
+}
+
+operation::Softmax generateSoftmax()
+{
+ operation::Softmax::Param param;
+ param.beta = 0.1f;
+
+ return operation::Softmax{OperandIndexSequence{1}, OperandIndexSequence{0}, param};
+}
+
+operation::SpaceToBatchND generateSpaceToBatchND()
+{
+ return operation::SpaceToBatchND{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}};
+}
+
+operation::SpaceToDepth generateSpaceToDepth()
+{
+ operation::SpaceToDepth::Param param;
+ param.block_size = 1;
+
+ return operation::SpaceToDepth{OperandIndexSequence{1}, OperandIndexSequence{0}, param};
+}
+
+operation::Split generateSplit()
+{
+ operation::Split::Param param;
+ param.num_splits = 1;
+
+ return operation::Split{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param};
+}
+
+operation::SplitV generateSplitV()
+{
+ operation::SplitV::Param param;
+ param.num_splits = 1;
+
+ return operation::SplitV{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param};
+}
+
+operation::SquaredDifference generateSquaredDifference()
+{
+ return operation::SquaredDifference{OperandIndexSequence{1, 2}, OperandIndexSequence{0}};
+}
+
+operation::Squeeze generateSqueeze()
+{
+ operation::Squeeze::Param param;
+ param.dims[0] = 1;
+ param.ndim = 1;
+
+ return operation::Squeeze{OperandIndexSequence{1}, OperandIndexSequence{0}, param};
+}
+
+operation::StatelessRandomUniform generateStatelessRandomUniform()
+{
+ return operation::StatelessRandomUniform{OperandIndexSequence{1, 2}, OperandIndexSequence{0}};
+}
+
+operation::StridedSlice generateStridedSlice()
+{
+ operation::StridedSlice::Param param;
+ param.begin_mask = 1;
+ param.end_mask = 1;
+ param.shrink_axis_mask = 1;
+
+ return operation::StridedSlice{OperandIndexSequence{1, 2, 3, 4}, OperandIndexSequence{0}, param};
+}
+
+operation::Tile generateTile()
+{
+ return operation::Tile{OperandIndexSequence{1, 2}, OperandIndexSequence{0}};
+}
+
+operation::TopKV2 generateTopKV2()
+{
+ operation::TopKV2::Param param;
+ param.k = 1;
+
+ return operation::TopKV2{OperandIndexSequence{1}, OperandIndexSequence{0}, param};
+}
+
+operation::Transpose generateTranspose()
+{
+ return operation::Transpose{OperandIndexSequence{1, 2}, OperandIndexSequence{0}};
+}
+
+operation::TransposeConv generateTransposeConv()
+{
+ operation::TransposeConv::Param param;
+ param.padding = Padding();
+ param.stride = Stride();
+
+ return operation::TransposeConv{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param};
+}
+
+operation::Unpack generateUnpack()
+{
+ operation::Unpack::Param param;
+ param.axis = 0;
+ param.num = 1;
+
+ return operation::Unpack{OperandIndexSequence{1}, OperandIndexSequence{0}, param};
+}
+
+operation::While generateWhile()
+{
+ operation::While::Param param;
+ param.cond_subg_index = 1;
+ param.body_subg_index = 2;
+
+ return operation::While{OperandIndexSequence{1, 2, 3}, OperandIndexSequence{0}, param};
+}
+
+class MockOperationVisitor : public OperationVisitor
+{
+public:
+ void invoke(Operation &op) { op.accept(*this); }
+
+#define OP(InternalName) \
+ virtual void visit(const operation::InternalName &) override { visit_flag = true; }
+#include "ir/Operations.lst"
+#undef OP
+
+public:
+ // TODO Replace this flag with using GMOCK if necessary
+ bool visit_flag{false};
+};
+
+template <typename OperationType> auto generateUntrainableOperation(const OperationType &op)
+{
+ return std::make_unique<train::operation::UntrainableOperation<OperationType>>(op);
+}
+
+template <typename OperationType> void verifyOp(const OperationType &op)
+{
+ auto untrainable = generateUntrainableOperation(op);
+ EXPECT_EQ(untrainable->opcode(), op.opcode());
+ EXPECT_EQ(untrainable->getInputs(), op.getInputs());
+ EXPECT_EQ(untrainable->getOutputs(), op.getOutputs());
+
+ // Check clone
+ auto clone = untrainable->clone();
+ EXPECT_TRUE(clone != nullptr);
+ EXPECT_EQ(clone->hasTrainableParameter(), untrainable->hasTrainableParameter());
+ EXPECT_EQ(clone->opcode(), untrainable->opcode());
+ EXPECT_EQ(clone->getInputs(), untrainable->getInputs());
+ EXPECT_EQ(clone->getOutputs(), untrainable->getOutputs());
+
+ // Check downcast
+ const auto derived =
+ dynamic_cast<train::operation::UntrainableOperation<OperationType> *>(clone.get());
+ EXPECT_TRUE(derived != nullptr);
+ EXPECT_EQ(clone->hasTrainableParameter(), untrainable->hasTrainableParameter());
+ EXPECT_EQ(derived->opcode(), op.opcode());
+ EXPECT_EQ(derived->getInputs(), op.getInputs());
+ EXPECT_EQ(derived->getOutputs(), op.getOutputs());
+
+ // Check visitor
+ MockOperationVisitor visitor;
+
+ visitor.visit_flag = false;
+ visitor.invoke(*untrainable);
+ EXPECT_TRUE(visitor.visit_flag);
+}
+
+TEST(UntrainableOperation, testAllOps)
+{
+ const auto addn = generateAddN();
+ verifyOp(addn);
+
+ const auto argminmax = generateArgMinMax();
+ verifyOp(argminmax);
+
+ const auto batch_matmul = generateBatchMatMul();
+ verifyOp(batch_matmul);
+
+ const auto batch_to_spacend = generateBatchToSpaceND();
+ verifyOp(batch_to_spacend);
+
+ const auto bcq_fc = generateBCQFullyConnected();
+ verifyOp(bcq_fc);
+
+ const auto bcq_gather = generateBCQGather();
+ verifyOp(bcq_gather);
+
+ const auto binary_arithmetic = generateBinaryArithmetic();
+ verifyOp(binary_arithmetic);
+
+ const auto broadcast = generateBroadcastTo();
+ verifyOp(broadcast);
+
+ const auto bulk = generateBulk();
+ verifyOp(bulk);
+
+ const auto comparison = generateComparison();
+ verifyOp(comparison);
+
+ const auto concat = generateConcat();
+ verifyOp(concat);
+
+ const auto conv2d = generateConv2D();
+ verifyOp(conv2d);
+
+ const auto fp16_to_fp32 = generateConvertFp16ToFp32();
+ verifyOp(fp16_to_fp32);
+
+ const auto fp32_to_fp16 = generateConvertFp32ToFp16();
+ verifyOp(fp32_to_fp16);
+
+ const auto custom = generateCustom();
+ verifyOp(custom);
+
+ const auto depth_to_space = generateDepthToSpace();
+ verifyOp(depth_to_space);
+
+ const auto depthwise_conv2d = generateDepthwiseConv2D();
+ verifyOp(depthwise_conv2d);
+
+ const auto detection = generateDetectionPostProcess();
+ verifyOp(detection);
+
+ const auto einsum = generateEinsum();
+ verifyOp(einsum);
+
+ const auto activation = generateElementwiseActivation();
+ verifyOp(activation);
+
+ const auto binary = generateElementwiseBinary();
+ verifyOp(binary);
+
+ const auto unary = generateElementwiseUnary();
+ verifyOp(unary);
+
+ const auto embed = generateEmbeddingLookup();
+ verifyOp(embed);
+
+ const auto expand_dims = generateExpandDims();
+ verifyOp(expand_dims);
+
+ const auto fill = generateFill();
+ verifyOp(fill);
+
+ const auto fc = generateFullyConnected();
+ verifyOp(fc);
+
+ const auto fused_batch_norm = generateFusedBatchNorm();
+ verifyOp(fused_batch_norm);
+
+ const auto gather = generateGather();
+ verifyOp(gather);
+
+ const auto hashtable = generateHashtableLookup();
+ verifyOp(hashtable);
+
+ const auto if_op = generateIf();
+ verifyOp(if_op);
+
+ const auto in_norm = generateInstanceNorm();
+ verifyOp(in_norm);
+
+ const auto l2_norm = generateL2Normalization();
+ verifyOp(l2_norm);
+
+ const auto local_norm = generateLocalResponseNormalization();
+ verifyOp(local_norm);
+
+ const auto log_softmax = generateLogSoftmax();
+ verifyOp(log_softmax);
+
+ const auto lstm = generateLSTM();
+ verifyOp(lstm);
+
+ const auto maxrix_band_part = generateMatrixBandPart();
+ verifyOp(maxrix_band_part);
+
+ const auto one_hot = generateOneHot();
+ verifyOp(one_hot);
+
+ const auto pack = generatePack();
+ verifyOp(pack);
+
+ const auto pad = generatePad();
+ verifyOp(pad);
+
+ const auto permute = generatePermute();
+ verifyOp(permute);
+
+ const auto pool2d = generatePool2D();
+ verifyOp(pool2d);
+
+ const auto pow = generatePow();
+ verifyOp(pow);
+
+ const auto prelu = generatePReLU();
+ verifyOp(prelu);
+
+ const auto range = generateRange();
+ verifyOp(range);
+
+ const auto rank = generateRank();
+ verifyOp(rank);
+
+ const auto reduce = generateReduce();
+ verifyOp(reduce);
+
+ const auto reshape = generateReshape();
+ verifyOp(reshape);
+
+ const auto resize_bilinear = generateResizeBilinear();
+ verifyOp(resize_bilinear);
+
+ const auto resize_nearest_neighbor = generateResizeNearestNeighbor();
+ verifyOp(resize_nearest_neighbor);
+
+ const auto reverse = generateReverse();
+ verifyOp(reverse);
+
+ const auto rnn = generateRNN();
+ verifyOp(rnn);
+
+ const auto select = generateSelect();
+ verifyOp(select);
+
+ const auto shape = generateShape();
+ verifyOp(shape);
+
+ const auto slice = generateSlice();
+ verifyOp(slice);
+
+ const auto softmax = generateSoftmax();
+ verifyOp(softmax);
+
+ const auto space_to_batchnd = generateSpaceToBatchND();
+ verifyOp(space_to_batchnd);
+
+ const auto space_to_depth = generateSpaceToDepth();
+ verifyOp(space_to_depth);
+
+ const auto split = generateSplit();
+ verifyOp(split);
+
+ const auto splitv = generateSplitV();
+ verifyOp(splitv);
+
+ const auto squared_diff = generateSquaredDifference();
+ verifyOp(squared_diff);
+
+ const auto squeeze = generateSqueeze();
+ verifyOp(squeeze);
+
+ const auto stateless_random_uniform = generateStatelessRandomUniform();
+ verifyOp(stateless_random_uniform);
+
+ const auto strided_slice = generateStridedSlice();
+ verifyOp(strided_slice);
+
+ const auto tile = generateTile();
+ verifyOp(tile);
+
+ const auto topkv2 = generateTopKV2();
+ verifyOp(topkv2);
+
+ const auto transpose = generateTranspose();
+ verifyOp(transpose);
+
+ const auto transpose_conv = generateTransposeConv();
+ verifyOp(transpose_conv);
+
+ const auto unpack = generateUnpack();
+ verifyOp(unpack);
+
+ const auto while_op = generateWhile();
+ verifyOp(while_op);
+}
+
+class MockTrainableOperationVisitor : public train::TrainableOperationVisitor
+{
+public:
+ void invoke(train::ITrainableOperation &op) { op.accept(*this); }
+
+#define OP(InternalName) \
+ virtual void visit(const train::operation::InternalName &) override {}
+#include "ir/train/ITrainableOperation.h"
+#undef OP
+};
+
+TEST(UntrainableOperation, neg_TrainableOperationVisitor)
+{
+ MockTrainableOperationVisitor visitor;
+
+ {
+ const auto addn = generateAddN();
+ auto untrainable = generateUntrainableOperation(addn);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ auto argminmax = generateArgMinMax();
+ auto untrainable = generateUntrainableOperation(argminmax);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto batch_matmul = generateBatchMatMul();
+ auto untrainable = generateUntrainableOperation(batch_matmul);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto batch_to_spacend = generateBatchToSpaceND();
+ auto untrainable = generateUntrainableOperation(batch_to_spacend);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto bcq_fc = generateBCQFullyConnected();
+ auto untrainable = generateUntrainableOperation(bcq_fc);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto bcq_gather = generateBCQGather();
+ auto untrainable = generateUntrainableOperation(bcq_gather);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto binary_arithmetic = generateBinaryArithmetic();
+ auto untrainable = generateUntrainableOperation(binary_arithmetic);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto broadcast = generateBroadcastTo();
+ auto untrainable = generateUntrainableOperation(broadcast);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto bulk = generateBulk();
+ auto untrainable = generateUntrainableOperation(bulk);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto comparison = generateComparison();
+ auto untrainable = generateUntrainableOperation(comparison);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto concat = generateConcat();
+ auto untrainable = generateUntrainableOperation(concat);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto conv2d = generateConv2D();
+ auto untrainable = generateUntrainableOperation(conv2d);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto fp16_to_fp32 = generateConvertFp16ToFp32();
+ auto untrainable = generateUntrainableOperation(fp16_to_fp32);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto fp32_to_fp16 = generateConvertFp32ToFp16();
+ auto untrainable = generateUntrainableOperation(fp32_to_fp16);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto custom = generateCustom();
+ auto untrainable = generateUntrainableOperation(custom);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto depth_to_space = generateDepthToSpace();
+ auto untrainable = generateUntrainableOperation(depth_to_space);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto depthwise_conv2d = generateDepthwiseConv2D();
+ auto untrainable = generateUntrainableOperation(depthwise_conv2d);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto detection = generateDetectionPostProcess();
+ auto untrainable = generateUntrainableOperation(detection);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto einsum = generateEinsum();
+ auto untrainable = generateUntrainableOperation(einsum);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto activation = generateElementwiseActivation();
+ auto untrainable = generateUntrainableOperation(activation);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto binary = generateElementwiseBinary();
+ auto untrainable = generateUntrainableOperation(binary);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto unary = generateElementwiseUnary();
+ auto untrainable = generateUntrainableOperation(unary);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto embed = generateEmbeddingLookup();
+ auto untrainable = generateUntrainableOperation(embed);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto expand_dims = generateExpandDims();
+ auto untrainable = generateUntrainableOperation(expand_dims);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto fill = generateFill();
+ auto untrainable = generateUntrainableOperation(fill);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto fc = generateFullyConnected();
+ auto untrainable = generateUntrainableOperation(fc);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto fused_batch_norm = generateFusedBatchNorm();
+ auto untrainable = generateUntrainableOperation(fused_batch_norm);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto gather = generateGather();
+ auto untrainable = generateUntrainableOperation(gather);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto hashtable = generateHashtableLookup();
+ auto untrainable = generateUntrainableOperation(hashtable);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto if_op = generateIf();
+ auto untrainable = generateUntrainableOperation(if_op);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto in_norm = generateInstanceNorm();
+ auto untrainable = generateUntrainableOperation(in_norm);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto l2_norm = generateL2Normalization();
+ auto untrainable = generateUntrainableOperation(l2_norm);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto local_norm = generateLocalResponseNormalization();
+ auto untrainable = generateUntrainableOperation(local_norm);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto log_softmax = generateLogSoftmax();
+ auto untrainable = generateUntrainableOperation(log_softmax);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto lstm = generateLSTM();
+ auto untrainable = generateUntrainableOperation(lstm);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto matrix_band_part = generateMatrixBandPart();
+ auto untrainable = generateUntrainableOperation(matrix_band_part);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto one_hot = generateOneHot();
+ auto untrainable = generateUntrainableOperation(one_hot);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto pack = generatePack();
+ auto untrainable = generateUntrainableOperation(pack);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto pad = generatePad();
+ auto untrainable = generateUntrainableOperation(pad);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto permute = generatePermute();
+ auto untrainable = generateUntrainableOperation(permute);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto pool2d = generatePool2D();
+ auto untrainable = generateUntrainableOperation(pool2d);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto pow = generatePow();
+ auto untrainable = generateUntrainableOperation(pow);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto prelu = generatePReLU();
+ auto untrainable = generateUntrainableOperation(prelu);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto range = generateRange();
+ auto untrainable = generateUntrainableOperation(range);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto rank = generateRank();
+ auto untrainable = generateUntrainableOperation(rank);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto reduce = generateReduce();
+ auto untrainable = generateUntrainableOperation(reduce);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto reshape = generateReshape();
+ auto untrainable = generateUntrainableOperation(reshape);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto resize_bilinear = generateResizeBilinear();
+ auto untrainable = generateUntrainableOperation(resize_bilinear);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto resize_nearest_neighbor = generateResizeNearestNeighbor();
+ auto untrainable = generateUntrainableOperation(resize_nearest_neighbor);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto reverse = generateReverse();
+ auto untrainable = generateUntrainableOperation(reverse);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto rnn = generateRNN();
+ auto untrainable = generateUntrainableOperation(rnn);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto select = generateSelect();
+ auto untrainable = generateUntrainableOperation(select);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto shape = generateShape();
+ auto untrainable = generateUntrainableOperation(shape);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto slice = generateSlice();
+ auto untrainable = generateUntrainableOperation(slice);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto softmax = generateSoftmax();
+ auto untrainable = generateUntrainableOperation(softmax);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto space_to_batchnd = generateSpaceToBatchND();
+ auto untrainable = generateUntrainableOperation(space_to_batchnd);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto space_to_depth = generateSpaceToDepth();
+ auto untrainable = generateUntrainableOperation(space_to_depth);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto split = generateSplit();
+ auto untrainable = generateUntrainableOperation(split);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto splitv = generateSplitV();
+ auto untrainable = generateUntrainableOperation(splitv);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto squared_diff = generateSquaredDifference();
+ auto untrainable = generateUntrainableOperation(squared_diff);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto squeeze = generateSqueeze();
+ auto untrainable = generateUntrainableOperation(squeeze);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto stateless_random_uniform = generateStatelessRandomUniform();
+ auto untrainable = generateUntrainableOperation(stateless_random_uniform);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto strided_slice = generateStridedSlice();
+ auto untrainable = generateUntrainableOperation(strided_slice);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto tile = generateTile();
+ auto untrainable = generateUntrainableOperation(tile);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto topkv2 = generateTopKV2();
+ auto untrainable = generateUntrainableOperation(topkv2);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto transpose = generateTranspose();
+ auto untrainable = generateUntrainableOperation(transpose);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto transpose_conv = generateTransposeConv();
+ auto untrainable = generateUntrainableOperation(transpose_conv);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto unpack = generateUnpack();
+ auto untrainable = generateUntrainableOperation(unpack);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+
+ {
+ const auto while_op = generateWhile();
+ auto untrainable = generateUntrainableOperation(while_op);
+ EXPECT_ANY_THROW(visitor.invoke(*untrainable));
+ }
+}
diff --git a/runtime/onert/core/src/ir/verifier/Verifier.cc b/runtime/onert/core/src/ir/verifier/Verifier.cc
index 09cbdcf2f..bcded0c68 100644
--- a/runtime/onert/core/src/ir/verifier/Verifier.cc
+++ b/runtime/onert/core/src/ir/verifier/Verifier.cc
@@ -21,6 +21,69 @@
#include "util/logging.h"
+namespace
+{
+
+using namespace onert::ir;
+
+std::set<train::TrainingOperationIndex>
+extractOperations(const train::UseDefChains &training_usedefs)
+{
+ // Extract TrainingOperations from training_usedefs
+ std::set<train::TrainingOperationIndex> operations;
+ for (const auto &pair : training_usedefs)
+ {
+ const auto &output = pair.first;
+ const auto &usedefs = pair.second;
+ const auto &defs = usedefs.getTrainingDefs();
+ for (const auto &node_index : defs)
+ if (node_index.valid() && output.valid())
+ operations.insert(node_index);
+ }
+
+ return operations;
+}
+
+std::unordered_map<train::TrainingOperationIndex, std::vector<train::TrainingOperandIndex>>
+extractNodeInputs(const train::UseDefChains &training_usedefs)
+{
+ // Extract inputs of TrainingOperations from training_usedefs
+ std::unordered_map<train::TrainingOperationIndex, std::vector<train::TrainingOperandIndex>>
+ node_inputs;
+ for (const auto &pair : training_usedefs)
+ {
+ const auto &input = pair.first;
+ const auto &usedefs = pair.second;
+ const auto &uses = usedefs.getTrainingUses();
+ for (const auto &node_index : uses)
+ if (node_index.valid() && input.valid())
+ node_inputs[node_index].emplace_back(input);
+ }
+
+ return node_inputs;
+}
+
+std::unordered_map<train::TrainingOperationIndex, std::vector<train::TrainingOperandIndex>>
+extractNodeOutputs(const train::UseDefChains &training_usedefs)
+{
+ // Extract outputs of TrainingOperations from training_usedefs
+ std::unordered_map<train::TrainingOperationIndex, std::vector<train::TrainingOperandIndex>>
+ node_outputs;
+ for (const auto &pair : training_usedefs)
+ {
+ const auto &output = pair.first;
+ const auto &usedefs = pair.second;
+ const auto &defs = usedefs.getTrainingDefs();
+ for (const auto &node_index : defs)
+ if (node_index.valid() && output.valid())
+ node_outputs[node_index].emplace_back(output);
+ }
+
+ return node_outputs;
+}
+
+} // namespace
+
namespace onert
{
namespace ir
@@ -39,11 +102,11 @@ bool DAGChecker::verify(const Graph &graph) const noexcept
OperationIndexMap<bool> visited;
operations.iterate(
- [&](const OperationIndex &index, const Operation &) { visited[index] = false; });
+ [&](const OperationIndex &index, const IOperation &) { visited[index] = false; });
OperationIndexMap<bool> on_stack = visited; // Copy from visited
- std::function<void(const OperationIndex &index, const Operation &)> dfs_recursive =
- [&](const OperationIndex &index, const Operation &node) -> void {
+ std::function<void(const OperationIndex &index, const IOperation &)> dfs_recursive =
+ [&](const OperationIndex &index, const IOperation &node) -> void {
if (on_stack[index])
cyclic = true;
if (visited[index])
@@ -51,7 +114,7 @@ bool DAGChecker::verify(const Graph &graph) const noexcept
visited[index] = true;
on_stack[index] = true;
- for (auto output : node.getOutputs() | Remove::DUPLICATED)
+ for (auto &&output : node.getOutputs() | Remove::DUPLICATED | Remove::UNDEFINED)
{
const auto &operand = graph.operands().at(output);
for (const auto &use : operand.getUses())
@@ -68,16 +131,56 @@ bool DAGChecker::verify(const Graph &graph) const noexcept
return !cyclic;
}
+// TODO Merge with the above DAGChecker::verify(const Graph &)
+bool DAGChecker::verify(const train::UseDefChains &training_usedefs) const noexcept
+{
+ bool cyclic = false;
+ const auto operations = extractOperations(training_usedefs);
+ auto outputs_map = extractNodeOutputs(training_usedefs);
+
+ std::unordered_map<train::TrainingOperationIndex, bool> visited;
+ for (const auto &node_index : operations)
+ visited[node_index] = false;
+ auto on_stack = visited; // Copy from visited
+
+ std::function<void(const train::TrainingOperationIndex &index)> dfs_recursive =
+ [&](const train::TrainingOperationIndex &index) -> void {
+ if (on_stack[index])
+ cyclic = true;
+ if (visited[index])
+ return;
+ visited[index] = true;
+ on_stack[index] = true;
+
+ auto &node_outputs = outputs_map[index];
+ for (const auto &output : node_outputs)
+ {
+ const auto &uses = training_usedefs.at(output).getTrainingUses();
+ for (const auto &use : uses)
+ {
+ dfs_recursive(use);
+ }
+ }
+
+ on_stack[index] = false;
+ };
+
+ for (const auto &node_index : operations)
+ dfs_recursive(node_index);
+
+ return !cyclic;
+}
+
//
// EdgeConsistencyVerifier
//
-bool EdgeConsistencyChecker::verify(const Graph &graph) const noexcept
+bool EdgeChecker::verify(const Graph &graph) const noexcept
{
auto &operations = graph.operations();
uint32_t errors = 0;
- operations.iterate([&](const OperationIndex &index, const Operation &node) {
- for (auto operand_index : node.getInputs() | ir::Remove::UNDEFINED)
+ operations.iterate([&](const OperationIndex &index, const IOperation &node) {
+ for (auto &&operand_index : node.getInputs() | ir::Remove::UNDEFINED)
{
try
{
@@ -85,44 +188,117 @@ bool EdgeConsistencyChecker::verify(const Graph &graph) const noexcept
bool operand_has_use = operand.getUses().contains(index);
if (!operand_has_use)
{
- VERBOSE(EdgeConsistencyChecker) << "[ERROR] EDGE MISMATCH : Missing USE edge - Operand "
- << operand_index << " to Operation " << index
- << std::endl;
+ VERBOSE(EdgeChecker) << "[ERROR] EDGE MISMATCH : Missing USE edge - Operand "
+ << operand_index << " to Operation " << index << std::endl;
errors += 1;
}
}
catch (const std::out_of_range &e)
{
- VERBOSE(EdgeConsistencyChecker)
- << "[ERROR] OPEARAND NOT FOUND : Operation " << index << " has Operand "
- << operand_index << ", but the operand object is not present in the graph" << std::endl;
+ VERBOSE(EdgeChecker) << "[ERROR] OPEARAND NOT FOUND : Operation " << index
+ << " has Operand " << operand_index
+ << ", but the operand object is not present in the graph" << std::endl;
errors += 1;
}
}
- for (auto operand_index : node.getOutputs())
+ for (auto &&operand_index : node.getOutputs() | ir::Remove::UNDEFINED)
{
try
{
auto &operand = graph.operands().at(operand_index);
if (operand.getDef() != index)
{
- VERBOSE(EdgeConsistencyChecker) << "[ERROR] EDGE MISMATCH : Missing DEF edge - Operand"
- << operand_index << " to Operation " << index
- << std::endl;
+ VERBOSE(EdgeChecker) << "[ERROR] EDGE MISMATCH : Missing DEF edge - Operand"
+ << operand_index << " to Operation " << index << std::endl;
errors += 1;
}
}
catch (const std::out_of_range &e)
{
- VERBOSE(EdgeConsistencyChecker)
- << "[ERROR] OPEARAND NOT FOUND : Operation " << index << " has Operand "
- << operand_index << ", but the operand object is not present in the graph" << std::endl;
+ VERBOSE(EdgeChecker) << "[ERROR] OPEARAND NOT FOUND : Operation " << index
+ << " has Operand " << operand_index
+ << ", but the operand object is not present in the graph" << std::endl;
errors += 1;
}
}
});
- VERBOSE(EdgeConsistencyChecker) << "Total Number of errors : " << errors << std::endl;
+ VERBOSE(EdgeChecker) << "Total Number of errors : " << errors << std::endl;
+
+ return errors == 0;
+}
+
+bool InputOutputChecker::verify(const Graph &graph) const noexcept
+{
+ for (auto &&operand_ind :
+ (graph.getInputs() + graph.getOutputs()) | Remove::DUPLICATED | Remove::UNDEFINED)
+ {
+ if (!graph.operands().exist(operand_ind))
+ {
+ VERBOSE(InputOutputChecker) << "Input or Output tensor " << operand_ind << " does not exist.";
+ return false;
+ }
+ }
+ return true;
+}
+
+// TODO Merge with the above EdgeChecker::verify(const Graph &)
+bool EdgeChecker::verify(const train::UseDefChains &training_usedefs) const noexcept
+{
+ const auto operations = extractOperations(training_usedefs);
+ auto inputs_map = extractNodeInputs(training_usedefs);
+ auto outputs_map = extractNodeOutputs(training_usedefs);
+ uint32_t errors = 0;
+ for (const auto &index : operations)
+ {
+ const auto &node_inputs = inputs_map[index];
+ for (const auto &operand_index : node_inputs)
+ {
+ try
+ {
+ const auto &uses = training_usedefs.at(operand_index).getTrainingUses();
+ bool operand_has_use = (uses.find(index) != uses.end());
+ if (!operand_has_use)
+ {
+ VERBOSE(EdgeChecker) << "[ERROR] EDGE MISMATCH : Missing USE edge - Operand "
+ << operand_index << " to Operation " << index << std::endl;
+ errors += 1;
+ }
+ }
+ catch (const std::out_of_range &e)
+ {
+ VERBOSE(EdgeChecker) << "[ERROR] OPEARAND NOT FOUND : Operation " << index
+ << " has Operand " << operand_index
+ << ", but the operand object is not present in the graph" << std::endl;
+ errors += 1;
+ }
+ }
+
+ const auto &node_outputs = outputs_map[index];
+ for (const auto &operand_index : node_outputs)
+ {
+ try
+ {
+ const auto &defs = training_usedefs.at(operand_index).getTrainingDefs();
+ bool operand_has_def = (defs.find(index) != defs.end());
+ if (!operand_has_def)
+ {
+ VERBOSE(EdgeChecker) << "[ERROR] EDGE MISMATCH : Missing DEF edge - Operand"
+ << operand_index << " to Operation " << index << std::endl;
+ errors += 1;
+ }
+ }
+ catch (const std::out_of_range &e)
+ {
+ VERBOSE(EdgeChecker) << "[ERROR] OPEARAND NOT FOUND : Operation " << index
+ << " has Operand " << operand_index
+ << ", but the operand object is not present in the graph" << std::endl;
+ errors += 1;
+ }
+ }
+ }
+
+ VERBOSE(EdgeChecker) << "Total Number of errors : " << errors << std::endl;
return errors == 0;
}
diff --git a/runtime/onert/core/src/ir/verifier/Verifier.h b/runtime/onert/core/src/ir/verifier/Verifier.h
index 0c7b57b04..9f1dd8e60 100644
--- a/runtime/onert/core/src/ir/verifier/Verifier.h
+++ b/runtime/onert/core/src/ir/verifier/Verifier.h
@@ -17,6 +17,8 @@
#ifndef __ONERT_GRAPH_VERIFIER_VERIFIER_H__
#define __ONERT_GRAPH_VERIFIER_VERIFIER_H__
+#include "ir/train/UseDefChains.h"
+
namespace onert
{
namespace ir
@@ -53,9 +55,20 @@ class DAGChecker : public IVerifier
{
public:
bool verify(const Graph &graph) const noexcept override;
+ bool verify(const train::UseDefChains &training_defuses) const noexcept;
};
-class EdgeConsistencyChecker : public IVerifier
+class EdgeChecker : public IVerifier
+{
+public:
+ bool verify(const Graph &graph) const noexcept override;
+ bool verify(const train::UseDefChains &training_defuses) const noexcept;
+};
+
+/**
+ * @brief Check model input and output operands are really exist in the graph
+ */
+class InputOutputChecker : public IVerifier
{
public:
bool verify(const Graph &graph) const noexcept override;
diff --git a/runtime/onert/core/src/ir/verifier/Verifier.test.cc b/runtime/onert/core/src/ir/verifier/Verifier.test.cc
new file mode 100644
index 000000000..1ec71cd55
--- /dev/null
+++ b/runtime/onert/core/src/ir/verifier/Verifier.test.cc
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Verifier.h"
+
+#include "../MockNode.h"
+
+#include "ir/Graph.h"
+
+#include <gtest/gtest.h>
+
+#include <memory>
+
+using IndexSet = onert::ir::OperandIndexSequence;
+using Mock = onert_test::ir::SimpleMock;
+
+TEST(Verifier, dag_checker)
+{
+ onert::ir::Graph graph;
+
+ onert::ir::Shape shape{3};
+ onert::ir::TypeInfo type{onert::ir::DataType::INT32};
+
+ auto operand1 = graph.addOperand(shape, type);
+ auto operand2 = graph.addOperand(shape, type);
+
+ graph.addInput(operand1);
+ graph.addOutput(operand2);
+
+ graph.addOperation(std::make_unique<Mock>(IndexSet{operand1}, IndexSet{operand2}));
+
+ onert::ir::verifier::DAGChecker verifier;
+
+ ASSERT_TRUE(verifier.verify(graph));
+}
+
+TEST(Verifier, neg_edge_consistency_checker_1)
+{
+ onert::ir::Graph graph;
+
+ onert::ir::Shape shape{3};
+ onert::ir::TypeInfo type{onert::ir::DataType::INT32};
+
+ auto operand1 = graph.addOperand(shape, type);
+ auto operand2 = graph.addOperand(shape, type);
+
+ graph.addInput(operand1);
+ graph.addOutput(operand2);
+
+ auto mock_op = std::make_unique<Mock>(IndexSet{operand1}, IndexSet{operand2});
+ auto op_ind = graph.addOperation(std::move(mock_op));
+
+ graph.operands().at(operand1).removeUse(op_ind); // Manipulate the operand alone
+
+ onert::ir::verifier::EdgeChecker verifier;
+ ASSERT_FALSE(verifier.verify(graph));
+}
+
+TEST(Verifier, neg_edge_consistency_checker_2)
+{
+ onert::ir::Graph graph;
+
+ onert::ir::Shape shape{3};
+ onert::ir::TypeInfo type{onert::ir::DataType::INT32};
+
+ auto operand1 = graph.addOperand(shape, type);
+ auto operand2 = graph.addOperand(shape, type);
+
+ graph.addInput(operand1);
+ graph.addOutput(operand2);
+
+ auto mock_op = std::make_unique<Mock>(IndexSet{operand1}, IndexSet{operand2});
+ auto mock_op_ptr = mock_op.get();
+ auto op_ind = graph.addOperation(std::move(mock_op));
+
+ mock_op_ptr->setInputs({operand2}); // Manipulate the operation alone
+
+ onert::ir::verifier::EdgeChecker verifier;
+ ASSERT_FALSE(verifier.verify(graph));
+}
diff --git a/runtime/onert/core/src/loader/BaseLoader.h b/runtime/onert/core/src/loader/BaseLoader.h
new file mode 100644
index 000000000..c3a50b0d8
--- /dev/null
+++ b/runtime/onert/core/src/loader/BaseLoader.h
@@ -0,0 +1,1794 @@
+/*
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_LOADER_BASE_LOADER_H__
+#define __ONERT_LOADER_BASE_LOADER_H__
+
+#include "ir/Graph.h"
+#include "ir/Shape.h"
+#include "ir/Operations.Include.h"
+
+#include "flatbuffers/flexbuffers.h"
+
+#include <map>
+#include <memory>
+#include <fstream>
+#include <limits>
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <sys/mman.h>
+#include <unistd.h>
+#include <util/logging.h>
+
+namespace onert
+{
+namespace loader
+{
+
+template <typename LoaderDomain> class BaseLoader
+{
+protected:
+ using Verifier = typename LoaderDomain::Verifier;
+ using ActivationFunctionType = typename LoaderDomain::ActivationFunctionType;
+ using Buffer = typename LoaderDomain::Buffer;
+ using BuiltinOperator = typename LoaderDomain::BuiltinOperator;
+ using CustomOptionsFormat = typename LoaderDomain::CustomOptionsFormat;
+ using Metadata = typename LoaderDomain::Metadata;
+ using Model = typename LoaderDomain::Model;
+ using Operator = typename LoaderDomain::Operator;
+ using Padding = typename LoaderDomain::Padding;
+ using Pool2DOptions = typename LoaderDomain::Pool2DOptions;
+ using SubGraph = typename LoaderDomain::SubGraph;
+ using Tensor = typename LoaderDomain::Tensor;
+ using TensorType = typename LoaderDomain::TensorType;
+ using DimensionType = typename LoaderDomain::DimensionType;
+ using SparseIndexVector = typename LoaderDomain::SparseIndexVector;
+
+protected:
+ bool isOptionalInputTensor(std::int32_t idx) { return idx == -1; }
+ virtual bool allowOptionalInputTensor(BuiltinOperator) = 0;
+
+public:
+ /**
+ * @brief Construct a new Loader object
+ *
+ * @param model reference to model
+ */
+ explicit BaseLoader(std::unique_ptr<ir::Model> &model)
+ : _base{nullptr}, _pagesize(getpagesize()), _fd(-1), _model(model), _domain_model{nullptr}
+ {
+ _use_mmaped_data = util::getConfigBool(util::config::USE_MMAPED_DATA);
+ }
+
+ /**
+ * @brief Load a model from file
+ *
+ * @param file_path
+ */
+ void loadFromFile(const std::string &file_path);
+ /**
+ * @brief Load a model from a buffer
+ *
+ * @param buffer buffer pointer
+ * @param size buffer size
+ */
+ void loadFromBuffer(uint8_t *buffer, size_t size);
+
+protected:
+ ~BaseLoader() = default;
+ void loadModel();
+
+ // Helper functions
+ ir::Activation convertActivation(ActivationFunctionType type);
+ ir::DataType tensorTypeToDataType(TensorType type);
+ ir::OperandIndex tensorIdxToOperandIdx(int32_t tensorIdx);
+ flexbuffers::Map getCustomOpAttrMap(const Operator *op);
+
+ // Create operands form tflite::Tensor
+ ir::OperandIndex loadOperand(const Tensor *tensor, ir::Graph &subg);
+ void loadQuantization(const Tensor *tensor, ir::TypeInfo &typeInfo);
+ void loadSparsity(const Tensor *tensor, ir::TypeInfo &typeInfo);
+ void loadOperationIO(const Operator *op, ir::OperandIndexSequence &inputs,
+ ir::OperandIndexSequence &outputs);
+ // Create operations from Operator
+ void loadOperation(const Operator *op, ir::Graph &subg);
+ // Load Strides and Paddings from options to param
+ template <typename Param, typename OptionsType>
+ void loadStridesAndPaddings(Param &param, const OptionsType *options);
+ // Load Pool2D param
+ template <typename Param> void loadPool2DOptions(Param &param, const Pool2DOptions *options);
+ // Get BuiltinOperator
+ BuiltinOperator getBuiltinOperator(const Operator *op)
+ {
+ auto const builtin_opcode = _domain_model->operator_codes()->Get(op->opcode_index());
+ auto builtin_op = builtin_opcode->builtin_code();
+ if (builtin_op < BuiltinOperator::BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES)
+ builtin_op = static_cast<BuiltinOperator>(builtin_opcode->deprecated_builtin_code());
+
+ return builtin_op;
+ }
+
+private:
+ std::unique_ptr<ir::Data> loadMetadata(const uint32_t buffer_idx);
+ virtual std::unique_ptr<ir::Graph> loadSubgraph(const SubGraph *subg) = 0;
+ // Operations
+ template <typename OpIR, typename... Args>
+ const OpIR *loadOperationTo(const Operator *op, ir::Graph &subg, Args &&...args);
+
+ void loadAddV2(const Operator *op, ir::Graph &subg);
+ void loadArgMinMax(const Operator *op, ir::Graph &subg, bool is_argmax);
+ void loadBatchMatMul(const Operator *op, ir::Graph &subg);
+ void loadBinaryArithmetic(const Operator *op, ir::Graph &subg,
+ ir::operation::BinaryArithmetic::ArithmeticType op_type);
+ void loadComparison(const Operator *op, ir::Graph &subg);
+ void loadConcatenation(const Operator *op, ir::Graph &subg);
+ void loadConv2D(const Operator *op, ir::Graph &subg);
+ void loadCustom(const Operator *op, ir::Graph &subg);
+ void loadDepthToSpace(const Operator *op, ir::Graph &subg);
+ void loadDepthwiseConv2D(const Operator *op, ir::Graph &subg);
+ void loadEinsum(const Operator *op, ir::Graph &subg);
+ void loadElementwiseActivation(const Operator *op, ir::Graph &subg,
+ ir::operation::ElementwiseActivation::Type op_type,
+ float alpha = 0.f, float beta = 0.f);
+ void loadElementwiseBinary(const Operator *op, ir::Graph &subg,
+ ir::operation::ElementwiseBinary::ElementwiseBinaryType op_type);
+ void loadElementwiseUnary(const Operator *op, ir::Graph &subg,
+ ir::operation::ElementwiseUnary::Type op_type);
+ void loadFC(const Operator *op, ir::Graph &subg);
+ void loadFusedBatchNorm(const Operator *op, ir::Graph &subg);
+ void loadGather(const Operator *op, ir::Graph &subg);
+ void loadIf(const Operator *op, ir::Graph &subg);
+ void loadLeakyRelu(const Operator *op, ir::Graph &subg);
+ void loadLogSoftmax(const Operator *op, ir::Graph &subg);
+ void loadDetectionPostProcess(const Operator *op, ir::Graph &subg);
+ void loadOneHot(const Operator *op, ir::Graph &subg);
+ void loadPack(const Operator *op, ir::Graph &subg);
+ void loadPool2D(const Operator *op, ir::Graph &subg, ir::operation::Pool2D::PoolType op_type);
+ void loadReduce(const Operator *op, ir::Graph &subg,
+ ir::operation::Reduce::ReduceType reduce_type);
+ void loadReduceAll(const Operator *op, ir::Graph &subg);
+ void loadReshape(const Operator *op, ir::Graph &subg);
+ void loadResizeBilinear(const Operator *op, ir::Graph &subg);
+ void loadResizeNearestNeighbor(const Operator *op, ir::Graph &subg);
+ void loadSoftmax(const Operator *op, ir::Graph &subg);
+ void loadSpaceToDepth(const Operator *op, ir::Graph &subg);
+ void loadSplit(const Operator *op, ir::Graph &subg);
+ void loadSplitV(const Operator *op, ir::Graph &subg);
+ void loadSqueeze(const Operator *op, ir::Graph &subg);
+ void loadStridedSlice(const Operator *op, ir::Graph &subg);
+ void loadTransposeConv(const Operator *op, ir::Graph &subg);
+ void loadUnidirectionalSequenceLSTM(const Operator *op, ir::Graph &subg);
+ void loadUnpack(const Operator *op, ir::Graph &subg);
+ void loadWhile(const Operator *op, ir::Graph &subg);
+
+ void verifySubgraphIndex(int subg_index)
+ {
+ const auto num_subgraphs = _domain_model->subgraphs()->size();
+ if (subg_index < 0 || subg_index >= static_cast<int32_t>(num_subgraphs))
+ throw std::runtime_error{std::string{"Invalid subgraph index - "} +
+ std::to_string(subg_index)};
+ }
+
+protected:
+ // Base address for mapped region for loading (if needed)
+ uint8_t *_base;
+ // Memory page size
+ int32_t _pagesize;
+ // loaded file description
+ int _fd;
+ // Reference to ir::model (to be loaded from _domain_model)
+ std::unique_ptr<ir::Model> &_model;
+ const Model *_domain_model;
+ // Maps Tensor indices to onert Operands.
+ std::vector<ir::OperandIndex> _tensor_to_operand;
+ std::unordered_map<ir::OperandIndex, std::string> _tensor_names;
+ // Verifier
+ std::unique_ptr<Verifier> _verifier;
+ // Boolean flag to use MMAPED_DATA
+ bool _use_mmaped_data = false;
+
+ std::unordered_map<uint32_t /* Buffer Index in circle file */, std::shared_ptr<ir::Data>>
+ _buf_to_data;
+};
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::BaseLoader::loadFromFile(const std::string &file_path)
+{
+ _fd = open(file_path.c_str(), O_RDONLY);
+ if (_fd < 0)
+ {
+ throw std::runtime_error("Failed to open file " + file_path);
+ }
+
+ struct stat file_stat;
+ if (fstat(_fd, &file_stat) != 0)
+ {
+ throw std::runtime_error("Fstat failed or file " + file_path + " is not a regular file");
+ }
+ int size = file_stat.st_size;
+
+ // Map model file into memory region
+ _base = static_cast<uint8_t *>(mmap(NULL, size, PROT_READ, MAP_PRIVATE, _fd, 0));
+ if (_base == MAP_FAILED)
+ {
+ close(_fd);
+ throw std::runtime_error("mmap failed - " + std::string(strerror(errno)));
+ }
+
+ _verifier = std::make_unique<Verifier>(reinterpret_cast<const std::uint8_t *>(_base), size);
+
+ loadModel();
+ munmap(_base, size);
+
+ close(_fd);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::BaseLoader::loadFromBuffer(uint8_t *buffer, size_t size)
+{
+ _base = buffer;
+ _verifier = std::make_unique<Verifier>(reinterpret_cast<const std::uint8_t *>(_base), size);
+ loadModel();
+}
+
+template <typename LoaderDomain>
+std::unique_ptr<ir::Data>
+BaseLoader<LoaderDomain>::BaseLoader::loadMetadata(const uint32_t buffer_idx)
+{
+ assert(_domain_model != nullptr);
+ const auto *data = _domain_model->buffers()->Get(buffer_idx)->data();
+ if (data == nullptr)
+ throw std::runtime_error("Metadata buffer is not found");
+
+ if (_fd == -1) // Model is from memory
+ {
+ return std::make_unique<ir::ExternalData>(data->data(), data->size());
+ }
+ else // Model is loaded(mmap'd) from a file
+ {
+ size_t data_size = data->size();
+ ptrdiff_t offset_start = data->data() - _base;
+ ptrdiff_t offset_end = offset_start + data_size;
+
+ ptrdiff_t page_start = (offset_start / _pagesize) * _pagesize;
+ size_t mapping_size = offset_end - page_start;
+
+ // Since metadata is not access often in inference/training time, always use mmaped-data
+ // Ref : https://github.com/Samsung/ONE/issues/3961#issuecomment-681750231
+ return std::make_unique<ir::MMapedData>(_fd, page_start, mapping_size, offset_start, data_size);
+ }
+}
+
+template <typename LoaderDomain>
+ir::Activation
+BaseLoader<LoaderDomain>::BaseLoader::convertActivation(const ActivationFunctionType type)
+{
+ switch (type)
+ {
+ case ActivationFunctionType::ActivationFunctionType_NONE:
+ return ir::Activation::NONE;
+ case ActivationFunctionType::ActivationFunctionType_RELU:
+ return ir::Activation::RELU;
+ case ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
+ return ir::Activation::RELU1;
+ case ActivationFunctionType::ActivationFunctionType_RELU6:
+ return ir::Activation::RELU6;
+ case ActivationFunctionType::ActivationFunctionType_TANH:
+ return ir::Activation::TANH;
+ default:
+ throw std::runtime_error(std::string("Unsupported or invalid activation type: ") +
+ std::to_string(static_cast<int>(type)));
+ }
+}
+
+template <typename LoaderDomain>
+ir::DataType BaseLoader<LoaderDomain>::BaseLoader::tensorTypeToDataType(const TensorType type)
+{
+ switch (type)
+ {
+ case TensorType::TensorType_FLOAT32:
+ return ir::DataType::FLOAT32;
+ case TensorType::TensorType_FLOAT16:
+ return ir::DataType::FLOAT16;
+ case TensorType::TensorType_INT32:
+ return ir::DataType::INT32;
+ case TensorType::TensorType_UINT8:
+ return ir::DataType::QUANT_UINT8_ASYMM;
+ case TensorType::TensorType_INT64:
+ return ir::DataType::INT64;
+ // case TensorType::TensorType_STRING:
+ case TensorType::TensorType_BOOL:
+ return ir::DataType::BOOL8;
+ case TensorType::TensorType_INT16:
+ return ir::DataType::QUANT_INT16_ASYMM;
+ // case TensorType::TensorType_COMPLEX64
+ case TensorType::TensorType_INT8:
+ return ir::DataType::QUANT_INT8_ASYMM;
+ // case TensorType::TensorType_FLOAT64
+ case TensorType::TensorType_UINT32:
+ return ir::DataType::UINT32;
+ default:
+ throw std::runtime_error(
+ std::string("Unsupported tensor type: ").append(EnumNameTensorType(type)));
+ }
+}
+
+template <typename LoaderDomain>
+ir::OperandIndex BaseLoader<LoaderDomain>::BaseLoader::tensorIdxToOperandIdx(int32_t tensorIdx)
+{
+ return isOptionalInputTensor(tensorIdx) ? ir::OperandIndex() : _tensor_to_operand[tensorIdx];
+}
+
+template <typename LoaderDomain>
+flexbuffers::Map BaseLoader<LoaderDomain>::BaseLoader::getCustomOpAttrMap(const Operator *op)
+{
+ size_t custom_op_data_size = op->custom_options()->size();
+ auto custom_op_data = op->custom_options()->Data();
+ auto data_root = flexbuffers::GetRoot(custom_op_data, custom_op_data_size);
+ return data_root.AsMap();
+}
+
+/* Copy is copied from tensorflow lite */
+template <typename T> bool Copy(const T *data_ptr, std::vector<uint16_t> &arr)
+{
+ if (data_ptr->values() == nullptr)
+ {
+ return false;
+ }
+
+ int size = data_ptr->values()->size();
+ arr.reserve(size);
+ for (int i = 0; i < size; i++)
+ {
+ arr.emplace_back(static_cast<uint16_t>(data_ptr->values()->Get(i)));
+ }
+ return true;
+}
+
+template <typename LoaderDomain>
+ir::OperandIndex BaseLoader<LoaderDomain>::loadOperand(const Tensor *tensor, ir::Graph &subg)
+{
+ ir::Shape shape;
+ // Shape
+ const auto *tensor_shape = tensor->shape();
+ if (tensor_shape != nullptr)
+ {
+ for (const auto &dim : *tensor_shape)
+ {
+ shape.append(dim);
+ }
+ }
+
+ // Note for tensor->shape_signature()
+ // We don't handle shape signature
+ // How we handle:
+ // If shape_signature[k] == -1, we will use tensor->shape()[k] == 1
+ // If app wants to change the input shape, call nnfw_apply_input_tensorinfo() can
+ // be used.
+
+ // TypeInfo
+ ir::TypeInfo type_info(tensorTypeToDataType(tensor->type()));
+ loadQuantization(tensor, type_info);
+ loadSparsity(tensor, type_info);
+
+ // Create operand
+ const auto operand_index = subg.addOperand(shape, type_info);
+
+ // Constant tensors are indicated by non-empty data.
+ const auto *data = _domain_model->buffers()->Get(tensor->buffer())->data();
+ if (data != nullptr)
+ {
+ using std::ptrdiff_t;
+ std::shared_ptr<ir::Data> data_obj;
+
+ if (_fd == -1) // Model is from memory
+ {
+ data_obj = std::make_shared<ir::ExternalData>(data->data(), data->size());
+ }
+ else // Model is loaded(mmap'd) from a file
+ {
+ size_t data_size = data->size();
+ ptrdiff_t unaligned_offset_start = data->data() - _base;
+ ptrdiff_t offset_end = unaligned_offset_start + data_size;
+
+ // Calculated aligned offset from base address of mapped region
+ // munmap accepts memory address which is a multiple of the pagesize
+ ptrdiff_t aligned_offset_start = (unaligned_offset_start / _pagesize) * _pagesize;
+ size_t mmap_size = offset_end - aligned_offset_start;
+
+ uint32_t buf_idx = tensor->buffer();
+ auto buffer_found = _buf_to_data.find(buf_idx);
+
+ if (buffer_found != _buf_to_data.end())
+ {
+ // Another tensor points this buffer and its matching Data(either CachedData or MMapedData)
+ // was already created. Let's reuse the Data
+ data_obj = buffer_found->second;
+ }
+ else if (_use_mmaped_data)
+ {
+ data_obj = std::make_shared<ir::MMapedData>(_fd, aligned_offset_start, mmap_size,
+ unaligned_offset_start, data_size);
+ _buf_to_data[buf_idx] = data_obj;
+ }
+ else
+ {
+ size_t offset = unaligned_offset_start - aligned_offset_start;
+ uint8_t *mmap_base = static_cast<uint8_t *>(
+ mmap(NULL, mmap_size, PROT_READ, MAP_PRIVATE, _fd, aligned_offset_start));
+
+ data_obj = std::make_shared<ir::CachedData>(mmap_base + offset, data_size);
+ _buf_to_data[buf_idx] = data_obj;
+
+ munmap(mmap_base, mmap_size);
+ }
+ }
+ subg.setOperandValue(operand_index, std::move(data_obj));
+ }
+
+ _tensor_names.emplace(operand_index, tensor->name()->str());
+
+ // Variable
+ if (tensor->is_variable())
+ {
+ if (data != nullptr)
+ throw std::runtime_error("Variable tensor with buffer is not supported!");
+
+ subg.operands().at(operand_index).info().setAsVariable();
+ }
+
+ return operand_index;
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadQuantization(const Tensor *tensor, ir::TypeInfo &typeInfo)
+{
+ auto q_params = tensor->quantization();
+ if (q_params == nullptr || q_params->scale() == nullptr || q_params->scale()->size() == 0)
+ {
+ typeInfo.quantization(0., 0);
+ return;
+ }
+ if (q_params->zero_point() == nullptr)
+ {
+ throw std::runtime_error("Quantization params: scale is not null, but zero_point is null.");
+ }
+ const size_t num_scales = q_params->scale()->size();
+ if (num_scales != q_params->zero_point()->size())
+ {
+ throw std::runtime_error("Quantization params: scale size != zero_point size");
+ }
+ std::vector<float> scales;
+ std::vector<int32_t> zero_points;
+ scales.resize(num_scales);
+ zero_points.resize(num_scales);
+ for (size_t i = 0; i < num_scales; ++i)
+ {
+ scales[i] = q_params->scale()->Get(i);
+ // zero_point is defined as long (i64) in schema while TypeInfo's zero_point is int32_t.
+ // int64_t is used instead of long because long is 4 byte in most 32bit architecture.
+ int64_t zero_point = q_params->zero_point()->Get(i);
+ if (zero_point < std::numeric_limits<int32_t>::min() ||
+ zero_point > std::numeric_limits<int32_t>::max())
+ throw std::runtime_error("Zero_point is out of int32 range.");
+ zero_points[i] = static_cast<int32_t>(zero_point);
+ }
+ auto details = q_params->details_as_CustomQuantization();
+ if (details != nullptr)
+ throw std::runtime_error("Custom Quantization is not supported");
+ typeInfo.quantization(std::move(scales), std::move(zero_points));
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadSparsity(const Tensor *tensor, ir::TypeInfo &typeInfo)
+{
+ auto src_sparsity = tensor->sparsity();
+ if (src_sparsity != nullptr)
+ {
+ std::vector<uint16_t> w1_segments;
+ std::vector<uint16_t> w1_indices;
+ // check traversal_order
+ if (src_sparsity->traversal_order())
+ {
+ const int traversal_order_size = src_sparsity->traversal_order()->size();
+ for (int i = 0; i < traversal_order_size; ++i)
+ {
+ if (i != src_sparsity->traversal_order()->Get(i))
+ throw std::runtime_error("traversal_order [0, 1, ..., n-1] is only supported.");
+ }
+ }
+ // check block_map
+ int block_rank = 0;
+ if (src_sparsity->block_map())
+ {
+ block_rank = src_sparsity->block_map()->size();
+ for (int i = 0; i < block_rank; ++i)
+ {
+ if (i != src_sparsity->block_map()->Get(i))
+ throw std::runtime_error("block_map [0, 1, ..., n-1] is only supported.");
+ }
+ }
+ // load metadata
+ const auto dim_metadata_size = src_sparsity->dim_metadata()->size();
+ const auto dense_rank = tensor->shape() ? tensor->shape()->size() : 0;
+ if (dense_rank + block_rank != dim_metadata_size)
+ throw std::runtime_error("sparsity dim_metadata length is wrong.");
+ bool random_sparsity = dim_metadata_size == 2 && block_rank == 0;
+ bool block2D_sparsity = dim_metadata_size == 4 && block_rank == 2;
+ if (dim_metadata_size != !random_sparsity && !block2D_sparsity)
+ throw std::runtime_error(
+ "sparsity is supported only for 2D tensor with random or 16x1 block sparsity.");
+
+ const auto *src_metadata = src_sparsity->dim_metadata()->Get(0);
+ if (src_metadata->format() != DimensionType::DimensionType_DENSE)
+ throw std::runtime_error("sparse tensor dim[0] is not DENSE");
+ src_metadata = src_sparsity->dim_metadata()->Get(1);
+ if (src_metadata->format() != DimensionType::DimensionType_SPARSE_CSR)
+ throw std::runtime_error("sparse tensor dim[0] is not SPARSE_CSR");
+ auto ParseSparseIndexVector = [src_metadata, &w1_segments, &w1_indices]() {
+ if (src_metadata->array_segments() == nullptr || src_metadata->array_indices() == nullptr)
+ return false;
+ bool status = true;
+ /* `onert` inernally uses uint16 type regardless of the value of
+ the array_segments_type and array_indices_type */
+ switch (src_metadata->array_segments_type())
+ {
+ case SparseIndexVector::SparseIndexVector_Int32Vector:
+ throw std::runtime_error("sparse tensor with int32 segment type is not supported");
+ case SparseIndexVector::SparseIndexVector_Uint16Vector:
+ status = Copy(src_metadata->array_segments_as_Uint16Vector(), w1_segments);
+ break;
+ case SparseIndexVector::SparseIndexVector_Uint8Vector:
+ status = Copy(src_metadata->array_segments_as_Uint8Vector(), w1_segments);
+ break;
+ default:
+ return false;
+ }
+ if (status != true)
+ return false;
+ switch (src_metadata->array_indices_type())
+ {
+ case SparseIndexVector::SparseIndexVector_Int32Vector:
+ throw std::runtime_error("sparse tensor with int32 indices type is not supported");
+ case SparseIndexVector::SparseIndexVector_Uint16Vector:
+ return Copy(src_metadata->array_indices_as_Uint16Vector(), w1_indices);
+ case SparseIndexVector::SparseIndexVector_Uint8Vector:
+ return Copy(src_metadata->array_indices_as_Uint8Vector(), w1_indices);
+ default:
+ break;
+ }
+ return false;
+ };
+ if (ParseSparseIndexVector() == false)
+ throw std::runtime_error("Error during parsing sparsity index information");
+ // Get block size
+ std::vector<int32_t> block_size;
+ for (int i = 0; i < block_rank; ++i)
+ {
+ auto block_metadata = src_sparsity->dim_metadata()->Get(dense_rank + i);
+ if (block_metadata->format() != DimensionType::DimensionType_DENSE)
+ throw std::runtime_error("block dimension must be DENSE.");
+ block_size.push_back(block_metadata->dense_size());
+ }
+ typeInfo.sparsity(std::make_shared<ir::Sparsity>(std::move(w1_segments), std::move(w1_indices),
+ std::move(block_size)));
+ }
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadOperationIO(const Operator *op, ir::OperandIndexSequence &inputs,
+ ir::OperandIndexSequence &outputs)
+{
+ for (const std::int32_t idx : *op->inputs())
+ {
+ // Optional tensors are not supported yet except for FULLY_CONNECTED and BCQ_FULLY_CONNECTED
+ auto check_optional_input = [&]() {
+ auto builtin_code = getBuiltinOperator(op);
+ if (isOptionalInputTensor(idx) && !allowOptionalInputTensor(builtin_code))
+ throw std::runtime_error(
+ std::string("loader doesn't support optional input tensor yet for ")
+ .append(EnumNameBuiltinOperator(builtin_code)));
+ };
+ check_optional_input();
+ inputs.append(tensorIdxToOperandIdx(idx));
+ }
+
+ for (const std::int32_t idx : *op->outputs())
+ {
+ outputs.append(tensorIdxToOperandIdx(idx));
+ }
+}
+
+template <typename LoaderDomain>
+template <typename Param, typename OptionsType>
+void BaseLoader<LoaderDomain>::loadStridesAndPaddings(Param &param, const OptionsType *options)
+{
+ // Strides
+ param.stride.vertical = options->stride_h();
+ param.stride.horizontal = options->stride_w();
+ // Paddings
+ switch (options->padding())
+ {
+ case Padding::Padding_SAME:
+ param.padding.type = ir::PaddingType::SAME;
+ break;
+ case Padding::Padding_VALID:
+ param.padding.type = ir::PaddingType::VALID;
+ break;
+ default:
+ throw std::runtime_error{"Invalid padding type"};
+ }
+ // param paddings indexes unused
+}
+
+template <typename LoaderDomain>
+template <typename Param>
+void BaseLoader<LoaderDomain>::loadPool2DOptions(Param &param, const Pool2DOptions *options)
+{
+ // Strides and Paddings
+ if (options->stride_h() <= 0 || options->stride_w() <= 0)
+ throw std::runtime_error{"Invalid stride vertical or horizontal - both must be bigger than 0"};
+ loadStridesAndPaddings(param, options);
+ // Filter width and height
+ // Strides
+ if (options->filter_width() <= 0 || options->filter_height() <= 0)
+ throw std::runtime_error{"Invalid filter width or height - both must be bigger than 0"};
+ param.kw = options->filter_width();
+ param.kh = options->filter_height();
+ // Activation
+ param.activation = convertActivation(options->fused_activation_function());
+}
+
+template <typename LoaderDomain>
+template <typename OpIR, typename... Args>
+const OpIR *BaseLoader<LoaderDomain>::loadOperationTo(const Operator *op, ir::Graph &subg,
+ Args &&...args)
+{
+ static_assert(sizeof...(args) <= 1, "You can't have more than 1 arguments!");
+ ir::OperandIndexSequence inputs;
+ ir::OperandIndexSequence outputs;
+
+ loadOperationIO(op, inputs, outputs);
+
+ std::unique_ptr<OpIR> new_op(new OpIR(inputs, outputs, std::forward<Args>(args)...));
+ auto ret = new_op.get();
+ subg.addOperation(std::move(new_op));
+
+ return ret;
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadConv2D(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::Conv2D::Param param;
+ const auto *options = op->builtin_options_as_Conv2DOptions();
+ param.activation = convertActivation(options->fused_activation_function());
+ loadStridesAndPaddings(param, options);
+ param.dilation.width_factor = options->dilation_w_factor();
+ param.dilation.height_factor = options->dilation_h_factor();
+
+ const auto conv = loadOperationTo<ir::operation::Conv2D>(op, subg, param);
+
+ // TFLite support old hybrid quantization (float input/output, uint8 kernel)
+ // but it interprets weight type as init8 internally
+ const auto &input_operand =
+ subg.operands().at(conv->getInputs().at(ir::operation::Conv2D::INPUT));
+ auto &weights_operand = subg.operands().at(conv->getInputs().at(ir::operation::Conv2D::KERNEL));
+ if (input_operand.typeInfo().type() == ir::DataType::FLOAT32 &&
+ ((weights_operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM) ||
+ weights_operand.typeInfo().type() == ir::DataType::QUANT_INT8_ASYMM))
+ {
+ weights_operand.type(ir::DataType::QUANT_INT8_SYMM);
+ }
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadDepthwiseConv2D(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::DepthwiseConv2D::Param param;
+ const auto *options = op->builtin_options_as_DepthwiseConv2DOptions();
+ param.activation = convertActivation(options->fused_activation_function());
+ loadStridesAndPaddings(param, options);
+ param.multiplier = options->depth_multiplier();
+ // Dilation h/w factor unused
+ param.dilation.width_factor = options->dilation_w_factor();
+ param.dilation.height_factor = options->dilation_h_factor();
+
+ const auto dconv = loadOperationTo<ir::operation::DepthwiseConv2D>(op, subg, param);
+
+ // TFLite does not support old hybrid quantization (float input/output, uint8 kernel)
+ // for depthwise convolution.
+ // But for consistency with Conv2D and FC, we interpret weight type as init8 internally
+ const auto &input_operand =
+ subg.operands().at(dconv->getInputs().at(ir::operation::DepthwiseConv2D::INPUT));
+ auto &weights_operand =
+ subg.operands().at(dconv->getInputs().at(ir::operation::DepthwiseConv2D::KERNEL));
+ if (input_operand.typeInfo().type() == ir::DataType::FLOAT32 &&
+ ((weights_operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM) ||
+ weights_operand.typeInfo().type() == ir::DataType::QUANT_INT8_ASYMM))
+ {
+ weights_operand.type(ir::DataType::QUANT_INT8_SYMM);
+ }
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadTransposeConv(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::TransposeConv::Param param;
+ const auto *options = op->builtin_options_as_TransposeConvOptions();
+ loadStridesAndPaddings(param, options);
+
+ loadOperationTo<ir::operation::TransposeConv>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadPool2D(const Operator *op, ir::Graph &subg,
+ ir::operation::Pool2D::PoolType op_type)
+{
+ ir::operation::Pool2D::Param param;
+ param.op_type = op_type;
+ const auto *options = op->builtin_options_as_Pool2DOptions();
+
+ loadPool2DOptions(param, options);
+
+ loadOperationTo<ir::operation::Pool2D>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadReshape(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::Reshape::Param param{};
+ const auto *options = op->builtin_options_as_ReshapeOptions();
+ if (options != nullptr)
+ {
+ const auto *new_shape = options->new_shape();
+ if (new_shape)
+ {
+ for (uint i = 0; i < new_shape->size(); ++i)
+ {
+ param.new_shape.push_back(new_shape->Get(i));
+ }
+ }
+ }
+
+ loadOperationTo<ir::operation::Reshape>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadSoftmax(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::Softmax::Param param;
+ const auto *options = op->builtin_options_as_SoftmaxOptions();
+ // Beta
+ param.beta = options->beta();
+
+ loadOperationTo<ir::operation::Softmax>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadConcatenation(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::Concat::Param param;
+ const auto *options = op->builtin_options_as_ConcatenationOptions();
+ // Axis
+ param.axis = options->axis();
+ // activation unused
+
+ loadOperationTo<ir::operation::Concat>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadFC(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::FullyConnected::Param param;
+ const auto *options = op->builtin_options_as_FullyConnectedOptions();
+
+ param.activation = convertActivation(options->fused_activation_function());
+ param.weights_format = static_cast<ir::FullyConnectedWeightsFormat>(options->weights_format());
+
+ const auto fc = loadOperationTo<ir::operation::FullyConnected>(op, subg, param);
+
+ // TFLite supports old hybrid quantization (float input/output, uint8 kernel)
+ // but it interprets weight type as init8 internally
+ const auto &input_operand =
+ subg.operands().at(fc->getInputs().at(ir::operation::FullyConnected::INPUT));
+ auto &weights_operand =
+ subg.operands().at(fc->getInputs().at(ir::operation::FullyConnected::WEIGHT));
+ if (input_operand.typeInfo().type() == ir::DataType::FLOAT32 &&
+ ((weights_operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM) ||
+ weights_operand.typeInfo().type() == ir::DataType::QUANT_INT8_ASYMM))
+ {
+ weights_operand.type(ir::DataType::QUANT_INT8_SYMM);
+ }
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadAddV2(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::BinaryArithmetic::Param param;
+ param.arithmetic_type = ir::operation::BinaryArithmetic::ArithmeticType::ADD;
+
+ if (op->custom_options() == nullptr)
+ {
+ param.activation = ir::Activation::NONE;
+ }
+ else
+ {
+ const auto attr_map = getCustomOpAttrMap(op);
+ const auto fused_activation_func = static_cast<typename LoaderDomain::ActivationFunctionType>(
+ attr_map["fused_activation_function"].AsInt8());
+ param.activation = convertActivation(fused_activation_func);
+ }
+
+ loadOperationTo<ir::operation::BinaryArithmetic>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadDepthToSpace(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::DepthToSpace::Param param;
+ const auto *options = op->builtin_options_as_DepthToSpaceOptions();
+ param.block_size = options->block_size();
+
+ loadOperationTo<ir::operation::DepthToSpace>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadBinaryArithmetic(
+ const Operator *op, ir::Graph &subg, ir::operation::BinaryArithmetic::ArithmeticType op_type)
+{
+ ir::operation::BinaryArithmetic::Param param;
+ param.arithmetic_type = op_type;
+ switch (op_type)
+ {
+ case ir::operation::BinaryArithmetic::ArithmeticType::ADD:
+ {
+ const auto *add_options = op->builtin_options_as_AddOptions();
+ param.activation = convertActivation(add_options->fused_activation_function());
+ break;
+ }
+ case ir::operation::BinaryArithmetic::ArithmeticType::SUB:
+ {
+ const auto *sub_options = op->builtin_options_as_SubOptions();
+ param.activation = convertActivation(sub_options->fused_activation_function());
+ break;
+ }
+ case ir::operation::BinaryArithmetic::ArithmeticType::MUL:
+ {
+ const auto *mul_options = op->builtin_options_as_MulOptions();
+ param.activation = convertActivation(mul_options->fused_activation_function());
+ break;
+ }
+ case ir::operation::BinaryArithmetic::ArithmeticType::DIV:
+ {
+ const auto *div_options = op->builtin_options_as_DivOptions();
+ param.activation = convertActivation(div_options->fused_activation_function());
+ break;
+ }
+ default:
+ assert(false &&
+ "The function 'loadBinaryArithmetic' supports only BinaryArithmetic operations");
+ break;
+ }
+
+ loadOperationTo<ir::operation::BinaryArithmetic>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadPack(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::Pack::Param param;
+ const auto *options = op->builtin_options_as_PackOptions();
+ param.num = options->values_count();
+ param.axis = options->axis();
+
+ loadOperationTo<ir::operation::Pack>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadElementwiseActivation(
+ const Operator *op, ir::Graph &subg, ir::operation::ElementwiseActivation::Type op_type,
+ float alpha, float beta)
+{
+ ir::operation::ElementwiseActivation::Param param;
+ param.op_type = op_type;
+ param.alpha = alpha;
+ param.beta = beta;
+
+ loadOperationTo<ir::operation::ElementwiseActivation>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadResizeBilinear(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::ResizeBilinear::Param param;
+ // heigh_out and width_out is used on NNAPI only
+ assert(op->inputs()->size() == 2);
+ param.height_out = 0;
+ param.width_out = 0;
+ param.align_corners = op->builtin_options_as_ResizeBilinearOptions()->align_corners();
+ param.half_pixel_centers = op->builtin_options_as_ResizeBilinearOptions()->half_pixel_centers();
+
+ loadOperationTo<ir::operation::ResizeBilinear>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadResizeNearestNeighbor(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::ResizeNearestNeighbor::Param param;
+ // heigh_out and width_out is used on NNAPI only
+ assert(op->inputs()->size() == 2);
+ param.height_out = 0;
+ param.width_out = 0;
+ param.align_corners = op->builtin_options_as_ResizeNearestNeighborOptions()->align_corners();
+
+ loadOperationTo<ir::operation::ResizeNearestNeighbor>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadReduce(const Operator *op, ir::Graph &subg,
+ ir::operation::Reduce::ReduceType reduce_type)
+{
+ ir::operation::Reduce::Param param;
+ param.reduce_type = reduce_type;
+ param.keep_dims = op->builtin_options_as_ReducerOptions()->keep_dims();
+
+ loadOperationTo<ir::operation::Reduce>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadReduceAll(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::Reduce::Param param;
+ param.reduce_type = ir::operation::Reduce::ReduceType::ALL;
+ if (op->custom_options() == nullptr)
+ {
+ param.keep_dims = false;
+ }
+ else
+ {
+ const auto attr_map = getCustomOpAttrMap(op);
+ param.keep_dims = attr_map["keep_dims"].AsBool();
+ }
+
+ loadOperationTo<ir::operation::Reduce>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadElementwiseBinary(
+ const Operator *op, ir::Graph &subg,
+ ir::operation::ElementwiseBinary::ElementwiseBinaryType op_type)
+{
+ ir::operation::ElementwiseBinary::Param param;
+ param.op_type = op_type;
+
+ loadOperationTo<ir::operation::ElementwiseBinary>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadElementwiseUnary(const Operator *op, ir::Graph &subg,
+ ir::operation::ElementwiseUnary::Type op_type)
+{
+ ir::operation::ElementwiseUnary::Param param;
+ param.op_type = op_type;
+
+ const auto eu = loadOperationTo<ir::operation::ElementwiseUnary>(op, subg, param);
+ if (op_type == ir::operation::ElementwiseUnary::Type::CAST)
+ {
+ auto qasymm8ToUint8 = [](ir::Operand &operand) {
+ if (operand.typeInfo().type() == ir::DataType::QUANT_UINT8_ASYMM)
+ {
+ operand.type(ir::DataType::UINT8);
+ }
+ };
+ qasymm8ToUint8(
+ subg.operands().at(eu->getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT)));
+ qasymm8ToUint8(subg.operands().at(eu->getOutputs().at(0)));
+ }
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadGather(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::Gather::Param param;
+ param.axis = op->builtin_options_as_GatherOptions()->axis();
+
+ loadOperationTo<ir::operation::Gather>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadDetectionPostProcess(const Operator *op, ir::Graph &subg)
+{
+ const auto &m = getCustomOpAttrMap(op);
+
+ ir::operation::DetectionPostProcess::Param param;
+
+ param.max_detections = m["max_detections"].AsInt32();
+
+ // TODO fixme
+ param.max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
+ if (m["detections_per_class"].IsNull())
+ param.max_boxes_per_class = 100;
+ else
+ param.max_boxes_per_class = m["detections_per_class"].AsInt32();
+
+ if (m["use_regular_nms"].IsNull())
+ param.do_fast_eval = true;
+ else
+ param.do_fast_eval = !m["use_regular_nms"].AsBool();
+
+ param.score_threshold = m["nms_score_threshold"].AsFloat();
+ param.iou_threshold = m["nms_iou_threshold"].AsFloat();
+
+ // TODO add num classes support
+ param.num_classes = m["num_classes"].AsInt32();
+
+ param.scale.y_scale = m["y_scale"].AsFloat();
+ param.scale.x_scale = m["x_scale"].AsFloat();
+ param.scale.h_scale = m["h_scale"].AsFloat();
+ param.scale.w_scale = m["w_scale"].AsFloat();
+
+ // TODO depends on input model framework
+ param.center_size_boxes = true;
+
+ loadOperationTo<ir::operation::DetectionPostProcess>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadBatchMatMul(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::BatchMatMul::Param param;
+
+ const auto builtin_op = getBuiltinOperator(op);
+
+ switch (builtin_op)
+ {
+ case BuiltinOperator::BuiltinOperator_BATCH_MATMUL:
+ // Handled on each loader: different option name
+ // Circle: adjoint_lhs, adjoint_rhs
+ // TFLite: adj_x, adj_y
+ throw std::runtime_error(
+ std::string("Cannot handle here: ").append(EnumNameBuiltinOperator(builtin_op)) + " as " +
+ EnumNameBuiltinOperator(BuiltinOperator::BuiltinOperator_BATCH_MATMUL));
+ case BuiltinOperator::BuiltinOperator_CUSTOM:
+ if (op->custom_options() == nullptr)
+ {
+ param.adj_x = false;
+ param.adj_y = false;
+ }
+ else
+ {
+ const auto attr_map = getCustomOpAttrMap(op);
+ param.adj_x = attr_map["adj_x"].AsBool();
+ param.adj_y = attr_map["adj_y"].AsBool();
+ }
+ break;
+ default:
+ throw std::runtime_error(
+ std::string("Wrong loaded operation: ").append(EnumNameBuiltinOperator(builtin_op)) +
+ " as " + EnumNameBuiltinOperator(BuiltinOperator::BuiltinOperator_BATCH_MATMUL));
+ }
+
+ loadOperationTo<ir::operation::BatchMatMul>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadSpaceToDepth(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::SpaceToDepth::Param param;
+ const auto *options = op->builtin_options_as_SpaceToDepthOptions();
+ param.block_size = options->block_size();
+
+ loadOperationTo<ir::operation::SpaceToDepth>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadCustom(const Operator *op, ir::Graph &subg)
+{
+ ir::OperandIndexSequence inputs;
+ ir::OperandIndexSequence outputs;
+
+ assert(op->custom_options_format() == CustomOptionsFormat::CustomOptionsFormat_FLEXBUFFERS &&
+ "Unsupported custom operation options format");
+
+ auto *op_code = _domain_model->operator_codes()->Get(op->opcode_index());
+ auto custom_op_name = op_code->custom_code()->str();
+
+ enum class BuiltinOP
+ {
+ AddV2,
+ ReduceAll,
+ MatrixBandPart,
+ BatchMatMul,
+ Einsum,
+ BroadcastTo,
+ FusedBatchNorm,
+ StatelessRandomUniform,
+ Erf,
+ DetectionPostProcess
+ };
+
+ // Mapping from custom op name string to BuiltinOP enum
+ std::map<std::string, BuiltinOP> builtin_map = {
+ {"AddV2", BuiltinOP::AddV2},
+ {"All", BuiltinOP::ReduceAll},
+ {"MatrixBandPart", BuiltinOP::MatrixBandPart},
+ {"BatchMatMulV2", BuiltinOP::BatchMatMul},
+ {"Einsum", BuiltinOP::Einsum},
+ {"FusedBatchNormV3", BuiltinOP::FusedBatchNorm},
+ {"BroadcastTo", BuiltinOP::BroadcastTo},
+ {"StatelessRandomUniform", BuiltinOP::StatelessRandomUniform},
+ {"Erf", BuiltinOP::Erf},
+ {"TFLite_Detection_PostProcess", BuiltinOP::DetectionPostProcess},
+ };
+
+ try
+ {
+ // Throw out_of_range if it is unknown custom op
+ auto custom_op_id = builtin_map.at(custom_op_name);
+ switch (custom_op_id)
+ {
+ case BuiltinOP::AddV2:
+ loadAddV2(op, subg);
+ break;
+ case BuiltinOP::ReduceAll:
+ loadReduceAll(op, subg);
+ break;
+ case BuiltinOP::MatrixBandPart:
+ loadOperationTo<ir::operation::MatrixBandPart>(op, subg);
+ break;
+ case BuiltinOP::BatchMatMul:
+ loadBatchMatMul(op, subg);
+ break;
+ case BuiltinOP::Einsum:
+ loadEinsum(op, subg);
+ break;
+ case BuiltinOP::BroadcastTo:
+ loadOperationTo<ir::operation::BroadcastTo>(op, subg);
+ break;
+ case BuiltinOP::FusedBatchNorm:
+ loadFusedBatchNorm(op, subg);
+ break;
+ case BuiltinOP::StatelessRandomUniform:
+ loadOperationTo<ir::operation::StatelessRandomUniform>(op, subg);
+ break;
+ case BuiltinOP::Erf:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ERF);
+ break;
+ case BuiltinOP::DetectionPostProcess:
+ loadDetectionPostProcess(op, subg);
+ break;
+ default:
+ throw std::runtime_error{
+ "Loader: Custom OP map is defined but operation loader function is not defined"};
+ }
+
+ return;
+ }
+ catch (...)
+ {
+ loadOperationIO(op, inputs, outputs);
+
+ auto constraint = ir::OperandConstraint::createExact(inputs.size());
+
+ size_t custom_op_data_size = op->custom_options()->size();
+ auto custom_op_data = new char[custom_op_data_size];
+ std::copy(op->custom_options()->begin(), op->custom_options()->end(), custom_op_data);
+
+ ir::operation::Custom::Userdata userdata{};
+ userdata.data = custom_op_data;
+ userdata.size = custom_op_data_size;
+
+ auto new_op = std::make_unique<ir::operation::Custom>(constraint, inputs, outputs,
+ custom_op_name, userdata);
+
+ subg.addOperation(std::move(new_op));
+ }
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadSqueeze(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::Squeeze::Param param;
+ const auto *options = op->builtin_options_as_SqueezeOptions();
+ const auto *dims = options->squeeze_dims();
+ if (dims)
+ {
+ if (dims->size() > sizeof(param.dims) / sizeof(param.dims[0]))
+ throw std::runtime_error("Squeeze: 'param.ndims' is out of range.");
+ param.ndim = dims->size();
+ for (int i = 0; i < param.ndim; ++i)
+ param.dims[i] = dims->Get(i);
+ }
+ else
+ param.ndim = 0;
+
+ loadOperationTo<ir::operation::Squeeze>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadSplit(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::Split::Param param;
+ const auto *options = op->builtin_options_as_SplitOptions();
+ param.num_splits = options->num_splits();
+
+ loadOperationTo<ir::operation::Split>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadSplitV(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::SplitV::Param param;
+ const auto *options = op->builtin_options_as_SplitVOptions();
+ param.num_splits = options->num_splits();
+
+ loadOperationTo<ir::operation::SplitV>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadStridedSlice(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::StridedSlice::Param param;
+ const auto *options = op->builtin_options_as_StridedSliceOptions();
+ param.begin_mask = options->begin_mask();
+ param.end_mask = options->end_mask();
+ param.shrink_axis_mask = options->shrink_axis_mask();
+
+ loadOperationTo<ir::operation::StridedSlice>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadUnpack(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::Unpack::Param param;
+ const auto *options = op->builtin_options_as_UnpackOptions();
+ param.num = options->num();
+ param.axis = options->axis();
+
+ loadOperationTo<ir::operation::Unpack>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadComparison(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::Comparison::Param param;
+ const auto builtin_op = getBuiltinOperator(op);
+
+ switch (builtin_op)
+ {
+ case BuiltinOperator::BuiltinOperator_EQUAL:
+ param.comparison_type = ir::operation::Comparison::ComparisonType::Equal;
+ break;
+ case BuiltinOperator::BuiltinOperator_NOT_EQUAL:
+ param.comparison_type = ir::operation::Comparison::ComparisonType::NotEqual;
+ break;
+ case BuiltinOperator::BuiltinOperator_GREATER_EQUAL:
+ param.comparison_type = ir::operation::Comparison::ComparisonType::GreaterEqual;
+ break;
+ case BuiltinOperator::BuiltinOperator_GREATER:
+ param.comparison_type = ir::operation::Comparison::ComparisonType::Greater;
+ break;
+ case BuiltinOperator::BuiltinOperator_LESS_EQUAL:
+ param.comparison_type = ir::operation::Comparison::ComparisonType::LessEqual;
+ break;
+ case BuiltinOperator::BuiltinOperator_LESS:
+ param.comparison_type = ir::operation::Comparison::ComparisonType::Less;
+ break;
+ default:
+ throw std::runtime_error(
+ std::string("Unsupported operation: ").append(EnumNameBuiltinOperator(builtin_op)));
+ }
+
+ loadOperationTo<ir::operation::Comparison>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadEinsum(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::Einsum::Param param;
+ if (op->custom_options() == nullptr)
+ {
+ throw std::runtime_error{"Einsum: empty equation"};
+ }
+ else
+ {
+ const auto attr_map = getCustomOpAttrMap(op);
+ param.equation = attr_map["equation"].ToString();
+ }
+
+ const auto es = loadOperationTo<ir::operation::Einsum>(op, subg, param);
+ if (es->getInputs().size() != 2)
+ {
+ throw std::runtime_error{"Einsum: NYI input - only support two inputs"};
+ }
+}
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadFusedBatchNorm(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::FusedBatchNorm::Param param;
+ if (op->custom_options() == nullptr)
+ {
+ throw std::runtime_error{"FusedBatchNorm: empty option"};
+ }
+ else
+ {
+ const auto attr_map = getCustomOpAttrMap(op);
+ param.is_training = attr_map["is_training"].AsBool();
+ param.epsilon = attr_map["epsilon"].AsFloat();
+ param.data_format = attr_map["data_format"].ToString();
+ }
+
+ const auto fbn = loadOperationTo<ir::operation::FusedBatchNorm>(op, subg, param);
+
+ if (fbn->getInputs().size() != 5)
+ {
+ throw std::runtime_error{"FusedBatchNorm: NYI input - only support five inputs"};
+ }
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadOneHot(const Operator *op, ir::Graph &subg)
+{
+ if (op->inputs()->size() != 4 || op->outputs()->size() != 1)
+ throw std::runtime_error("OneHot Op has wrong number of input or output tensors.");
+
+ // Set parameter
+ ir::operation::OneHot::Param param;
+ param.axis = op->builtin_options_as_OneHotOptions()->axis();
+
+ loadOperationTo<ir::operation::OneHot>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadIf(const Operator *op, ir::Graph &subg)
+{
+ const auto *options = op->builtin_options_as_IfOptions();
+ const int32_t then_index = options->then_subgraph_index();
+ const int32_t else_index = options->else_subgraph_index();
+
+ verifySubgraphIndex(then_index);
+ verifySubgraphIndex(else_index);
+
+ ir::operation::If::Param param;
+ param.then_subg_index = ir::SubgraphIndex{static_cast<uint16_t>(then_index)};
+ param.else_subg_index = ir::SubgraphIndex{static_cast<uint16_t>(else_index)};
+
+ loadOperationTo<ir::operation::If>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadWhile(const Operator *op, ir::Graph &subg)
+{
+ const auto *options = op->builtin_options_as_WhileOptions();
+ const int32_t cond_index = options->cond_subgraph_index();
+ const int32_t body_index = options->body_subgraph_index();
+
+ verifySubgraphIndex(cond_index);
+ verifySubgraphIndex(body_index);
+
+ ir::operation::While::Param param;
+ param.cond_subg_index = ir::SubgraphIndex{static_cast<uint16_t>(cond_index)};
+ param.body_subg_index = ir::SubgraphIndex{static_cast<uint16_t>(body_index)};
+
+ loadOperationTo<ir::operation::While>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadArgMinMax(const Operator *op, ir::Graph &subg, bool is_argmax)
+{
+ ir::operation::ArgMinMax::Param param;
+ const auto output_type = is_argmax ? op->builtin_options_as_ArgMaxOptions()->output_type()
+ : op->builtin_options_as_ArgMinOptions()->output_type();
+ param.output_type = tensorTypeToDataType(output_type);
+ param.is_arg_max = is_argmax;
+
+ loadOperationTo<ir::operation::ArgMinMax>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadLogSoftmax(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::LogSoftmax::Param param;
+ // In tflite, beta is fixed to 1.0 and axis is fixed to -1.
+ param.beta = 1.0f;
+ param.axis = -1;
+
+ loadOperationTo<ir::operation::LogSoftmax>(op, subg, param);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadLeakyRelu(const Operator *op, ir::Graph &subg)
+{
+ float alpha = op->builtin_options_as_LeakyReluOptions()->alpha();
+ loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::LEAKY_RELU, alpha,
+ 1.f);
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadUnidirectionalSequenceLSTM(const Operator *op, ir::Graph &subg)
+{
+ ir::operation::LSTM::Param param;
+ const auto *options = op->builtin_options_as_UnidirectionalSequenceLSTMOptions();
+ param.activation = convertActivation(options->fused_activation_function());
+ param.cell_threshold = options->cell_clip();
+ param.projection_threshold = options->proj_clip();
+ param.time_major = options->time_major();
+ // The asymmetric_quantize_inputs option is unused yet
+
+ ir::OperandIndexSequence inputs;
+ for (const std::int32_t idx : *op->inputs())
+ {
+ inputs.append(tensorIdxToOperandIdx(idx));
+ }
+
+ ir::OperandIndexSequence outputs;
+ // loader doesn't support optional output tensor yet
+ if (op->outputs()->size() != 1)
+ {
+ auto builtin_code = getBuiltinOperator(op);
+ throw std::runtime_error(std::string("loader doesn't support optional output tensor yet for ")
+ .append(EnumNameBuiltinOperator(builtin_code)));
+ }
+ for (size_t i = 0; i < ir::operation::LSTM::Output::OUTPUT; ++i)
+ {
+ // Add optional outputs
+ outputs.append(ir::OperandIndex());
+ }
+ outputs.append(tensorIdxToOperandIdx(op->outputs()->Get(0)));
+
+ std::unique_ptr<ir::operation::LSTM> new_op(new ir::operation::LSTM(inputs, outputs, param));
+ subg.addOperation(std::move(new_op));
+}
+
+template <typename LoaderDomain>
+void BaseLoader<LoaderDomain>::loadOperation(const Operator *op, ir::Graph &subg)
+{
+ auto const builtin_op = getBuiltinOperator(op);
+
+ switch (builtin_op)
+ {
+ case BuiltinOperator::BuiltinOperator_ADD_N:
+ loadOperationTo<ir::operation::AddN>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_CONV_2D:
+ loadConv2D(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D:
+ loadPool2D(op, subg, ir::operation::Pool2D::PoolType::AVG);
+ return;
+ case BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D:
+ loadDepthwiseConv2D(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_TRANSPOSE_CONV:
+ loadTransposeConv(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_RESHAPE:
+ loadReshape(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_SOFTMAX:
+ loadSoftmax(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_MAX_POOL_2D:
+ loadPool2D(op, subg, ir::operation::Pool2D::PoolType::MAX);
+ return;
+ case BuiltinOperator::BuiltinOperator_CONCATENATION:
+ loadConcatenation(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_FLOOR:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::FLOOR);
+ return;
+ case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED:
+ loadFC(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_ADD:
+ loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::ADD);
+ return;
+ case BuiltinOperator::BuiltinOperator_SUB:
+ loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::SUB);
+ return;
+ case BuiltinOperator::BuiltinOperator_MUL:
+ loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::MUL);
+ return;
+ case BuiltinOperator::BuiltinOperator_DIV:
+ loadBinaryArithmetic(op, subg, ir::operation::BinaryArithmetic::ArithmeticType::DIV);
+ return;
+ case BuiltinOperator::BuiltinOperator_PACK:
+ loadPack(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_ELU:
+ loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::ELU);
+ return;
+ case BuiltinOperator::BuiltinOperator_RELU:
+ loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::RELU,
+ ir::operation::ElementwiseActivation::infinity, 0.f);
+ return;
+ case BuiltinOperator::BuiltinOperator_RELU_N1_TO_1:
+ loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::RELU, 1.f,
+ -1.f);
+ return;
+ case BuiltinOperator::BuiltinOperator_RELU6:
+ loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::RELU, 6.f,
+ 0.f);
+ return;
+ case BuiltinOperator::BuiltinOperator_RESIZE_BILINEAR:
+ loadResizeBilinear(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
+ loadResizeNearestNeighbor(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_RSQRT:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::RSQRT);
+ return;
+ case BuiltinOperator::BuiltinOperator_SELECT:
+ case BuiltinOperator::BuiltinOperator_SELECT_V2:
+ loadOperationTo<ir::operation::Select>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_SQRT:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::SQRT);
+ return;
+ case BuiltinOperator::BuiltinOperator_SQUARE:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::SQUARE);
+ return;
+ case BuiltinOperator::BuiltinOperator_SQUARED_DIFFERENCE:
+ loadOperationTo<ir::operation::SquaredDifference>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_TANH:
+ loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::TANH, 1.f,
+ 1.f);
+ return;
+ case BuiltinOperator::BuiltinOperator_TRANSPOSE:
+ loadOperationTo<ir::operation::Transpose>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_MEAN:
+ loadReduce(op, subg, ir::operation::Reduce::ReduceType::MEAN);
+ return;
+ case BuiltinOperator::BuiltinOperator_REDUCE_ANY:
+ loadReduce(op, subg, ir::operation::Reduce::ReduceType::ANY);
+ return;
+ case BuiltinOperator::BuiltinOperator_REDUCE_MAX:
+ loadReduce(op, subg, ir::operation::Reduce::ReduceType::MAX);
+ return;
+ case BuiltinOperator::BuiltinOperator_REVERSE_V2:
+ loadOperationTo<ir::operation::Reverse>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_PAD:
+ case BuiltinOperator::BuiltinOperator_PADV2:
+ loadOperationTo<ir::operation::Pad>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_LOGISTIC:
+ loadElementwiseActivation(op, subg, ir::operation::ElementwiseActivation::Type::LOGISTIC);
+ return;
+ case BuiltinOperator::BuiltinOperator_EXP:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::EXP);
+ return;
+ case BuiltinOperator::BuiltinOperator_EXPAND_DIMS:
+ loadOperationTo<ir::operation::ExpandDims>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_GATHER:
+ loadGather(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_SPACE_TO_BATCH_ND:
+ loadOperationTo<ir::operation::SpaceToBatchND>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_BATCH_TO_SPACE_ND:
+ loadOperationTo<ir::operation::BatchToSpaceND>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_SUM:
+ loadReduce(op, subg, ir::operation::Reduce::ReduceType::SUM);
+ return;
+ case BuiltinOperator::BuiltinOperator_CUSTOM:
+ loadCustom(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_SQUEEZE:
+ loadSqueeze(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_PRELU:
+ loadOperationTo<ir::operation::PReLU>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_SPLIT:
+ loadSplit(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_SPLIT_V:
+ loadSplitV(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_SLICE:
+ loadOperationTo<ir::operation::Slice>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_STRIDED_SLICE:
+ loadStridedSlice(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_UNPACK:
+ loadUnpack(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_FLOOR_DIV:
+ loadElementwiseBinary(op, subg,
+ ir::operation::ElementwiseBinary::ElementwiseBinaryType::FLOOR_DIV);
+ return;
+ case BuiltinOperator::BuiltinOperator_FLOOR_MOD:
+ loadElementwiseBinary(op, subg,
+ ir::operation::ElementwiseBinary::ElementwiseBinaryType::FLOOR_MOD);
+ return;
+ case BuiltinOperator::BuiltinOperator_MINIMUM:
+ loadElementwiseBinary(op, subg, ir::operation::ElementwiseBinary::ElementwiseBinaryType::MIN);
+ return;
+ case BuiltinOperator::BuiltinOperator_MAXIMUM:
+ loadElementwiseBinary(op, subg, ir::operation::ElementwiseBinary::ElementwiseBinaryType::MAX);
+ return;
+ case BuiltinOperator::BuiltinOperator_CAST:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::CAST);
+ return;
+ case BuiltinOperator::BuiltinOperator_EQUAL:
+ case BuiltinOperator::BuiltinOperator_NOT_EQUAL:
+ case BuiltinOperator::BuiltinOperator_GREATER_EQUAL:
+ case BuiltinOperator::BuiltinOperator_GREATER:
+ case BuiltinOperator::BuiltinOperator_LESS_EQUAL:
+ case BuiltinOperator::BuiltinOperator_LESS:
+ loadComparison(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_ONE_HOT:
+ loadOneHot(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_ABS:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ABS);
+ return;
+ case BuiltinOperator::BuiltinOperator_COS:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::COS);
+ return;
+ case BuiltinOperator::BuiltinOperator_SIN:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::SIN);
+ return;
+ case BuiltinOperator::BuiltinOperator_SHAPE:
+ loadOperationTo<ir::operation::Shape>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_REDUCE_PROD:
+ loadReduce(op, subg, ir::operation::Reduce::ReduceType::PROD);
+ return;
+ case BuiltinOperator::BuiltinOperator_IF:
+ loadIf(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_WHILE:
+ loadWhile(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_NEG:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::NEG);
+ return;
+ case BuiltinOperator::BuiltinOperator_ARG_MAX:
+ loadArgMinMax(op, subg, true);
+ return;
+ case BuiltinOperator::BuiltinOperator_ARG_MIN:
+ loadArgMinMax(op, subg, false);
+ return;
+ case BuiltinOperator::BuiltinOperator_LOG:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::LOG);
+ return;
+ case BuiltinOperator::BuiltinOperator_ROUND:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ROUND);
+ return;
+ case BuiltinOperator::BuiltinOperator_POW:
+ loadOperationTo<ir::operation::Pow>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_LOGICAL_NOT:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::LOGICAL_NOT);
+ return;
+ case BuiltinOperator::BuiltinOperator_LOGICAL_AND:
+ loadElementwiseBinary(op, subg,
+ ir::operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_AND);
+ return;
+ case BuiltinOperator::BuiltinOperator_LOGICAL_OR:
+ loadElementwiseBinary(op, subg,
+ ir::operation::ElementwiseBinary::ElementwiseBinaryType::LOGICAL_OR);
+ return;
+ case BuiltinOperator::BuiltinOperator_FILL:
+ loadOperationTo<ir::operation::Fill>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_ZEROS_LIKE:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::ZEROS_LIKE);
+ return;
+ case BuiltinOperator::BuiltinOperator_TILE:
+ loadOperationTo<ir::operation::Tile>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_RANGE:
+ loadOperationTo<ir::operation::Range>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_BATCH_MATMUL:
+ loadBatchMatMul(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_LOG_SOFTMAX:
+ loadLogSoftmax(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_QUANTIZE:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::QUANTIZE);
+ return;
+ case BuiltinOperator::BuiltinOperator_DEQUANTIZE:
+ loadElementwiseUnary(op, subg, ir::operation::ElementwiseUnary::Type::DEQUANTIZE);
+ return;
+ case BuiltinOperator::BuiltinOperator_SPACE_TO_DEPTH:
+ loadSpaceToDepth(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_L2_NORMALIZATION:
+ loadOperationTo<ir::operation::L2Normalization>(op, subg);
+ break;
+ case BuiltinOperator::BuiltinOperator_LEAKY_RELU:
+ loadLeakyRelu(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_RANK:
+ loadOperationTo<ir::operation::Rank>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
+ loadUnidirectionalSequenceLSTM(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_DEPTH_TO_SPACE:
+ loadDepthToSpace(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_EMBEDDING_LOOKUP:
+ loadOperationTo<ir::operation::EmbeddingLookup>(op, subg);
+ return;
+ case BuiltinOperator::BuiltinOperator_HASHTABLE_LOOKUP:
+ loadOperationTo<ir::operation::HashtableLookup>(op, subg);
+ return;
+ default:
+ throw std::runtime_error(
+ std::string("Unsupported operation: ").append(EnumNameBuiltinOperator(builtin_op)));
+ }
+}
+
+template <typename LoaderDomain> void BaseLoader<LoaderDomain>::loadModel()
+{
+ LoaderDomain::VerifyModelBuffer(*_verifier.get());
+ _domain_model = LoaderDomain::GetModel(_base);
+
+ auto model = std::make_unique<ir::Model>();
+ // Version unused
+ // const auto version = _model->version();
+ // Description unused
+
+ // Load Metadata
+ auto const metadata_list = _domain_model->metadata();
+ if (metadata_list != nullptr)
+ {
+ for (uint32_t i = 0; i < metadata_list->size(); ++i)
+ {
+ const auto metadata = metadata_list->Get(i);
+ if (metadata->name() == nullptr)
+ continue; // metadata should have name
+
+ std::unique_ptr<const ir::Data> data = loadMetadata(metadata->buffer());
+ model->add_metadata(metadata->name()->str(), std::move(data));
+ }
+ }
+
+ // const auto *description = _model->description();
+ // Load subgraphs and map operations on subgraph
+ const auto subgraphs = _domain_model->subgraphs();
+ if (subgraphs->size() - 1 > ir::SubgraphIndex::max())
+ throw std::runtime_error{"The number of subgraphs cannot exceed " +
+ std::to_string(ir::SubgraphIndex::max() + 1)};
+ for (uint16_t subgraph_index = 0; subgraph_index < subgraphs->size(); ++subgraph_index)
+ {
+ auto subg = loadSubgraph((*_domain_model->subgraphs())[subgraph_index]);
+ // NOTE: Used () instead of {}, which does not check narrowing.
+ // It is okay since overflow is checked the above if-statement.
+ model->push(ir::SubgraphIndex(subgraph_index), std::move(subg));
+ }
+ _model = std::move(model);
+}
+
+} // namespace loader
+} // namespace onert
+
+#endif //__ONERT_LOADER_BASE_LOADER_H__
diff --git a/runtime/onert/core/src/loader/CircleLoader.cc b/runtime/onert/core/src/loader/CircleLoader.cc
new file mode 100644
index 000000000..442a0f518
--- /dev/null
+++ b/runtime/onert/core/src/loader/CircleLoader.cc
@@ -0,0 +1,239 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loader/CircleLoader.h"
+
+#include "BaseLoader.h"
+#include "circle_schema_generated.h"
+
+namespace onert
+{
+namespace loader
+{
+
+namespace
+{
+
+struct LoaderDomain
+{
+ using Verifier = flatbuffers::Verifier;
+ using ActivationFunctionType = circle::ActivationFunctionType;
+ using Buffer = circle::Buffer;
+ using BuiltinOperator = circle::BuiltinOperator;
+ using CustomOptionsFormat = circle::CustomOptionsFormat;
+ using Metadata = circle::Metadata;
+ using Model = circle::Model;
+ using Operator = circle::Operator;
+ using Padding = circle::Padding;
+ using Pool2DOptions = circle::Pool2DOptions;
+ using Tensor = circle::Tensor;
+ using TensorType = circle::TensorType;
+ using SubGraph = circle::SubGraph;
+ using DimensionType = circle::DimensionType;
+ using SparseIndexVector = circle::SparseIndexVector;
+
+ static const char *EnumNameBuiltinOperator(BuiltinOperator e)
+ {
+ return circle::EnumNameBuiltinOperator(e);
+ }
+ static const char *EnumNameActivationFunctionType(ActivationFunctionType e)
+ {
+ return circle::EnumNameActivationFunctionType(e);
+ }
+ static const char *EnumNameTensorType(TensorType e) { return circle::EnumNameTensorType(e); }
+ static const Model *GetModel(const void *buf) { return circle::GetModel(buf); }
+ static bool VerifyModelBuffer(Verifier &verifier) { return circle::VerifyModelBuffer(verifier); }
+};
+
+class CircleLoader final : public loader::BaseLoader<LoaderDomain>
+{
+protected:
+ // Different option name
+ // Circle: adjoint_lhs, adjoint_rhs
+ // TFLite: adj_x, adj_y
+ void loadBatchMatMul(const Operator *op, ir::Graph &subg);
+
+ // Only circle operations
+ void loadInstanceNorm(const Operator *op, ir::Graph &subg);
+ void loadBCQFullyConnected(const Operator *op, ir::Graph &subg);
+ void loadBCQGather(const Operator *op, ir::Graph &subg);
+
+public:
+ using BaseLoader::BaseLoader;
+
+ bool allowOptionalInputTensor(BuiltinOperator op) override
+ {
+ switch (op)
+ {
+ case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED:
+ case BuiltinOperator::BuiltinOperator_BCQ_FULLY_CONNECTED:
+ case BuiltinOperator::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+private:
+ std::unique_ptr<ir::Graph> loadSubgraph(const circle::SubGraph *circle_subg) override
+ {
+ auto subg = std::make_unique<ir::Graph>();
+ // Load tensors
+ _tensor_to_operand.resize(circle_subg->tensors()->size());
+ for (flatbuffers::uoffset_t i = 0; i < circle_subg->tensors()->size(); ++i)
+ {
+ _tensor_to_operand[i] = loadOperand(circle_subg->tensors()->Get(i), *subg);
+ subg->operands().at(_tensor_to_operand[i]).setOriginIndex(ir::OriginIndex(i));
+ }
+ // Set inputs
+ for (const std::int32_t input_ind : *circle_subg->inputs())
+ {
+ subg->addInput(tensorIdxToOperandIdx(input_ind),
+ _tensor_names.at(_tensor_to_operand[input_ind]));
+ }
+ // Set outputs
+ for (const std::int32_t output_ind : *circle_subg->outputs())
+ {
+ subg->addOutput(tensorIdxToOperandIdx(output_ind),
+ _tensor_names.at(_tensor_to_operand[output_ind]));
+ }
+ // Create operations
+ for (const auto *op : *circle_subg->operators())
+ {
+ CircleLoader::loadOperation(op, *subg);
+ }
+
+ // TODO Remove frontend layout feature
+ subg->setLayout(ir::Layout::NHWC);
+
+ subg->verify();
+
+ return subg;
+ }
+
+ void loadOperation(const circle::Operator *op, ir::Graph &subg)
+ {
+ auto const builtin_op = getBuiltinOperator(op);
+
+ switch (builtin_op)
+ {
+ case circle::BuiltinOperator::BuiltinOperator_BATCH_MATMUL:
+ loadBatchMatMul(op, subg);
+ return;
+ case circle::BuiltinOperator::BuiltinOperator_INSTANCE_NORM:
+ loadInstanceNorm(op, subg);
+ return;
+ case circle::BuiltinOperator::BuiltinOperator_BCQ_FULLY_CONNECTED:
+ loadBCQFullyConnected(op, subg);
+ return;
+ case circle::BuiltinOperator::BuiltinOperator_BCQ_GATHER:
+ loadBCQGather(op, subg);
+ return;
+ default:
+ BaseLoader::loadOperation(op, subg);
+ return;
+ }
+ }
+};
+
+void CircleLoader::loadBatchMatMul(const Operator *op, ir::Graph &subg)
+{
+ ir::OperandIndexSequence inputs;
+ ir::OperandIndexSequence outputs;
+
+ loadOperationIO(op, inputs, outputs);
+
+ ir::operation::BatchMatMul::Param param;
+ const auto *options = op->builtin_options_as_BatchMatMulOptions();
+
+ param.adj_x = options->adjoint_lhs();
+ param.adj_y = options->adjoint_rhs();
+
+ std::unique_ptr<ir::Operation> new_op(new ir::operation::BatchMatMul(inputs, outputs, param));
+ subg.addOperation(std::move(new_op));
+}
+
+void CircleLoader::loadInstanceNorm(const Operator *op, ir::Graph &subg)
+{
+ ir::OperandIndexSequence inputs;
+ ir::OperandIndexSequence outputs;
+
+ loadOperationIO(op, inputs, outputs);
+
+ ir::operation::InstanceNorm::Param param;
+ const auto *options = op->builtin_options_as_InstanceNormOptions();
+
+ param.activation = convertActivation(options->fused_activation_function());
+ // Use default value 1e-5 if value of epsilon is zero
+ param.epsilon = options->epsilon() == 0.f ? 1e-5 : options->epsilon();
+
+ std::unique_ptr<ir::Operation> new_op(new ir::operation::InstanceNorm(inputs, outputs, param));
+ subg.addOperation(std::move(new_op));
+}
+
+void CircleLoader::loadBCQGather(const Operator *op, ir::Graph &subg)
+{
+ ir::OperandIndexSequence inputs;
+ ir::OperandIndexSequence outputs;
+
+ loadOperationIO(op, inputs, outputs);
+
+ ir::operation::BCQGather::Param param;
+ const auto *options = op->builtin_options_as_BCQGatherOptions();
+ param.input_hidden_size = options->input_hidden_size();
+ param.axis = options->axis();
+
+ std::unique_ptr<ir::Operation> new_op(new ir::operation::BCQGather(inputs, outputs, param));
+ subg.addOperation(std::move(new_op));
+}
+
+void CircleLoader::loadBCQFullyConnected(const Operator *op, ir::Graph &subg)
+{
+ ir::OperandIndexSequence inputs;
+ ir::OperandIndexSequence outputs;
+
+ loadOperationIO(op, inputs, outputs);
+
+ ir::operation::BCQFullyConnected::Param param;
+ const auto *options = op->builtin_options_as_BCQFullyConnectedOptions();
+ param.weights_hidden_size = options->weights_hidden_size();
+ param.activation = convertActivation(options->fused_activation_function());
+
+ std::unique_ptr<ir::Operation> new_op(
+ new ir::operation::BCQFullyConnected(inputs, outputs, param));
+ subg.addOperation(std::move(new_op));
+}
+
+} // namespace
+
+std::unique_ptr<ir::Model> loadCircleModel(const std::string &filename)
+{
+ auto model = std::make_unique<ir::Model>();
+ CircleLoader loader(model);
+ loader.loadFromFile(filename);
+ return model;
+}
+
+std::unique_ptr<ir::Model> loadCircleModel(uint8_t *buffer, size_t size)
+{
+ auto model = std::make_unique<ir::Model>();
+ CircleLoader loader(model);
+ loader.loadFromBuffer(buffer, size);
+ return model;
+}
+
+} // namespace loader
+} // namespace onert
diff --git a/runtime/onert/core/src/loader/ModelLoader.cc b/runtime/onert/core/src/loader/ModelLoader.cc
new file mode 100644
index 000000000..1f3b4673c
--- /dev/null
+++ b/runtime/onert/core/src/loader/ModelLoader.cc
@@ -0,0 +1,85 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loader/ModelLoader.h"
+
+#include "loader/ILoader.h"
+
+#include <dlfcn.h>
+
+namespace onert
+{
+namespace loader
+{
+
+std::unique_ptr<ir::Model> loadModel(const std::string &filename, const std::string &type)
+{
+ // Custom loader library name should be lib<type>_loader.so
+ std::string libname = "lib" + type + "_loader.so";
+
+ // Open custom loader library
+ void *handle = dlopen(libname.c_str(), RTLD_LAZY);
+ if (!handle)
+ throw std::runtime_error("Failed to open " + type + " loader");
+
+ // Get custom loader create function
+ using create_func_t = ILoader *(*)();
+ auto create_fn = reinterpret_cast<create_func_t>(dlsym(handle, "onert_loader_create"));
+ if (!create_fn)
+ {
+ dlclose(handle);
+ throw std::runtime_error("Failed to find loader create function");
+ }
+
+ // Get custom loader destroy function
+ using destroy_func_t = void (*)(ILoader *);
+ auto destroy_fn = reinterpret_cast<destroy_func_t>(dlsym(handle, "onert_loader_destroy"));
+ if (!destroy_fn)
+ {
+ dlclose(handle);
+ throw std::runtime_error("Failed to find loader destroy function");
+ }
+
+ // Create custom loader
+ auto loader = create_fn();
+ if (!loader)
+ {
+ dlclose(handle);
+ throw std::runtime_error("Failed to find loader create function");
+ }
+
+ // Load model
+ auto model = loader->loadFromFile(filename);
+
+ // Destroy custom loader
+ destroy_fn(loader);
+
+ // Close custom loader library
+ //
+ // NOTE:
+ // It assumes that custom loader will not be used frequently on runtime session.
+ // If custom loader is used frequently, it should not close custom loader library and
+ // save handler to reuse it.
+ dlclose(handle);
+
+ if (model)
+ return model;
+
+ throw std::runtime_error("Failed to load model " + filename);
+}
+
+} // namespace loader
+} // namespace onert
diff --git a/runtime/onert/core/src/loader/TFLiteLoader.cc b/runtime/onert/core/src/loader/TFLiteLoader.cc
new file mode 100644
index 000000000..745f39006
--- /dev/null
+++ b/runtime/onert/core/src/loader/TFLiteLoader.cc
@@ -0,0 +1,167 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loader/TFLiteLoader.h"
+
+#include "BaseLoader.h"
+#include "tflite_schema_generated.h"
+
+namespace onert
+{
+namespace loader
+{
+
+namespace
+{
+
+struct LoaderDomain
+{
+ using Verifier = flatbuffers::Verifier;
+ using ActivationFunctionType = onert_tflite::ActivationFunctionType;
+ using Buffer = onert_tflite::Buffer;
+ using BuiltinOperator = onert_tflite::BuiltinOperator;
+ using CustomOptionsFormat = onert_tflite::CustomOptionsFormat;
+ using Model = onert_tflite::Model;
+ using Metadata = onert_tflite::Metadata;
+ using Operator = onert_tflite::Operator;
+ using Padding = onert_tflite::Padding;
+ using Pool2DOptions = onert_tflite::Pool2DOptions;
+ using Tensor = onert_tflite::Tensor;
+ using TensorType = onert_tflite::TensorType;
+ using SubGraph = onert_tflite::SubGraph;
+ using DimensionType = onert_tflite::DimensionType;
+ using SparseIndexVector = onert_tflite::SparseIndexVector;
+
+ static const char *EnumNameBuiltinOperator(BuiltinOperator e)
+ {
+ return onert_tflite::EnumNameBuiltinOperator(e);
+ }
+ static const char *EnumNameActivationFunctionType(ActivationFunctionType e)
+ {
+ return onert_tflite::EnumNameActivationFunctionType(e);
+ }
+ static const char *EnumNameTensorType(TensorType e)
+ {
+ return onert_tflite::EnumNameTensorType(e);
+ }
+ static const Model *GetModel(const void *buf) { return onert_tflite::GetModel(buf); }
+ static bool VerifyModelBuffer(Verifier &verifier)
+ {
+ return onert_tflite::VerifyModelBuffer(verifier);
+ }
+};
+
+class TFLiteLoader final : public loader::BaseLoader<LoaderDomain>
+{
+protected:
+ // Different option name
+ // Circle: adjoint_lhs, adjoint_rhs
+ // TFLite: adj_x, adj_y
+ void loadBatchMatMul(const Operator *op, ir::Graph &subg);
+
+public:
+ using BaseLoader::BaseLoader;
+
+ bool allowOptionalInputTensor(BuiltinOperator op) override
+ {
+ switch (op)
+ {
+ case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED:
+ case BuiltinOperator::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+private:
+ std::unique_ptr<ir::Graph> loadSubgraph(const onert_tflite::SubGraph *tflite_subg) override
+ {
+ auto subg = std::make_unique<ir::Graph>();
+ // Load tensors
+ _tensor_to_operand.resize(tflite_subg->tensors()->size());
+ for (flatbuffers::uoffset_t i = 0; i < tflite_subg->tensors()->size(); ++i)
+ {
+ _tensor_to_operand[i] = loadOperand(tflite_subg->tensors()->Get(i), *subg);
+ }
+ // Set inputs
+ for (const std::int32_t input_ind : *tflite_subg->inputs())
+ {
+ subg->addInput(tensorIdxToOperandIdx(input_ind),
+ _tensor_names.at(_tensor_to_operand[input_ind]));
+ }
+ // Set outputs
+ for (const std::int32_t output_ind : *tflite_subg->outputs())
+ {
+ subg->addOutput(tensorIdxToOperandIdx(output_ind),
+ _tensor_names.at(_tensor_to_operand[output_ind]));
+ }
+ // Create operations
+ for (const auto *op : *tflite_subg->operators())
+ {
+ loadOperation(op, *subg);
+ }
+
+ subg->verify();
+
+ return subg;
+ }
+
+ void loadOperation(const onert_tflite::Operator *op, ir::Graph &subg)
+ {
+ auto const builtin_op = getBuiltinOperator(op);
+
+ switch (builtin_op)
+ {
+ case onert_tflite::BuiltinOperator::BuiltinOperator_BATCH_MATMUL:
+ loadBatchMatMul(op, subg);
+ return;
+ default:
+ BaseLoader::loadOperation(op, subg);
+ return;
+ }
+ }
+};
+
+void TFLiteLoader::loadBatchMatMul(const Operator *op, ir::Graph &subg)
+{
+ ir::OperandIndexSequence inputs;
+ ir::OperandIndexSequence outputs;
+
+ loadOperationIO(op, inputs, outputs);
+
+ ir::operation::BatchMatMul::Param param;
+ const auto *options = op->builtin_options_as_BatchMatMulOptions();
+
+ param.adj_x = options->adj_x();
+ param.adj_y = options->adj_y();
+
+ std::unique_ptr<ir::Operation> new_op(new ir::operation::BatchMatMul(inputs, outputs, param));
+ subg.addOperation(std::move(new_op));
+}
+
+} // namespace
+
+std::unique_ptr<ir::Model> loadTFLiteModel(const std::string &filename)
+{
+ auto model = std::make_unique<ir::Model>();
+ TFLiteLoader loader(model);
+ loader.loadFromFile(filename);
+ return model;
+}
+
+} // namespace loader
+} // namespace onert
diff --git a/runtime/onert/core/src/loader/TrainInfoLoader.cc b/runtime/onert/core/src/loader/TrainInfoLoader.cc
new file mode 100644
index 000000000..bb75daa6f
--- /dev/null
+++ b/runtime/onert/core/src/loader/TrainInfoLoader.cc
@@ -0,0 +1,139 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loader/TrainInfoLoader.h"
+
+#include "circle_traininfo_generated.h"
+#include "flatbuffers/flatbuffers.h"
+
+namespace onert
+{
+namespace loader
+{
+
+const char *const TRAININFO_METADATA_NAME = "CIRCLE_TRAINING";
+
+namespace
+{
+
+ir::train::OptimizerInfo loadOptimizerInfo(const circle::ModelTraining *circle_model)
+{
+ assert(circle_model != nullptr);
+
+ // fill ir_opt from cirlce_opt
+ ir::train::OptimizerInfo ir_opt;
+ const circle::Optimizer circle_opt = circle_model->optimizer();
+
+ switch (circle_opt)
+ {
+ case circle::Optimizer_SGD:
+ ir_opt.optim_code = ir::train::OptimizerCode::SGD;
+ ir_opt.learning_rate = circle_model->optimizer_opt_as_SGDOptions()->learning_rate();
+ break;
+ case circle::Optimizer_ADAM:
+ ir_opt.optim_code = ir::train::OptimizerCode::Adam;
+ ir_opt.learning_rate = circle_model->optimizer_opt_as_AdamOptions()->learning_rate();
+ break;
+ default:
+ throw std::runtime_error("unknown optimzer");
+ }
+ return ir_opt;
+}
+
+ir::train::LossInfo loadLossInfo(const circle::ModelTraining *circle_model)
+{
+ assert(circle_model != nullptr);
+
+ // fill ir_loss from circle_loss
+ ir::train::LossInfo ir_loss;
+ const circle::LossFn circle_loss = circle_model->lossfn();
+ const circle::LossReductionType circle_loss_rdt = circle_model->loss_reduction_type();
+
+ switch (circle_loss)
+ {
+ case circle::LossFn::LossFn_CATEGORICAL_CROSSENTROPY:
+ ir_loss.loss_code = ir::train::LossCode::CategoricalCrossentropy;
+ break;
+ case circle::LossFn::LossFn_MEAN_SQUARED_ERROR:
+ ir_loss.loss_code = ir::train::LossCode::MeanSquaredError;
+ break;
+ case circle::LossFn::LossFn_SPARSE_CATEGORICAL_CROSSENTROPY:
+ // TODO enable this conversion after core support sparse_categorial_crossentropy
+ throw std::runtime_error{"'sparse_categorical_crossentropy' is not supported yet"};
+ default:
+ throw std::runtime_error{"unknown loss function"};
+ }
+
+ switch (circle_loss_rdt)
+ {
+ case circle::LossReductionType::LossReductionType_SumOverBatchSize:
+ ir_loss.reduction_type = ir::train::LossReductionType::SumOverBatchSize;
+ break;
+ case circle::LossReductionType::LossReductionType_Sum:
+ ir_loss.reduction_type = ir::train::LossReductionType::Sum;
+ break;
+ default:
+ throw std::runtime_error{"unknown loss reduction type"};
+ }
+
+ return ir_loss;
+}
+
+std::set<ir::OperationIndex> loadTrainableOps(const circle::ModelTraining *circle_model)
+{
+ assert(circle_model != nullptr);
+
+ std::set<ir::OperationIndex> ir_trainable_ops;
+ const auto lists = circle_model->trainable_ops();
+ if (lists != nullptr)
+ {
+ for (::flatbuffers::uoffset_t i = 0; i < lists->size(); ++i)
+ {
+ const uint32_t op_index = lists->Get(i);
+ ir_trainable_ops.emplace(ir::OperationIndex{op_index});
+ }
+ }
+ return ir_trainable_ops;
+}
+} // namespace
+
+std::unique_ptr<ir::train::TrainingInfo> loadTrainingInfo(const uint8_t *buffer, const size_t size)
+{
+ assert(buffer != nullptr);
+
+ flatbuffers::Verifier v(buffer, size);
+ bool verified = circle::VerifyModelTrainingBuffer(v);
+ if (not verified)
+ throw std::runtime_error{"TrainingInfo buffer is not accessible"};
+
+ const circle::ModelTraining *circle_model =
+ circle::GetModelTraining(static_cast<const void *>(buffer));
+
+ assert(circle_model != nullptr);
+
+ auto tinfo = std::make_unique<ir::train::TrainingInfo>();
+ {
+ tinfo->setVersion(circle_model->version());
+ tinfo->setBatchSize(circle_model->batch_size());
+ tinfo->setOptimizerInfo(loadOptimizerInfo(circle_model));
+ tinfo->setLossInfo(loadLossInfo(circle_model));
+ tinfo->setTrainableOps(loadTrainableOps(circle_model));
+ }
+ return tinfo;
+}
+
+} // namespace loader
+} // namespace onert
diff --git a/runtime/onert/core/src/loader/tflite_schema.fbs b/runtime/onert/core/src/loader/tflite_schema.fbs
new file mode 100644
index 000000000..f7997528e
--- /dev/null
+++ b/runtime/onert/core/src/loader/tflite_schema.fbs
@@ -0,0 +1,1308 @@
+// Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Revision History
+// Version 0: Initial version.
+// Version 1: Add subgraphs to schema.
+// Version 2: Rename operators to conform to NN API.
+// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers.
+// Version 3a: Add new builtin op code field. Has backward compatibility with
+// version 3.
+// Version 3b: Rename fields in SignatureDef. Has backward compatibility with
+// version 3 and 3a.
+
+// Change namespace to onert_tflite
+namespace onert_tflite;
+
+// This corresponds to the version.
+file_identifier "TFL3";
+// File extension of any written files.
+file_extension "tflite";
+
+// IMPORTANT: All new members of tables, enums and unions must be added at the
+// end to ensure backwards compatibility.
+
+// The type of data stored in a tensor.
+enum TensorType : byte {
+ FLOAT32 = 0,
+ FLOAT16 = 1,
+ INT32 = 2,
+ UINT8 = 3,
+ INT64 = 4,
+ STRING = 5,
+ BOOL = 6,
+ INT16 = 7,
+ COMPLEX64 = 8,
+ INT8 = 9,
+ FLOAT64 = 10,
+ COMPLEX128 = 11,
+ UINT64 = 12,
+ // Experimental: Resource and variant types are experimental, that are subject
+ // to change. Do not implement custom kernels using resource & variant types
+ // now.
+ RESOURCE = 13,
+ VARIANT = 14,
+ UINT32 = 15,
+ UINT16 = 16
+}
+
+// Custom quantization parameters for experimenting with new quantization
+// techniques.
+table CustomQuantization {
+ custom:[ubyte] (force_align: 16);
+}
+
+// Represents a specific quantization technique's parameters.
+union QuantizationDetails {
+ CustomQuantization,
+}
+
+// Parameters for converting a quantized tensor back to float.
+table QuantizationParameters {
+ // These four parameters are the asymmetric linear quantization parameters.
+ // Given a quantized value q, the corresponding float value f should be:
+ // f = scale * (q - zero_point)
+ // For other quantization types, the QuantizationDetails below is used.
+ min:[float]; // For importing back into tensorflow.
+ max:[float]; // For importing back into tensorflow.
+ scale:[float]; // For dequantizing the tensor's values.
+ zero_point:[long];
+
+ // If this is not none, the other quantization parameters (i.e. min, max,
+ // scale, zero_point fields above) are ignored and the value of the
+ // QuantizationDetails union should be used.
+ details:QuantizationDetails;
+
+ // Specifies the dimension of the Tensor's shape that the scales and
+ // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
+ // with quantization params:
+ // scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1
+ // will be quantized across the second dimension of t.
+ // t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
+ // t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
+ // t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
+ quantized_dimension:int;
+}
+
+// Sparse tensors.
+// We use a modification of the TACO format.
+// Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf
+//
+// To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1),
+// potentially with a k-dimensional block (0 <= k <= n) with dims
+// (dn, ..., dn+k-1), the format needs to specify:
+// 1. In what order to traverse these dimensions. For example, to store a 2-D
+// matrix in row major order, the traversal order would be (d0, d1),
+// whereas to store it in column major order, the traversal order would be
+// (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order
+// could be (d0, d1, d2, d3).
+// 2. How each block dimension in (dn, ..., dn+k-1) maps to the original
+// tensor dimension in (d0, ..., dn-1).
+// 3. In the traversal order defined above, the format (dense vs. sparse) and
+// index metadata for each dimension. For a dense dimension, this is just
+// the size of that dimension. For a sparse dimension, it's the same as
+// the compressed index defined in the Compressed Sparse Row (CSR) format.
+// (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html)
+
+// The storage type for a dimension. Currently we support:
+// 1. DENSE: each coordinate in this dimension is stored implicitly.
+// 2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The
+// compression technique is the same what CSR uses.
+// More types like a sparse dimension with a different compression technique
+// could be added to the list in the future.
+enum DimensionType : byte {
+ DENSE = 0,
+ SPARSE_CSR = 1,
+}
+
+table Int32Vector {
+ values:[int];
+}
+
+table Uint16Vector {
+ values:[ushort] (force_align: 4);
+}
+
+table Uint8Vector {
+ values:[ubyte] (force_align: 4);
+}
+
+// Variable-typed buffer to store the index metadata for a sparse dimension.
+// The widest type is Int32 instead of UInt32 because tensor's shape is a int32
+// vector. We don't want the per-dimensional index to overflow that range.
+union SparseIndexVector {
+ Int32Vector,
+ Uint16Vector,
+ Uint8Vector
+}
+
+table DimensionMetadata {
+ // Whether a dimension is dense or sparse.
+ format:DimensionType;
+ // Index metadata used for a dimension.
+ // - If format is DimensionType.DENSE then we use the dense_size field to
+ // store the size of that dimension. Each index in that dimension is
+ // stored implicitly.
+ // - If format is DimensionType.SPARSE_CSR then we use array_segments and
+ // array_indices to encode that dimension. array_segments represents how
+ // to segment the indices array, each segment corresponds to one element
+ // in the previous dimension. array_indices represents the index of the
+ // non-zero elements within this dimension (as those in the CSR matrix
+ // format, where the first array is row pointers and the second array is
+ // column indices).
+ dense_size:int;
+ array_segments:SparseIndexVector;
+ array_indices:SparseIndexVector;
+}
+
+// Parameters to encode a sparse TfLite tensor.
+table SparsityParameters {
+ // The traversal order of the dimensions defined in the `shape` field of the
+ // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1,
+ // ..., dn-1),
+ // - if not block sparse, the traversal_order is just a permutation of (d0,
+ // ..., dn-1). For example, a 2-D matrix stored in row-major order would
+ // have traversal_order = (d0, d1).
+ // - if block sparse with a k-dimensional block (0 <= k <= n), the
+ // traversal_order has n + k elements. The first n elements are still a
+ // permutation of (d0, ..., dn-1). The lask k elements are a permutation
+ // of (dn, ..., dn+k-1), defining how to traverse a block internally. For
+ // example, a 2-D matrix with 2-D blocks, both stored in row-major order
+ // would have traversal_order = (d0, d1, d2, d3).
+ traversal_order:[int];
+ // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n),
+ // stores how a block dimension in (dn, ..., dn+k-1) maps to the original
+ // tensor dimension in (d0, ..., dn).
+ // It's stored in the order of (dn, ..., dn+k-1).
+ // If not block-sparse, this field is NULL.
+ block_map:[int];
+ // In the traversal order defined above, the metadata needed for
+ // each dimension to locate the non-zero values in the original dense tensor.
+ // The size of the dim_metadata array = the size of the traversal_order array
+ // = n + k.
+ dim_metadata:[DimensionMetadata];
+}
+
+table Tensor {
+ // The tensor shape. The meaning of each entry is operator-specific but
+ // builtin ops use: [batch size, height, width, number of channels] (That's
+ // Tensorflow's NHWC).
+ shape:[int];
+ type:TensorType;
+ // An index that refers to the buffers table at the root of the model. Or,
+ // if there is no data buffer associated (i.e. intermediate results), then
+ // this is 0 (which refers to an always existent empty buffer).
+ //
+ // The data_buffer itself is an opaque container, with the assumption that the
+ // target device is little-endian. In addition, all builtin operators assume
+ // the memory is ordered such that if `shape` is [4, 3, 2], then index
+ // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k].
+ buffer:uint;
+ name:string; // For debugging and importing back into tensorflow.
+ quantization:QuantizationParameters; // Optional.
+
+ is_variable:bool = false;
+
+ // Parameters to encode a sparse tensor. See the example in
+ // tensorflow/lite/testdata/sparse_tensor.json.
+ sparsity:SparsityParameters; // Optional.
+
+ // Encodes `shape` with unknown dimensions. Unknown dimensions are
+ // represented with -1.
+ shape_signature:[int]; // Optional.
+
+ // If false, the rank or the number of tensor dimensions is unknown.
+ // If false, "shape" must be [].
+ has_rank: bool = false;
+}
+
+// A list of builtin operators. Builtin operators are slightly faster than custom
+// ones, but not by much. Moreover, while custom operators accept an opaque
+// object containing configuration parameters, builtins have a predetermined
+// set of acceptable options.
+// LINT.IfChange
+enum BuiltinOperator : int32 {
+ ADD = 0,
+ AVERAGE_POOL_2D = 1,
+ CONCATENATION = 2,
+ CONV_2D = 3,
+ DEPTHWISE_CONV_2D = 4,
+ DEPTH_TO_SPACE = 5,
+ DEQUANTIZE = 6,
+ EMBEDDING_LOOKUP = 7,
+ FLOOR = 8,
+ FULLY_CONNECTED = 9,
+ HASHTABLE_LOOKUP = 10,
+ L2_NORMALIZATION = 11,
+ L2_POOL_2D = 12,
+ LOCAL_RESPONSE_NORMALIZATION = 13,
+ LOGISTIC = 14,
+ LSH_PROJECTION = 15,
+ LSTM = 16,
+ MAX_POOL_2D = 17,
+ MUL = 18,
+ RELU = 19,
+ // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed
+ // since different model developers use RELU1 in different ways. Never
+ // create another op called RELU1.
+ RELU_N1_TO_1 = 20,
+ RELU6 = 21,
+ RESHAPE = 22,
+ RESIZE_BILINEAR = 23,
+ RNN = 24,
+ SOFTMAX = 25,
+ SPACE_TO_DEPTH = 26,
+ SVDF = 27,
+ TANH = 28,
+ CONCAT_EMBEDDINGS = 29,
+ SKIP_GRAM = 30,
+ CALL = 31,
+ CUSTOM = 32,
+ EMBEDDING_LOOKUP_SPARSE = 33,
+ PAD = 34,
+ UNIDIRECTIONAL_SEQUENCE_RNN = 35,
+ GATHER = 36,
+ BATCH_TO_SPACE_ND = 37,
+ SPACE_TO_BATCH_ND = 38,
+ TRANSPOSE = 39,
+ MEAN = 40,
+ SUB = 41,
+ DIV = 42,
+ SQUEEZE = 43,
+ UNIDIRECTIONAL_SEQUENCE_LSTM = 44,
+ STRIDED_SLICE = 45,
+ BIDIRECTIONAL_SEQUENCE_RNN = 46,
+ EXP = 47,
+ TOPK_V2 = 48,
+ SPLIT = 49,
+ LOG_SOFTMAX = 50,
+ // DELEGATE is a special op type for the operations which are delegated to
+ // other backends.
+ // WARNING: Experimental interface, subject to change
+ DELEGATE = 51,
+ BIDIRECTIONAL_SEQUENCE_LSTM = 52,
+ CAST = 53,
+ PRELU = 54,
+ MAXIMUM = 55,
+ ARG_MAX = 56,
+ MINIMUM = 57,
+ LESS = 58,
+ NEG = 59,
+ PADV2 = 60,
+ GREATER = 61,
+ GREATER_EQUAL = 62,
+ LESS_EQUAL = 63,
+ SELECT = 64,
+ SLICE = 65,
+ SIN = 66,
+ TRANSPOSE_CONV = 67,
+ SPARSE_TO_DENSE = 68,
+ TILE = 69,
+ EXPAND_DIMS = 70,
+ EQUAL = 71,
+ NOT_EQUAL = 72,
+ LOG = 73,
+ SUM = 74,
+ SQRT = 75,
+ RSQRT = 76,
+ SHAPE = 77,
+ POW = 78,
+ ARG_MIN = 79,
+ FAKE_QUANT = 80,
+ REDUCE_PROD = 81,
+ REDUCE_MAX = 82,
+ PACK = 83,
+ LOGICAL_OR = 84,
+ ONE_HOT = 85,
+ LOGICAL_AND = 86,
+ LOGICAL_NOT = 87,
+ UNPACK = 88,
+ REDUCE_MIN = 89,
+ FLOOR_DIV = 90,
+ REDUCE_ANY = 91,
+ SQUARE = 92,
+ ZEROS_LIKE = 93,
+ FILL = 94,
+ FLOOR_MOD = 95,
+ RANGE = 96,
+ RESIZE_NEAREST_NEIGHBOR = 97,
+ LEAKY_RELU = 98,
+ SQUARED_DIFFERENCE = 99,
+ MIRROR_PAD = 100,
+ ABS = 101,
+ SPLIT_V = 102,
+ UNIQUE = 103,
+ CEIL = 104,
+ REVERSE_V2 = 105,
+ ADD_N = 106,
+ GATHER_ND = 107,
+ COS = 108,
+ WHERE = 109,
+ RANK = 110,
+ ELU = 111,
+ REVERSE_SEQUENCE = 112,
+ MATRIX_DIAG = 113,
+ QUANTIZE = 114,
+ MATRIX_SET_DIAG = 115,
+ ROUND = 116,
+ HARD_SWISH = 117,
+ IF = 118,
+ WHILE = 119,
+ NON_MAX_SUPPRESSION_V4 = 120,
+ NON_MAX_SUPPRESSION_V5 = 121,
+ SCATTER_ND = 122,
+ SELECT_V2 = 123,
+ DENSIFY = 124,
+ SEGMENT_SUM = 125,
+ BATCH_MATMUL = 126,
+ PLACEHOLDER_FOR_GREATER_OP_CODES = 127,
+ CUMSUM = 128,
+ CALL_ONCE = 129,
+ BROADCAST_TO = 130,
+ RFFT2D = 131,
+ CONV_3D = 132,
+ IMAG=133,
+ REAL=134,
+ COMPLEX_ABS=135,
+ HASHTABLE = 136,
+ HASHTABLE_FIND = 137,
+ HASHTABLE_IMPORT = 138,
+ HASHTABLE_SIZE = 139,
+ REDUCE_ALL = 140,
+ CONV_3D_TRANSPOSE = 141,
+ VAR_HANDLE = 142,
+ READ_VARIABLE = 143,
+ ASSIGN_VARIABLE = 144,
+ BROADCAST_ARGS = 145,
+ RANDOM_STANDARD_NORMAL = 146,
+ BUCKETIZE = 147,
+ RANDOM_UNIFORM = 148,
+ MULTINOMIAL = 149,
+ GELU = 150,
+ DYNAMIC_UPDATE_SLICE = 151,
+ RELU_0_TO_1 = 152,
+ UNSORTED_SEGMENT_PROD = 153,
+ UNSORTED_SEGMENT_MAX = 154,
+ UNSORTED_SEGMENT_SUM = 155,
+ ATAN2 = 156
+}
+// LINT.ThenChange(nnapi_linter/linter.proto)
+
+// Options for the builtin operators.
+union BuiltinOptions {
+ Conv2DOptions,
+ DepthwiseConv2DOptions,
+ ConcatEmbeddingsOptions,
+ LSHProjectionOptions,
+ Pool2DOptions,
+ SVDFOptions,
+ RNNOptions,
+ FullyConnectedOptions,
+ SoftmaxOptions,
+ ConcatenationOptions,
+ AddOptions,
+ L2NormOptions,
+ LocalResponseNormalizationOptions,
+ LSTMOptions,
+ ResizeBilinearOptions,
+ CallOptions,
+ ReshapeOptions,
+ SkipGramOptions,
+ SpaceToDepthOptions,
+ EmbeddingLookupSparseOptions,
+ MulOptions,
+ PadOptions,
+ GatherOptions,
+ BatchToSpaceNDOptions,
+ SpaceToBatchNDOptions,
+ TransposeOptions,
+ ReducerOptions,
+ SubOptions,
+ DivOptions,
+ SqueezeOptions,
+ SequenceRNNOptions,
+ StridedSliceOptions,
+ ExpOptions,
+ TopKV2Options,
+ SplitOptions,
+ LogSoftmaxOptions,
+ CastOptions,
+ DequantizeOptions,
+ MaximumMinimumOptions,
+ ArgMaxOptions,
+ LessOptions,
+ NegOptions,
+ PadV2Options,
+ GreaterOptions,
+ GreaterEqualOptions,
+ LessEqualOptions,
+ SelectOptions,
+ SliceOptions,
+ TransposeConvOptions,
+ SparseToDenseOptions,
+ TileOptions,
+ ExpandDimsOptions,
+ EqualOptions,
+ NotEqualOptions,
+ ShapeOptions,
+ PowOptions,
+ ArgMinOptions,
+ FakeQuantOptions,
+ PackOptions,
+ LogicalOrOptions,
+ OneHotOptions,
+ LogicalAndOptions,
+ LogicalNotOptions,
+ UnpackOptions,
+ FloorDivOptions,
+ SquareOptions,
+ ZerosLikeOptions,
+ FillOptions,
+ BidirectionalSequenceLSTMOptions,
+ BidirectionalSequenceRNNOptions,
+ UnidirectionalSequenceLSTMOptions,
+ FloorModOptions,
+ RangeOptions,
+ ResizeNearestNeighborOptions,
+ LeakyReluOptions,
+ SquaredDifferenceOptions,
+ MirrorPadOptions,
+ AbsOptions,
+ SplitVOptions,
+ UniqueOptions,
+ ReverseV2Options,
+ AddNOptions,
+ GatherNdOptions,
+ CosOptions,
+ WhereOptions,
+ RankOptions,
+ ReverseSequenceOptions,
+ MatrixDiagOptions,
+ QuantizeOptions,
+ MatrixSetDiagOptions,
+ HardSwishOptions,
+ IfOptions,
+ WhileOptions,
+ DepthToSpaceOptions,
+ NonMaxSuppressionV4Options,
+ NonMaxSuppressionV5Options,
+ ScatterNdOptions,
+ SelectV2Options,
+ DensifyOptions,
+ SegmentSumOptions,
+ BatchMatMulOptions,
+ CumsumOptions,
+ CallOnceOptions,
+ BroadcastToOptions,
+ Rfft2dOptions,
+ Conv3DOptions,
+ HashtableOptions,
+ HashtableFindOptions,
+ HashtableImportOptions,
+ HashtableSizeOptions,
+ VarHandleOptions,
+ ReadVariableOptions,
+ AssignVariableOptions,
+ RandomOptions,
+ BucketizeOptions,
+ GeluOptions,
+ DynamicUpdateSliceOptions,
+ UnsortedSegmentProdOptions,
+ UnsortedSegmentMaxOptions,
+ UnsortedSegmentSumOptions,
+ ATan2Options
+}
+
+// LINT.IfChange
+enum Padding : byte { SAME, VALID }
+// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td)
+
+// LINT.IfChange
+enum ActivationFunctionType : byte {
+ NONE = 0,
+ RELU = 1,
+ RELU_N1_TO_1 = 2,
+ RELU6 = 3,
+ TANH = 4,
+ SIGN_BIT = 5,
+}
+// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td)
+
+table Conv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+ dilation_w_factor:int = 1;
+ dilation_h_factor:int = 1;
+}
+
+// Options for both Conv3D and Conv3DTranspose.
+table Conv3DOptions {
+ padding:Padding;
+ stride_d:int;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+ dilation_d_factor:int = 1;
+ dilation_w_factor:int = 1;
+ dilation_h_factor:int = 1;
+}
+
+table Pool2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ filter_width:int;
+ filter_height:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table DepthwiseConv2DOptions {
+ // Parameters for DepthwiseConv version 1 or above.
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ // `depth_multiplier` is redundant. It's used by CPU kernels in
+ // TensorFlow 2.0 or below, but ignored in versions above.
+ // See comments in lite/c/builtin_op_data.h for more details.
+ depth_multiplier:int;
+ fused_activation_function:ActivationFunctionType;
+ // Parameters for DepthwiseConv version 2 or above.
+ dilation_w_factor:int = 1;
+ dilation_h_factor:int = 1;
+}
+
+table ConcatEmbeddingsOptions {
+ num_channels:int;
+ num_columns_per_channel:[int];
+ embedding_dim_per_channel:[int]; // This could be inferred from parameters.
+}
+
+enum LSHProjectionType: byte {
+ UNKNOWN = 0,
+ SPARSE = 1,
+ DENSE = 2,
+}
+
+table LSHProjectionOptions {
+ type: LSHProjectionType;
+}
+
+table SVDFOptions {
+ rank:int;
+ fused_activation_function:ActivationFunctionType;
+ // For weights-only quantization, use asymmetric quantization for non
+ // constant inputs at evaluation time.
+ asymmetric_quantize_inputs:bool;
+}
+
+// An implementation of TensorFlow RNNCell.
+table RNNOptions {
+ fused_activation_function:ActivationFunctionType;
+ asymmetric_quantize_inputs:bool;
+}
+
+// An implementation of TensorFlow dynamic_rnn with RNNCell.
+table SequenceRNNOptions {
+ time_major:bool;
+ fused_activation_function:ActivationFunctionType;
+ asymmetric_quantize_inputs:bool;
+}
+
+// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell.
+table BidirectionalSequenceRNNOptions {
+ time_major:bool;
+ fused_activation_function:ActivationFunctionType;
+ merge_outputs: bool;
+ asymmetric_quantize_inputs:bool;
+}
+
+// LINT.IfChange
+enum FullyConnectedOptionsWeightsFormat: byte {
+ DEFAULT = 0,
+ SHUFFLED4x16INT8 = 1,
+}
+// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td)
+
+// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
+table FullyConnectedOptions {
+ // Parameters for FullyConnected version 1 or above.
+ fused_activation_function:ActivationFunctionType;
+
+ // Parameters for FullyConnected version 2 or above.
+ weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT;
+
+ // Parameters for FullyConnected version 5 or above.
+ // If set to true, then the number of dimension is preserved. Furthermore,
+ // all but the last dimension of the input and output shapes will be equal.
+ keep_num_dims: bool;
+
+ // Parameters for FullyConnected version 7 or above.
+ // If set to true, then weights-only op will use asymmetric quantization for
+ // inputs.
+ asymmetric_quantize_inputs: bool;
+}
+
+table SoftmaxOptions {
+ beta: float;
+}
+
+// An implementation of TensorFlow concat.
+table ConcatenationOptions {
+ axis:int;
+ fused_activation_function:ActivationFunctionType;
+}
+
+table AddOptions {
+ fused_activation_function:ActivationFunctionType;
+ // Parameters supported by version 3.
+ pot_scale_int16:bool = true;
+}
+
+table MulOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table L2NormOptions {
+ // This field is currently ignored in the L2 Norm Op.
+ fused_activation_function:ActivationFunctionType;
+}
+
+table LocalResponseNormalizationOptions {
+ radius:int;
+ bias:float;
+ alpha:float;
+ beta:float;
+}
+
+// LINT.IfChange
+enum LSTMKernelType : byte {
+ // Full LSTM kernel which supports peephole and projection.
+ FULL = 0,
+ // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell.
+ BASIC = 1,
+}
+// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td)
+
+// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
+table LSTMOptions {
+ // Parameters for LSTM version 1 or above.
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+
+ // Parameters for LSTM version 2 or above.
+ // Basic kernel is only supported in version 2 or above.
+ kernel_type: LSTMKernelType = FULL;
+
+ // Parameters for LSTM version 4 or above.
+ asymmetric_quantize_inputs: bool;
+}
+
+// An implementation of TensorFlow dynamic_rnn with LSTMCell.
+table UnidirectionalSequenceLSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+
+ // If true then first dimension is sequence, otherwise batch.
+ time_major:bool;
+
+ // Parameter for Unidirectional Sequence LSTM version 4.
+ asymmetric_quantize_inputs:bool;
+}
+
+table BidirectionalSequenceLSTMOptions {
+ // Parameters supported by version 1:
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+
+ // If true, store the outputs of both directions into the first output.
+ merge_outputs: bool;
+
+ // Parameters supported by version 2:
+ // If true then first dimension is sequence, otherwise batch.
+ // Version 1 implementations assumed time_major to be true, so this default
+ // value should never change.
+ time_major: bool = true;
+
+ // Parameters for version 3 or above.
+ asymmetric_quantize_inputs:bool;
+}
+
+table ResizeBilinearOptions {
+ new_height: int (deprecated);
+ new_width: int (deprecated);
+ align_corners: bool;
+ half_pixel_centers: bool;
+}
+
+table ResizeNearestNeighborOptions {
+ align_corners: bool;
+ half_pixel_centers: bool;
+}
+
+// A call operation options
+table CallOptions {
+ // The subgraph index that needs to be called.
+ subgraph:uint;
+}
+
+table PadOptions {
+}
+
+table PadV2Options {
+}
+
+table ReshapeOptions {
+ new_shape:[int];
+}
+
+table SpaceToBatchNDOptions {
+}
+
+table BatchToSpaceNDOptions {
+}
+
+table SkipGramOptions {
+ ngram_size: int;
+ max_skip_size: int;
+ include_all_ngrams: bool;
+}
+
+table SpaceToDepthOptions {
+ block_size: int;
+}
+
+table DepthToSpaceOptions {
+ block_size: int;
+}
+
+table SubOptions {
+ fused_activation_function:ActivationFunctionType;
+ // Parameters supported by version 5
+ pot_scale_int16:bool = true;
+}
+
+table DivOptions {
+ fused_activation_function:ActivationFunctionType;
+}
+
+table TopKV2Options {
+}
+
+enum CombinerType : byte {
+ SUM = 0,
+ MEAN = 1,
+ SQRTN = 2,
+}
+
+table EmbeddingLookupSparseOptions {
+ combiner:CombinerType;
+}
+
+table GatherOptions {
+ axis: int;
+ // Parameters for Gather version 5 or above.
+ batch_dims: int = 0;
+}
+
+table TransposeOptions {
+}
+
+table ExpOptions {
+}
+
+table CosOptions {
+}
+
+table ReducerOptions {
+ keep_dims: bool;
+}
+
+table SqueezeOptions {
+ squeeze_dims:[int];
+}
+
+table SplitOptions {
+ num_splits: int;
+}
+
+table SplitVOptions {
+ num_splits: int;
+}
+
+table StridedSliceOptions {
+ begin_mask: int;
+ end_mask: int;
+ ellipsis_mask: int;
+ new_axis_mask: int;
+ shrink_axis_mask: int;
+}
+
+table LogSoftmaxOptions {
+}
+
+table CastOptions {
+ in_data_type: TensorType;
+ out_data_type: TensorType;
+}
+
+table DequantizeOptions {
+}
+
+table MaximumMinimumOptions {
+}
+
+table TileOptions {
+}
+
+table ArgMaxOptions {
+ output_type : TensorType;
+}
+
+table ArgMinOptions {
+ output_type : TensorType;
+}
+
+table GreaterOptions {
+}
+
+table GreaterEqualOptions {
+}
+
+table LessOptions {
+}
+
+table LessEqualOptions {
+}
+
+table NegOptions {
+}
+
+table SelectOptions {
+}
+
+table SliceOptions {
+}
+
+table TransposeConvOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+}
+
+table ExpandDimsOptions {
+}
+
+table SparseToDenseOptions {
+ validate_indices:bool;
+}
+
+table EqualOptions {
+}
+
+table NotEqualOptions {
+}
+
+table ShapeOptions {
+ // Optional output type of the operation (int32 or int64). Defaults to int32.
+ out_type : TensorType;
+}
+
+table RankOptions {
+}
+
+table PowOptions {
+}
+
+table FakeQuantOptions {
+ // Parameters supported by version 1:
+ min:float;
+ max:float;
+ num_bits:int;
+
+ // Parameters supported by version 2:
+ narrow_range:bool;
+}
+
+table PackOptions {
+ values_count:int;
+ axis:int;
+}
+
+table LogicalOrOptions {
+}
+
+table OneHotOptions {
+ axis:int;
+}
+
+table AbsOptions {
+}
+
+
+table HardSwishOptions {
+}
+
+table LogicalAndOptions {
+}
+
+table LogicalNotOptions {
+}
+
+table UnpackOptions {
+ num:int;
+ axis:int;
+}
+
+table FloorDivOptions {
+}
+
+table SquareOptions {
+}
+
+table ZerosLikeOptions {
+}
+
+table FillOptions {
+}
+
+table FloorModOptions {
+}
+
+table RangeOptions {
+}
+
+table LeakyReluOptions {
+ alpha:float;
+}
+
+table SquaredDifferenceOptions {
+}
+
+// LINT.IfChange
+enum MirrorPadMode : byte {
+ // Doesn't include borders.
+ REFLECT = 0,
+ // Includes borders.
+ SYMMETRIC = 1,
+}
+// LINT.ThenChange(//tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td)
+
+table MirrorPadOptions {
+ mode:MirrorPadMode;
+}
+
+table UniqueOptions {
+ idx_out_type:TensorType = INT32;
+}
+
+table ReverseV2Options {
+}
+
+table AddNOptions {
+}
+
+table GatherNdOptions {
+}
+
+table WhereOptions {
+}
+
+table ReverseSequenceOptions {
+ seq_dim:int;
+ batch_dim:int = 0;
+}
+
+table MatrixDiagOptions {
+}
+
+table QuantizeOptions {
+}
+
+table MatrixSetDiagOptions {
+}
+
+table IfOptions {
+ then_subgraph_index:int;
+ else_subgraph_index:int;
+}
+
+table CallOnceOptions {
+ init_subgraph_index:int;
+}
+
+table WhileOptions {
+ cond_subgraph_index:int;
+ body_subgraph_index:int;
+}
+
+table NonMaxSuppressionV4Options {
+}
+
+table NonMaxSuppressionV5Options {
+}
+
+table ScatterNdOptions {
+}
+
+table SelectV2Options {
+}
+
+table DensifyOptions {
+}
+
+table SegmentSumOptions {
+}
+
+table BatchMatMulOptions {
+ adj_x:bool;
+ adj_y:bool;
+ // Parameters for BatchMatMul version 4 or above.
+ // If set to true, then weights-only op will use asymmetric quantization for
+ // inputs.
+ asymmetric_quantize_inputs: bool;
+}
+
+table CumsumOptions {
+ exclusive:bool;
+ reverse:bool;
+}
+
+table BroadcastToOptions {
+}
+
+table Rfft2dOptions {
+}
+
+table HashtableOptions {
+ // The identity of hash tables. This identity will be used across different
+ // subgraphs in the same interpreter instance.
+ table_id:int;
+ key_dtype:TensorType;
+ value_dtype:TensorType;
+}
+
+table HashtableFindOptions {
+}
+
+table HashtableImportOptions {
+}
+
+table HashtableSizeOptions {
+}
+
+table VarHandleOptions {
+ container:string;
+ shared_name:string;
+}
+
+table ReadVariableOptions {
+}
+
+table AssignVariableOptions {
+}
+
+table RandomOptions {
+ seed: long;
+ seed2: long;
+}
+
+table BucketizeOptions {
+ boundaries: [float]; // The bucket boundaries.
+}
+
+table GeluOptions {
+ approximate: bool;
+}
+
+table DynamicUpdateSliceOptions {
+}
+
+table UnsortedSegmentProdOptions {
+}
+
+table UnsortedSegmentMaxOptions {
+}
+
+table UnsortedSegmentSumOptions {
+}
+
+table ATan2Options {
+}
+
+
+// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
+// builtin, or a string if the operator is custom.
+table OperatorCode {
+ // This field is for backward compatibility. This field will be used when
+ // the value of the extended builtin_code field has less than
+ // BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES.
+ deprecated_builtin_code:byte;
+ custom_code:string;
+
+ // The version of the operator. The version need to be bumped whenever new
+ // parameters are introduced into an op.
+ version:int = 1;
+
+ // This field is introduced for resolving op builtin code shortage problem
+ // (the original BuiltinOperator enum field was represented as a byte).
+ // This field will be used when the value of the extended builtin_code field
+ // has greater than BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES.
+ builtin_code:BuiltinOperator;
+}
+
+enum CustomOptionsFormat : byte {
+ FLEXBUFFERS = 0,
+}
+
+// An operator takes tensors as inputs and outputs. The type of operation being
+// performed is determined by an index into the list of valid OperatorCodes,
+// while the specifics of each operations is configured using builtin_options
+// or custom_options.
+table Operator {
+ // Index into the operator_codes array. Using an integer here avoids
+ // complicate map lookups.
+ opcode_index:uint;
+
+ // Optional input are indicated by -1.
+ inputs:[int];
+ outputs:[int];
+
+ builtin_options:BuiltinOptions;
+ custom_options:[ubyte];
+ custom_options_format:CustomOptionsFormat;
+
+ // A list of booleans indicating the input tensors which are being mutated by
+ // this operator.(e.g. used by RNN and LSTM).
+ // For example, if the "inputs" array refers to 5 tensors and the second and
+ // fifth are mutable variables, then this list will contain
+ // [false, true, false, false, true].
+ //
+ // If the list is empty, no variable is mutated in this operator.
+ // The list either has the same length as `inputs`, or is empty.
+ mutating_variable_inputs:[bool];
+
+ // A list of indices to the subgraph's "tensors" that are internal to an Op.
+ // Internal tensors are those that do not flow in or out of the operation,
+ // but instead are part of internal computation. As such, the operation's
+ // implementation may manage its memory more efficiently. They are needed
+ // however (i.e. not just an implementation detail) since they are part of the
+ // computation, which may require relevant metadata such as quantization
+ // parameters.
+ intermediates:[int];
+}
+
+// The root type, defining a subgraph, which typically represents an entire
+// model.
+table SubGraph {
+ // A list of all tensors used in this subgraph.
+ tensors:[Tensor];
+
+ // Indices of the tensors that are inputs into this subgraph. Note this is
+ // the list of non-static tensors that feed into the subgraph for inference.
+ inputs:[int];
+
+ // Indices of the tensors that are outputs out of this subgraph. Note this is
+ // the list of output tensors that are considered the product of the
+ // subgraph's inference.
+ outputs:[int];
+
+ // All operators, in execution order.
+ operators:[Operator];
+
+ // Name of this subgraph (used for debugging).
+ name:string;
+}
+
+// Table of raw data buffers (used for constant tensors). Referenced by tensors
+// by index. The generous alignment accommodates mmap-friendly data structures.
+table Buffer {
+ data:[ubyte] (force_align: 16);
+}
+
+table Metadata {
+ // A human readable string to uniquely identify a Metadata.
+ name:string;
+ // An index to the buffers table.
+ buffer:uint;
+}
+
+// Map from an alias name of tensor to tensor index in the graph.
+// This is used in Signature def.
+table TensorMap {
+ // Represents the alias to use for this tensor.
+ name:string;
+
+ // The actual tensor index in the primary graph, that 'name' corresponds to.
+ tensor_index:uint;
+}
+
+// This corresponds to SignatureDef in Tensorflow SavedModel.
+// The SignatureDef will be part of the SavedModel provided for conversion.
+table SignatureDef {
+ // Named inputs for this signature.
+ inputs:[TensorMap];
+
+ // Named outputs for this signature.
+ outputs:[TensorMap];
+
+ // Key value which was in the Tensorflow SavedModel SignatureDef map.
+ signature_key:string;
+
+ // Model tag, deprecated.
+ deprecated_tag:string (deprecated);
+
+ // Index of subgraphs that corresponds to the exported method.
+ subgraph_index:uint;
+}
+
+table Model {
+ // Version of the schema.
+ version:uint;
+
+ // A list of all operator codes used in this model. This is
+ // kept in order because operators carry an index into this
+ // vector.
+ operator_codes:[OperatorCode];
+
+ // All the subgraphs of the model. The 0th is assumed to be the main
+ // model.
+ subgraphs:[SubGraph];
+
+ // A description of the model.
+ description:string;
+
+ // Buffers of the model.
+ // Note the 0th entry of this array must be an empty buffer (sentinel).
+ // This is a convention so that tensors without a buffer can provide 0 as
+ // their buffer.
+ buffers:[Buffer];
+
+ // Metadata about the model. Indirects into the existings buffers list.
+ // Deprecated, prefer to use metadata field.
+ metadata_buffer:[int];
+
+ // Metadata about the model.
+ metadata:[Metadata];
+
+ // Optional SignatureDefs for the model.
+ signature_defs:[SignatureDef];
+}
+
+root_type Model;
diff --git a/runtime/onert/core/src/loader/tflite_schema_generated.h b/runtime/onert/core/src/loader/tflite_schema_generated.h
new file mode 100644
index 000000000..9d891841a
--- /dev/null
+++ b/runtime/onert/core/src/loader/tflite_schema_generated.h
@@ -0,0 +1,11989 @@
+/*
+ * Copyright (c) 2019-2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// automatically generated by the FlatBuffers compiler, do not modify
+
+#ifndef FLATBUFFERS_GENERATED_TFLITESCHEMA_ONERT_TFLITE_H_
+#define FLATBUFFERS_GENERATED_TFLITESCHEMA_ONERT_TFLITE_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+// Ensure the included flatbuffers.h is the same version as when this file was
+// generated, otherwise it may not be compatible.
+static_assert(FLATBUFFERS_VERSION_MAJOR == 23 && FLATBUFFERS_VERSION_MINOR == 5 &&
+ FLATBUFFERS_VERSION_REVISION == 26,
+ "Non-compatible flatbuffers version included");
+
+namespace onert_tflite
+{
+
+struct CustomQuantization;
+struct CustomQuantizationBuilder;
+
+struct QuantizationParameters;
+struct QuantizationParametersBuilder;
+
+struct Int32Vector;
+struct Int32VectorBuilder;
+
+struct Uint16Vector;
+struct Uint16VectorBuilder;
+
+struct Uint8Vector;
+struct Uint8VectorBuilder;
+
+struct DimensionMetadata;
+struct DimensionMetadataBuilder;
+
+struct SparsityParameters;
+struct SparsityParametersBuilder;
+
+struct Tensor;
+struct TensorBuilder;
+
+struct Conv2DOptions;
+struct Conv2DOptionsBuilder;
+
+struct Conv3DOptions;
+struct Conv3DOptionsBuilder;
+
+struct Pool2DOptions;
+struct Pool2DOptionsBuilder;
+
+struct DepthwiseConv2DOptions;
+struct DepthwiseConv2DOptionsBuilder;
+
+struct ConcatEmbeddingsOptions;
+struct ConcatEmbeddingsOptionsBuilder;
+
+struct LSHProjectionOptions;
+struct LSHProjectionOptionsBuilder;
+
+struct SVDFOptions;
+struct SVDFOptionsBuilder;
+
+struct RNNOptions;
+struct RNNOptionsBuilder;
+
+struct SequenceRNNOptions;
+struct SequenceRNNOptionsBuilder;
+
+struct BidirectionalSequenceRNNOptions;
+struct BidirectionalSequenceRNNOptionsBuilder;
+
+struct FullyConnectedOptions;
+struct FullyConnectedOptionsBuilder;
+
+struct SoftmaxOptions;
+struct SoftmaxOptionsBuilder;
+
+struct ConcatenationOptions;
+struct ConcatenationOptionsBuilder;
+
+struct AddOptions;
+struct AddOptionsBuilder;
+
+struct MulOptions;
+struct MulOptionsBuilder;
+
+struct L2NormOptions;
+struct L2NormOptionsBuilder;
+
+struct LocalResponseNormalizationOptions;
+struct LocalResponseNormalizationOptionsBuilder;
+
+struct LSTMOptions;
+struct LSTMOptionsBuilder;
+
+struct UnidirectionalSequenceLSTMOptions;
+struct UnidirectionalSequenceLSTMOptionsBuilder;
+
+struct BidirectionalSequenceLSTMOptions;
+struct BidirectionalSequenceLSTMOptionsBuilder;
+
+struct ResizeBilinearOptions;
+struct ResizeBilinearOptionsBuilder;
+
+struct ResizeNearestNeighborOptions;
+struct ResizeNearestNeighborOptionsBuilder;
+
+struct CallOptions;
+struct CallOptionsBuilder;
+
+struct PadOptions;
+struct PadOptionsBuilder;
+
+struct PadV2Options;
+struct PadV2OptionsBuilder;
+
+struct ReshapeOptions;
+struct ReshapeOptionsBuilder;
+
+struct SpaceToBatchNDOptions;
+struct SpaceToBatchNDOptionsBuilder;
+
+struct BatchToSpaceNDOptions;
+struct BatchToSpaceNDOptionsBuilder;
+
+struct SkipGramOptions;
+struct SkipGramOptionsBuilder;
+
+struct SpaceToDepthOptions;
+struct SpaceToDepthOptionsBuilder;
+
+struct DepthToSpaceOptions;
+struct DepthToSpaceOptionsBuilder;
+
+struct SubOptions;
+struct SubOptionsBuilder;
+
+struct DivOptions;
+struct DivOptionsBuilder;
+
+struct TopKV2Options;
+struct TopKV2OptionsBuilder;
+
+struct EmbeddingLookupSparseOptions;
+struct EmbeddingLookupSparseOptionsBuilder;
+
+struct GatherOptions;
+struct GatherOptionsBuilder;
+
+struct TransposeOptions;
+struct TransposeOptionsBuilder;
+
+struct ExpOptions;
+struct ExpOptionsBuilder;
+
+struct CosOptions;
+struct CosOptionsBuilder;
+
+struct ReducerOptions;
+struct ReducerOptionsBuilder;
+
+struct SqueezeOptions;
+struct SqueezeOptionsBuilder;
+
+struct SplitOptions;
+struct SplitOptionsBuilder;
+
+struct SplitVOptions;
+struct SplitVOptionsBuilder;
+
+struct StridedSliceOptions;
+struct StridedSliceOptionsBuilder;
+
+struct LogSoftmaxOptions;
+struct LogSoftmaxOptionsBuilder;
+
+struct CastOptions;
+struct CastOptionsBuilder;
+
+struct DequantizeOptions;
+struct DequantizeOptionsBuilder;
+
+struct MaximumMinimumOptions;
+struct MaximumMinimumOptionsBuilder;
+
+struct TileOptions;
+struct TileOptionsBuilder;
+
+struct ArgMaxOptions;
+struct ArgMaxOptionsBuilder;
+
+struct ArgMinOptions;
+struct ArgMinOptionsBuilder;
+
+struct GreaterOptions;
+struct GreaterOptionsBuilder;
+
+struct GreaterEqualOptions;
+struct GreaterEqualOptionsBuilder;
+
+struct LessOptions;
+struct LessOptionsBuilder;
+
+struct LessEqualOptions;
+struct LessEqualOptionsBuilder;
+
+struct NegOptions;
+struct NegOptionsBuilder;
+
+struct SelectOptions;
+struct SelectOptionsBuilder;
+
+struct SliceOptions;
+struct SliceOptionsBuilder;
+
+struct TransposeConvOptions;
+struct TransposeConvOptionsBuilder;
+
+struct ExpandDimsOptions;
+struct ExpandDimsOptionsBuilder;
+
+struct SparseToDenseOptions;
+struct SparseToDenseOptionsBuilder;
+
+struct EqualOptions;
+struct EqualOptionsBuilder;
+
+struct NotEqualOptions;
+struct NotEqualOptionsBuilder;
+
+struct ShapeOptions;
+struct ShapeOptionsBuilder;
+
+struct RankOptions;
+struct RankOptionsBuilder;
+
+struct PowOptions;
+struct PowOptionsBuilder;
+
+struct FakeQuantOptions;
+struct FakeQuantOptionsBuilder;
+
+struct PackOptions;
+struct PackOptionsBuilder;
+
+struct LogicalOrOptions;
+struct LogicalOrOptionsBuilder;
+
+struct OneHotOptions;
+struct OneHotOptionsBuilder;
+
+struct AbsOptions;
+struct AbsOptionsBuilder;
+
+struct HardSwishOptions;
+struct HardSwishOptionsBuilder;
+
+struct LogicalAndOptions;
+struct LogicalAndOptionsBuilder;
+
+struct LogicalNotOptions;
+struct LogicalNotOptionsBuilder;
+
+struct UnpackOptions;
+struct UnpackOptionsBuilder;
+
+struct FloorDivOptions;
+struct FloorDivOptionsBuilder;
+
+struct SquareOptions;
+struct SquareOptionsBuilder;
+
+struct ZerosLikeOptions;
+struct ZerosLikeOptionsBuilder;
+
+struct FillOptions;
+struct FillOptionsBuilder;
+
+struct FloorModOptions;
+struct FloorModOptionsBuilder;
+
+struct RangeOptions;
+struct RangeOptionsBuilder;
+
+struct LeakyReluOptions;
+struct LeakyReluOptionsBuilder;
+
+struct SquaredDifferenceOptions;
+struct SquaredDifferenceOptionsBuilder;
+
+struct MirrorPadOptions;
+struct MirrorPadOptionsBuilder;
+
+struct UniqueOptions;
+struct UniqueOptionsBuilder;
+
+struct ReverseV2Options;
+struct ReverseV2OptionsBuilder;
+
+struct AddNOptions;
+struct AddNOptionsBuilder;
+
+struct GatherNdOptions;
+struct GatherNdOptionsBuilder;
+
+struct WhereOptions;
+struct WhereOptionsBuilder;
+
+struct ReverseSequenceOptions;
+struct ReverseSequenceOptionsBuilder;
+
+struct MatrixDiagOptions;
+struct MatrixDiagOptionsBuilder;
+
+struct QuantizeOptions;
+struct QuantizeOptionsBuilder;
+
+struct MatrixSetDiagOptions;
+struct MatrixSetDiagOptionsBuilder;
+
+struct IfOptions;
+struct IfOptionsBuilder;
+
+struct CallOnceOptions;
+struct CallOnceOptionsBuilder;
+
+struct WhileOptions;
+struct WhileOptionsBuilder;
+
+struct NonMaxSuppressionV4Options;
+struct NonMaxSuppressionV4OptionsBuilder;
+
+struct NonMaxSuppressionV5Options;
+struct NonMaxSuppressionV5OptionsBuilder;
+
+struct ScatterNdOptions;
+struct ScatterNdOptionsBuilder;
+
+struct SelectV2Options;
+struct SelectV2OptionsBuilder;
+
+struct DensifyOptions;
+struct DensifyOptionsBuilder;
+
+struct SegmentSumOptions;
+struct SegmentSumOptionsBuilder;
+
+struct BatchMatMulOptions;
+struct BatchMatMulOptionsBuilder;
+
+struct CumsumOptions;
+struct CumsumOptionsBuilder;
+
+struct BroadcastToOptions;
+struct BroadcastToOptionsBuilder;
+
+struct Rfft2dOptions;
+struct Rfft2dOptionsBuilder;
+
+struct HashtableOptions;
+struct HashtableOptionsBuilder;
+
+struct HashtableFindOptions;
+struct HashtableFindOptionsBuilder;
+
+struct HashtableImportOptions;
+struct HashtableImportOptionsBuilder;
+
+struct HashtableSizeOptions;
+struct HashtableSizeOptionsBuilder;
+
+struct VarHandleOptions;
+struct VarHandleOptionsBuilder;
+
+struct ReadVariableOptions;
+struct ReadVariableOptionsBuilder;
+
+struct AssignVariableOptions;
+struct AssignVariableOptionsBuilder;
+
+struct RandomOptions;
+struct RandomOptionsBuilder;
+
+struct BucketizeOptions;
+struct BucketizeOptionsBuilder;
+
+struct GeluOptions;
+struct GeluOptionsBuilder;
+
+struct DynamicUpdateSliceOptions;
+struct DynamicUpdateSliceOptionsBuilder;
+
+struct UnsortedSegmentProdOptions;
+struct UnsortedSegmentProdOptionsBuilder;
+
+struct UnsortedSegmentMaxOptions;
+struct UnsortedSegmentMaxOptionsBuilder;
+
+struct UnsortedSegmentSumOptions;
+struct UnsortedSegmentSumOptionsBuilder;
+
+struct ATan2Options;
+struct ATan2OptionsBuilder;
+
+struct OperatorCode;
+struct OperatorCodeBuilder;
+
+struct Operator;
+struct OperatorBuilder;
+
+struct SubGraph;
+struct SubGraphBuilder;
+
+struct Buffer;
+struct BufferBuilder;
+
+struct Metadata;
+struct MetadataBuilder;
+
+struct TensorMap;
+struct TensorMapBuilder;
+
+struct SignatureDef;
+struct SignatureDefBuilder;
+
+struct Model;
+struct ModelBuilder;
+
+enum TensorType : int8_t
+{
+ TensorType_FLOAT32 = 0,
+ TensorType_FLOAT16 = 1,
+ TensorType_INT32 = 2,
+ TensorType_UINT8 = 3,
+ TensorType_INT64 = 4,
+ TensorType_STRING = 5,
+ TensorType_BOOL = 6,
+ TensorType_INT16 = 7,
+ TensorType_COMPLEX64 = 8,
+ TensorType_INT8 = 9,
+ TensorType_FLOAT64 = 10,
+ TensorType_COMPLEX128 = 11,
+ TensorType_UINT64 = 12,
+ TensorType_RESOURCE = 13,
+ TensorType_VARIANT = 14,
+ TensorType_UINT32 = 15,
+ TensorType_UINT16 = 16,
+ TensorType_MIN = TensorType_FLOAT32,
+ TensorType_MAX = TensorType_UINT16
+};
+
+inline const TensorType (&EnumValuesTensorType())[17]
+{
+ static const TensorType values[] = {
+ TensorType_FLOAT32, TensorType_FLOAT16, TensorType_INT32, TensorType_UINT8,
+ TensorType_INT64, TensorType_STRING, TensorType_BOOL, TensorType_INT16,
+ TensorType_COMPLEX64, TensorType_INT8, TensorType_FLOAT64, TensorType_COMPLEX128,
+ TensorType_UINT64, TensorType_RESOURCE, TensorType_VARIANT, TensorType_UINT32,
+ TensorType_UINT16};
+ return values;
+}
+
+inline const char *const *EnumNamesTensorType()
+{
+ static const char *const names[18] = {"FLOAT32", "FLOAT16", "INT32", "UINT8", "INT64",
+ "STRING", "BOOL", "INT16", "COMPLEX64", "INT8",
+ "FLOAT64", "COMPLEX128", "UINT64", "RESOURCE", "VARIANT",
+ "UINT32", "UINT16", nullptr};
+ return names;
+}
+
+inline const char *EnumNameTensorType(TensorType e)
+{
+ if (::flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_UINT16))
+ return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesTensorType()[index];
+}
+
+enum QuantizationDetails : uint8_t
+{
+ QuantizationDetails_NONE = 0,
+ QuantizationDetails_CustomQuantization = 1,
+ QuantizationDetails_MIN = QuantizationDetails_NONE,
+ QuantizationDetails_MAX = QuantizationDetails_CustomQuantization
+};
+
+inline const QuantizationDetails (&EnumValuesQuantizationDetails())[2]
+{
+ static const QuantizationDetails values[] = {QuantizationDetails_NONE,
+ QuantizationDetails_CustomQuantization};
+ return values;
+}
+
+inline const char *const *EnumNamesQuantizationDetails()
+{
+ static const char *const names[3] = {"NONE", "CustomQuantization", nullptr};
+ return names;
+}
+
+inline const char *EnumNameQuantizationDetails(QuantizationDetails e)
+{
+ if (::flatbuffers::IsOutRange(e, QuantizationDetails_NONE,
+ QuantizationDetails_CustomQuantization))
+ return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesQuantizationDetails()[index];
+}
+
+template <typename T> struct QuantizationDetailsTraits
+{
+ static const QuantizationDetails enum_value = QuantizationDetails_NONE;
+};
+
+template <> struct QuantizationDetailsTraits<onert_tflite::CustomQuantization>
+{
+ static const QuantizationDetails enum_value = QuantizationDetails_CustomQuantization;
+};
+
+bool VerifyQuantizationDetails(::flatbuffers::Verifier &verifier, const void *obj,
+ QuantizationDetails type);
+bool VerifyQuantizationDetailsVector(
+ ::flatbuffers::Verifier &verifier,
+ const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values,
+ const ::flatbuffers::Vector<uint8_t> *types);
+
+enum DimensionType : int8_t
+{
+ DimensionType_DENSE = 0,
+ DimensionType_SPARSE_CSR = 1,
+ DimensionType_MIN = DimensionType_DENSE,
+ DimensionType_MAX = DimensionType_SPARSE_CSR
+};
+
+inline const DimensionType (&EnumValuesDimensionType())[2]
+{
+ static const DimensionType values[] = {DimensionType_DENSE, DimensionType_SPARSE_CSR};
+ return values;
+}
+
+inline const char *const *EnumNamesDimensionType()
+{
+ static const char *const names[3] = {"DENSE", "SPARSE_CSR", nullptr};
+ return names;
+}
+
+inline const char *EnumNameDimensionType(DimensionType e)
+{
+ if (::flatbuffers::IsOutRange(e, DimensionType_DENSE, DimensionType_SPARSE_CSR))
+ return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesDimensionType()[index];
+}
+
+enum SparseIndexVector : uint8_t
+{
+ SparseIndexVector_NONE = 0,
+ SparseIndexVector_Int32Vector = 1,
+ SparseIndexVector_Uint16Vector = 2,
+ SparseIndexVector_Uint8Vector = 3,
+ SparseIndexVector_MIN = SparseIndexVector_NONE,
+ SparseIndexVector_MAX = SparseIndexVector_Uint8Vector
+};
+
+inline const SparseIndexVector (&EnumValuesSparseIndexVector())[4]
+{
+ static const SparseIndexVector values[] = {SparseIndexVector_NONE, SparseIndexVector_Int32Vector,
+ SparseIndexVector_Uint16Vector,
+ SparseIndexVector_Uint8Vector};
+ return values;
+}
+
+inline const char *const *EnumNamesSparseIndexVector()
+{
+ static const char *const names[5] = {"NONE", "Int32Vector", "Uint16Vector", "Uint8Vector",
+ nullptr};
+ return names;
+}
+
+inline const char *EnumNameSparseIndexVector(SparseIndexVector e)
+{
+ if (::flatbuffers::IsOutRange(e, SparseIndexVector_NONE, SparseIndexVector_Uint8Vector))
+ return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesSparseIndexVector()[index];
+}
+
+template <typename T> struct SparseIndexVectorTraits
+{
+ static const SparseIndexVector enum_value = SparseIndexVector_NONE;
+};
+
+template <> struct SparseIndexVectorTraits<onert_tflite::Int32Vector>
+{
+ static const SparseIndexVector enum_value = SparseIndexVector_Int32Vector;
+};
+
+template <> struct SparseIndexVectorTraits<onert_tflite::Uint16Vector>
+{
+ static const SparseIndexVector enum_value = SparseIndexVector_Uint16Vector;
+};
+
+template <> struct SparseIndexVectorTraits<onert_tflite::Uint8Vector>
+{
+ static const SparseIndexVector enum_value = SparseIndexVector_Uint8Vector;
+};
+
+bool VerifySparseIndexVector(::flatbuffers::Verifier &verifier, const void *obj,
+ SparseIndexVector type);
+bool VerifySparseIndexVectorVector(::flatbuffers::Verifier &verifier,
+ const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values,
+ const ::flatbuffers::Vector<uint8_t> *types);
+
+enum BuiltinOperator : int32_t
+{
+ BuiltinOperator_ADD = 0,
+ BuiltinOperator_AVERAGE_POOL_2D = 1,
+ BuiltinOperator_CONCATENATION = 2,
+ BuiltinOperator_CONV_2D = 3,
+ BuiltinOperator_DEPTHWISE_CONV_2D = 4,
+ BuiltinOperator_DEPTH_TO_SPACE = 5,
+ BuiltinOperator_DEQUANTIZE = 6,
+ BuiltinOperator_EMBEDDING_LOOKUP = 7,
+ BuiltinOperator_FLOOR = 8,
+ BuiltinOperator_FULLY_CONNECTED = 9,
+ BuiltinOperator_HASHTABLE_LOOKUP = 10,
+ BuiltinOperator_L2_NORMALIZATION = 11,
+ BuiltinOperator_L2_POOL_2D = 12,
+ BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION = 13,
+ BuiltinOperator_LOGISTIC = 14,
+ BuiltinOperator_LSH_PROJECTION = 15,
+ BuiltinOperator_LSTM = 16,
+ BuiltinOperator_MAX_POOL_2D = 17,
+ BuiltinOperator_MUL = 18,
+ BuiltinOperator_RELU = 19,
+ BuiltinOperator_RELU_N1_TO_1 = 20,
+ BuiltinOperator_RELU6 = 21,
+ BuiltinOperator_RESHAPE = 22,
+ BuiltinOperator_RESIZE_BILINEAR = 23,
+ BuiltinOperator_RNN = 24,
+ BuiltinOperator_SOFTMAX = 25,
+ BuiltinOperator_SPACE_TO_DEPTH = 26,
+ BuiltinOperator_SVDF = 27,
+ BuiltinOperator_TANH = 28,
+ BuiltinOperator_CONCAT_EMBEDDINGS = 29,
+ BuiltinOperator_SKIP_GRAM = 30,
+ BuiltinOperator_CALL = 31,
+ BuiltinOperator_CUSTOM = 32,
+ BuiltinOperator_EMBEDDING_LOOKUP_SPARSE = 33,
+ BuiltinOperator_PAD = 34,
+ BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN = 35,
+ BuiltinOperator_GATHER = 36,
+ BuiltinOperator_BATCH_TO_SPACE_ND = 37,
+ BuiltinOperator_SPACE_TO_BATCH_ND = 38,
+ BuiltinOperator_TRANSPOSE = 39,
+ BuiltinOperator_MEAN = 40,
+ BuiltinOperator_SUB = 41,
+ BuiltinOperator_DIV = 42,
+ BuiltinOperator_SQUEEZE = 43,
+ BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM = 44,
+ BuiltinOperator_STRIDED_SLICE = 45,
+ BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN = 46,
+ BuiltinOperator_EXP = 47,
+ BuiltinOperator_TOPK_V2 = 48,
+ BuiltinOperator_SPLIT = 49,
+ BuiltinOperator_LOG_SOFTMAX = 50,
+ BuiltinOperator_DELEGATE = 51,
+ BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM = 52,
+ BuiltinOperator_CAST = 53,
+ BuiltinOperator_PRELU = 54,
+ BuiltinOperator_MAXIMUM = 55,
+ BuiltinOperator_ARG_MAX = 56,
+ BuiltinOperator_MINIMUM = 57,
+ BuiltinOperator_LESS = 58,
+ BuiltinOperator_NEG = 59,
+ BuiltinOperator_PADV2 = 60,
+ BuiltinOperator_GREATER = 61,
+ BuiltinOperator_GREATER_EQUAL = 62,
+ BuiltinOperator_LESS_EQUAL = 63,
+ BuiltinOperator_SELECT = 64,
+ BuiltinOperator_SLICE = 65,
+ BuiltinOperator_SIN = 66,
+ BuiltinOperator_TRANSPOSE_CONV = 67,
+ BuiltinOperator_SPARSE_TO_DENSE = 68,
+ BuiltinOperator_TILE = 69,
+ BuiltinOperator_EXPAND_DIMS = 70,
+ BuiltinOperator_EQUAL = 71,
+ BuiltinOperator_NOT_EQUAL = 72,
+ BuiltinOperator_LOG = 73,
+ BuiltinOperator_SUM = 74,
+ BuiltinOperator_SQRT = 75,
+ BuiltinOperator_RSQRT = 76,
+ BuiltinOperator_SHAPE = 77,
+ BuiltinOperator_POW = 78,
+ BuiltinOperator_ARG_MIN = 79,
+ BuiltinOperator_FAKE_QUANT = 80,
+ BuiltinOperator_REDUCE_PROD = 81,
+ BuiltinOperator_REDUCE_MAX = 82,
+ BuiltinOperator_PACK = 83,
+ BuiltinOperator_LOGICAL_OR = 84,
+ BuiltinOperator_ONE_HOT = 85,
+ BuiltinOperator_LOGICAL_AND = 86,
+ BuiltinOperator_LOGICAL_NOT = 87,
+ BuiltinOperator_UNPACK = 88,
+ BuiltinOperator_REDUCE_MIN = 89,
+ BuiltinOperator_FLOOR_DIV = 90,
+ BuiltinOperator_REDUCE_ANY = 91,
+ BuiltinOperator_SQUARE = 92,
+ BuiltinOperator_ZEROS_LIKE = 93,
+ BuiltinOperator_FILL = 94,
+ BuiltinOperator_FLOOR_MOD = 95,
+ BuiltinOperator_RANGE = 96,
+ BuiltinOperator_RESIZE_NEAREST_NEIGHBOR = 97,
+ BuiltinOperator_LEAKY_RELU = 98,
+ BuiltinOperator_SQUARED_DIFFERENCE = 99,
+ BuiltinOperator_MIRROR_PAD = 100,
+ BuiltinOperator_ABS = 101,
+ BuiltinOperator_SPLIT_V = 102,
+ BuiltinOperator_UNIQUE = 103,
+ BuiltinOperator_CEIL = 104,
+ BuiltinOperator_REVERSE_V2 = 105,
+ BuiltinOperator_ADD_N = 106,
+ BuiltinOperator_GATHER_ND = 107,
+ BuiltinOperator_COS = 108,
+ BuiltinOperator_WHERE = 109,
+ BuiltinOperator_RANK = 110,
+ BuiltinOperator_ELU = 111,
+ BuiltinOperator_REVERSE_SEQUENCE = 112,
+ BuiltinOperator_MATRIX_DIAG = 113,
+ BuiltinOperator_QUANTIZE = 114,
+ BuiltinOperator_MATRIX_SET_DIAG = 115,
+ BuiltinOperator_ROUND = 116,
+ BuiltinOperator_HARD_SWISH = 117,
+ BuiltinOperator_IF = 118,
+ BuiltinOperator_WHILE = 119,
+ BuiltinOperator_NON_MAX_SUPPRESSION_V4 = 120,
+ BuiltinOperator_NON_MAX_SUPPRESSION_V5 = 121,
+ BuiltinOperator_SCATTER_ND = 122,
+ BuiltinOperator_SELECT_V2 = 123,
+ BuiltinOperator_DENSIFY = 124,
+ BuiltinOperator_SEGMENT_SUM = 125,
+ BuiltinOperator_BATCH_MATMUL = 126,
+ BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES = 127,
+ BuiltinOperator_CUMSUM = 128,
+ BuiltinOperator_CALL_ONCE = 129,
+ BuiltinOperator_BROADCAST_TO = 130,
+ BuiltinOperator_RFFT2D = 131,
+ BuiltinOperator_CONV_3D = 132,
+ BuiltinOperator_IMAG = 133,
+ BuiltinOperator_REAL = 134,
+ BuiltinOperator_COMPLEX_ABS = 135,
+ BuiltinOperator_HASHTABLE = 136,
+ BuiltinOperator_HASHTABLE_FIND = 137,
+ BuiltinOperator_HASHTABLE_IMPORT = 138,
+ BuiltinOperator_HASHTABLE_SIZE = 139,
+ BuiltinOperator_REDUCE_ALL = 140,
+ BuiltinOperator_CONV_3D_TRANSPOSE = 141,
+ BuiltinOperator_VAR_HANDLE = 142,
+ BuiltinOperator_READ_VARIABLE = 143,
+ BuiltinOperator_ASSIGN_VARIABLE = 144,
+ BuiltinOperator_BROADCAST_ARGS = 145,
+ BuiltinOperator_RANDOM_STANDARD_NORMAL = 146,
+ BuiltinOperator_BUCKETIZE = 147,
+ BuiltinOperator_RANDOM_UNIFORM = 148,
+ BuiltinOperator_MULTINOMIAL = 149,
+ BuiltinOperator_GELU = 150,
+ BuiltinOperator_DYNAMIC_UPDATE_SLICE = 151,
+ BuiltinOperator_RELU_0_TO_1 = 152,
+ BuiltinOperator_UNSORTED_SEGMENT_PROD = 153,
+ BuiltinOperator_UNSORTED_SEGMENT_MAX = 154,
+ BuiltinOperator_UNSORTED_SEGMENT_SUM = 155,
+ BuiltinOperator_ATAN2 = 156,
+ BuiltinOperator_MIN = BuiltinOperator_ADD,
+ BuiltinOperator_MAX = BuiltinOperator_ATAN2
+};
+
+inline const BuiltinOperator (&EnumValuesBuiltinOperator())[157]
+{
+ static const BuiltinOperator values[] = {BuiltinOperator_ADD,
+ BuiltinOperator_AVERAGE_POOL_2D,
+ BuiltinOperator_CONCATENATION,
+ BuiltinOperator_CONV_2D,
+ BuiltinOperator_DEPTHWISE_CONV_2D,
+ BuiltinOperator_DEPTH_TO_SPACE,
+ BuiltinOperator_DEQUANTIZE,
+ BuiltinOperator_EMBEDDING_LOOKUP,
+ BuiltinOperator_FLOOR,
+ BuiltinOperator_FULLY_CONNECTED,
+ BuiltinOperator_HASHTABLE_LOOKUP,
+ BuiltinOperator_L2_NORMALIZATION,
+ BuiltinOperator_L2_POOL_2D,
+ BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
+ BuiltinOperator_LOGISTIC,
+ BuiltinOperator_LSH_PROJECTION,
+ BuiltinOperator_LSTM,
+ BuiltinOperator_MAX_POOL_2D,
+ BuiltinOperator_MUL,
+ BuiltinOperator_RELU,
+ BuiltinOperator_RELU_N1_TO_1,
+ BuiltinOperator_RELU6,
+ BuiltinOperator_RESHAPE,
+ BuiltinOperator_RESIZE_BILINEAR,
+ BuiltinOperator_RNN,
+ BuiltinOperator_SOFTMAX,
+ BuiltinOperator_SPACE_TO_DEPTH,
+ BuiltinOperator_SVDF,
+ BuiltinOperator_TANH,
+ BuiltinOperator_CONCAT_EMBEDDINGS,
+ BuiltinOperator_SKIP_GRAM,
+ BuiltinOperator_CALL,
+ BuiltinOperator_CUSTOM,
+ BuiltinOperator_EMBEDDING_LOOKUP_SPARSE,
+ BuiltinOperator_PAD,
+ BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
+ BuiltinOperator_GATHER,
+ BuiltinOperator_BATCH_TO_SPACE_ND,
+ BuiltinOperator_SPACE_TO_BATCH_ND,
+ BuiltinOperator_TRANSPOSE,
+ BuiltinOperator_MEAN,
+ BuiltinOperator_SUB,
+ BuiltinOperator_DIV,
+ BuiltinOperator_SQUEEZE,
+ BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
+ BuiltinOperator_STRIDED_SLICE,
+ BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
+ BuiltinOperator_EXP,
+ BuiltinOperator_TOPK_V2,
+ BuiltinOperator_SPLIT,
+ BuiltinOperator_LOG_SOFTMAX,
+ BuiltinOperator_DELEGATE,
+ BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
+ BuiltinOperator_CAST,
+ BuiltinOperator_PRELU,
+ BuiltinOperator_MAXIMUM,
+ BuiltinOperator_ARG_MAX,
+ BuiltinOperator_MINIMUM,
+ BuiltinOperator_LESS,
+ BuiltinOperator_NEG,
+ BuiltinOperator_PADV2,
+ BuiltinOperator_GREATER,
+ BuiltinOperator_GREATER_EQUAL,
+ BuiltinOperator_LESS_EQUAL,
+ BuiltinOperator_SELECT,
+ BuiltinOperator_SLICE,
+ BuiltinOperator_SIN,
+ BuiltinOperator_TRANSPOSE_CONV,
+ BuiltinOperator_SPARSE_TO_DENSE,
+ BuiltinOperator_TILE,
+ BuiltinOperator_EXPAND_DIMS,
+ BuiltinOperator_EQUAL,
+ BuiltinOperator_NOT_EQUAL,
+ BuiltinOperator_LOG,
+ BuiltinOperator_SUM,
+ BuiltinOperator_SQRT,
+ BuiltinOperator_RSQRT,
+ BuiltinOperator_SHAPE,
+ BuiltinOperator_POW,
+ BuiltinOperator_ARG_MIN,
+ BuiltinOperator_FAKE_QUANT,
+ BuiltinOperator_REDUCE_PROD,
+ BuiltinOperator_REDUCE_MAX,
+ BuiltinOperator_PACK,
+ BuiltinOperator_LOGICAL_OR,
+ BuiltinOperator_ONE_HOT,
+ BuiltinOperator_LOGICAL_AND,
+ BuiltinOperator_LOGICAL_NOT,
+ BuiltinOperator_UNPACK,
+ BuiltinOperator_REDUCE_MIN,
+ BuiltinOperator_FLOOR_DIV,
+ BuiltinOperator_REDUCE_ANY,
+ BuiltinOperator_SQUARE,
+ BuiltinOperator_ZEROS_LIKE,
+ BuiltinOperator_FILL,
+ BuiltinOperator_FLOOR_MOD,
+ BuiltinOperator_RANGE,
+ BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
+ BuiltinOperator_LEAKY_RELU,
+ BuiltinOperator_SQUARED_DIFFERENCE,
+ BuiltinOperator_MIRROR_PAD,
+ BuiltinOperator_ABS,
+ BuiltinOperator_SPLIT_V,
+ BuiltinOperator_UNIQUE,
+ BuiltinOperator_CEIL,
+ BuiltinOperator_REVERSE_V2,
+ BuiltinOperator_ADD_N,
+ BuiltinOperator_GATHER_ND,
+ BuiltinOperator_COS,
+ BuiltinOperator_WHERE,
+ BuiltinOperator_RANK,
+ BuiltinOperator_ELU,
+ BuiltinOperator_REVERSE_SEQUENCE,
+ BuiltinOperator_MATRIX_DIAG,
+ BuiltinOperator_QUANTIZE,
+ BuiltinOperator_MATRIX_SET_DIAG,
+ BuiltinOperator_ROUND,
+ BuiltinOperator_HARD_SWISH,
+ BuiltinOperator_IF,
+ BuiltinOperator_WHILE,
+ BuiltinOperator_NON_MAX_SUPPRESSION_V4,
+ BuiltinOperator_NON_MAX_SUPPRESSION_V5,
+ BuiltinOperator_SCATTER_ND,
+ BuiltinOperator_SELECT_V2,
+ BuiltinOperator_DENSIFY,
+ BuiltinOperator_SEGMENT_SUM,
+ BuiltinOperator_BATCH_MATMUL,
+ BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES,
+ BuiltinOperator_CUMSUM,
+ BuiltinOperator_CALL_ONCE,
+ BuiltinOperator_BROADCAST_TO,
+ BuiltinOperator_RFFT2D,
+ BuiltinOperator_CONV_3D,
+ BuiltinOperator_IMAG,
+ BuiltinOperator_REAL,
+ BuiltinOperator_COMPLEX_ABS,
+ BuiltinOperator_HASHTABLE,
+ BuiltinOperator_HASHTABLE_FIND,
+ BuiltinOperator_HASHTABLE_IMPORT,
+ BuiltinOperator_HASHTABLE_SIZE,
+ BuiltinOperator_REDUCE_ALL,
+ BuiltinOperator_CONV_3D_TRANSPOSE,
+ BuiltinOperator_VAR_HANDLE,
+ BuiltinOperator_READ_VARIABLE,
+ BuiltinOperator_ASSIGN_VARIABLE,
+ BuiltinOperator_BROADCAST_ARGS,
+ BuiltinOperator_RANDOM_STANDARD_NORMAL,
+ BuiltinOperator_BUCKETIZE,
+ BuiltinOperator_RANDOM_UNIFORM,
+ BuiltinOperator_MULTINOMIAL,
+ BuiltinOperator_GELU,
+ BuiltinOperator_DYNAMIC_UPDATE_SLICE,
+ BuiltinOperator_RELU_0_TO_1,
+ BuiltinOperator_UNSORTED_SEGMENT_PROD,
+ BuiltinOperator_UNSORTED_SEGMENT_MAX,
+ BuiltinOperator_UNSORTED_SEGMENT_SUM,
+ BuiltinOperator_ATAN2};
+ return values;
+}
+
+inline const char *const *EnumNamesBuiltinOperator()
+{
+ static const char *const names[158] = {"ADD",
+ "AVERAGE_POOL_2D",
+ "CONCATENATION",
+ "CONV_2D",
+ "DEPTHWISE_CONV_2D",
+ "DEPTH_TO_SPACE",
+ "DEQUANTIZE",
+ "EMBEDDING_LOOKUP",
+ "FLOOR",
+ "FULLY_CONNECTED",
+ "HASHTABLE_LOOKUP",
+ "L2_NORMALIZATION",
+ "L2_POOL_2D",
+ "LOCAL_RESPONSE_NORMALIZATION",
+ "LOGISTIC",
+ "LSH_PROJECTION",
+ "LSTM",
+ "MAX_POOL_2D",
+ "MUL",
+ "RELU",
+ "RELU_N1_TO_1",
+ "RELU6",
+ "RESHAPE",
+ "RESIZE_BILINEAR",
+ "RNN",
+ "SOFTMAX",
+ "SPACE_TO_DEPTH",
+ "SVDF",
+ "TANH",
+ "CONCAT_EMBEDDINGS",
+ "SKIP_GRAM",
+ "CALL",
+ "CUSTOM",
+ "EMBEDDING_LOOKUP_SPARSE",
+ "PAD",
+ "UNIDIRECTIONAL_SEQUENCE_RNN",
+ "GATHER",
+ "BATCH_TO_SPACE_ND",
+ "SPACE_TO_BATCH_ND",
+ "TRANSPOSE",
+ "MEAN",
+ "SUB",
+ "DIV",
+ "SQUEEZE",
+ "UNIDIRECTIONAL_SEQUENCE_LSTM",
+ "STRIDED_SLICE",
+ "BIDIRECTIONAL_SEQUENCE_RNN",
+ "EXP",
+ "TOPK_V2",
+ "SPLIT",
+ "LOG_SOFTMAX",
+ "DELEGATE",
+ "BIDIRECTIONAL_SEQUENCE_LSTM",
+ "CAST",
+ "PRELU",
+ "MAXIMUM",
+ "ARG_MAX",
+ "MINIMUM",
+ "LESS",
+ "NEG",
+ "PADV2",
+ "GREATER",
+ "GREATER_EQUAL",
+ "LESS_EQUAL",
+ "SELECT",
+ "SLICE",
+ "SIN",
+ "TRANSPOSE_CONV",
+ "SPARSE_TO_DENSE",
+ "TILE",
+ "EXPAND_DIMS",
+ "EQUAL",
+ "NOT_EQUAL",
+ "LOG",
+ "SUM",
+ "SQRT",
+ "RSQRT",
+ "SHAPE",
+ "POW",
+ "ARG_MIN",
+ "FAKE_QUANT",
+ "REDUCE_PROD",
+ "REDUCE_MAX",
+ "PACK",
+ "LOGICAL_OR",
+ "ONE_HOT",
+ "LOGICAL_AND",
+ "LOGICAL_NOT",
+ "UNPACK",
+ "REDUCE_MIN",
+ "FLOOR_DIV",
+ "REDUCE_ANY",
+ "SQUARE",
+ "ZEROS_LIKE",
+ "FILL",
+ "FLOOR_MOD",
+ "RANGE",
+ "RESIZE_NEAREST_NEIGHBOR",
+ "LEAKY_RELU",
+ "SQUARED_DIFFERENCE",
+ "MIRROR_PAD",
+ "ABS",
+ "SPLIT_V",
+ "UNIQUE",
+ "CEIL",
+ "REVERSE_V2",
+ "ADD_N",
+ "GATHER_ND",
+ "COS",
+ "WHERE",
+ "RANK",
+ "ELU",
+ "REVERSE_SEQUENCE",
+ "MATRIX_DIAG",
+ "QUANTIZE",
+ "MATRIX_SET_DIAG",
+ "ROUND",
+ "HARD_SWISH",
+ "IF",
+ "WHILE",
+ "NON_MAX_SUPPRESSION_V4",
+ "NON_MAX_SUPPRESSION_V5",
+ "SCATTER_ND",
+ "SELECT_V2",
+ "DENSIFY",
+ "SEGMENT_SUM",
+ "BATCH_MATMUL",
+ "PLACEHOLDER_FOR_GREATER_OP_CODES",
+ "CUMSUM",
+ "CALL_ONCE",
+ "BROADCAST_TO",
+ "RFFT2D",
+ "CONV_3D",
+ "IMAG",
+ "REAL",
+ "COMPLEX_ABS",
+ "HASHTABLE",
+ "HASHTABLE_FIND",
+ "HASHTABLE_IMPORT",
+ "HASHTABLE_SIZE",
+ "REDUCE_ALL",
+ "CONV_3D_TRANSPOSE",
+ "VAR_HANDLE",
+ "READ_VARIABLE",
+ "ASSIGN_VARIABLE",
+ "BROADCAST_ARGS",
+ "RANDOM_STANDARD_NORMAL",
+ "BUCKETIZE",
+ "RANDOM_UNIFORM",
+ "MULTINOMIAL",
+ "GELU",
+ "DYNAMIC_UPDATE_SLICE",
+ "RELU_0_TO_1",
+ "UNSORTED_SEGMENT_PROD",
+ "UNSORTED_SEGMENT_MAX",
+ "UNSORTED_SEGMENT_SUM",
+ "ATAN2",
+ nullptr};
+ return names;
+}
+
+inline const char *EnumNameBuiltinOperator(BuiltinOperator e)
+{
+ if (::flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_ATAN2))
+ return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesBuiltinOperator()[index];
+}
+
+enum BuiltinOptions : uint8_t
+{
+ BuiltinOptions_NONE = 0,
+ BuiltinOptions_Conv2DOptions = 1,
+ BuiltinOptions_DepthwiseConv2DOptions = 2,
+ BuiltinOptions_ConcatEmbeddingsOptions = 3,
+ BuiltinOptions_LSHProjectionOptions = 4,
+ BuiltinOptions_Pool2DOptions = 5,
+ BuiltinOptions_SVDFOptions = 6,
+ BuiltinOptions_RNNOptions = 7,
+ BuiltinOptions_FullyConnectedOptions = 8,
+ BuiltinOptions_SoftmaxOptions = 9,
+ BuiltinOptions_ConcatenationOptions = 10,
+ BuiltinOptions_AddOptions = 11,
+ BuiltinOptions_L2NormOptions = 12,
+ BuiltinOptions_LocalResponseNormalizationOptions = 13,
+ BuiltinOptions_LSTMOptions = 14,
+ BuiltinOptions_ResizeBilinearOptions = 15,
+ BuiltinOptions_CallOptions = 16,
+ BuiltinOptions_ReshapeOptions = 17,
+ BuiltinOptions_SkipGramOptions = 18,
+ BuiltinOptions_SpaceToDepthOptions = 19,
+ BuiltinOptions_EmbeddingLookupSparseOptions = 20,
+ BuiltinOptions_MulOptions = 21,
+ BuiltinOptions_PadOptions = 22,
+ BuiltinOptions_GatherOptions = 23,
+ BuiltinOptions_BatchToSpaceNDOptions = 24,
+ BuiltinOptions_SpaceToBatchNDOptions = 25,
+ BuiltinOptions_TransposeOptions = 26,
+ BuiltinOptions_ReducerOptions = 27,
+ BuiltinOptions_SubOptions = 28,
+ BuiltinOptions_DivOptions = 29,
+ BuiltinOptions_SqueezeOptions = 30,
+ BuiltinOptions_SequenceRNNOptions = 31,
+ BuiltinOptions_StridedSliceOptions = 32,
+ BuiltinOptions_ExpOptions = 33,
+ BuiltinOptions_TopKV2Options = 34,
+ BuiltinOptions_SplitOptions = 35,
+ BuiltinOptions_LogSoftmaxOptions = 36,
+ BuiltinOptions_CastOptions = 37,
+ BuiltinOptions_DequantizeOptions = 38,
+ BuiltinOptions_MaximumMinimumOptions = 39,
+ BuiltinOptions_ArgMaxOptions = 40,
+ BuiltinOptions_LessOptions = 41,
+ BuiltinOptions_NegOptions = 42,
+ BuiltinOptions_PadV2Options = 43,
+ BuiltinOptions_GreaterOptions = 44,
+ BuiltinOptions_GreaterEqualOptions = 45,
+ BuiltinOptions_LessEqualOptions = 46,
+ BuiltinOptions_SelectOptions = 47,
+ BuiltinOptions_SliceOptions = 48,
+ BuiltinOptions_TransposeConvOptions = 49,
+ BuiltinOptions_SparseToDenseOptions = 50,
+ BuiltinOptions_TileOptions = 51,
+ BuiltinOptions_ExpandDimsOptions = 52,
+ BuiltinOptions_EqualOptions = 53,
+ BuiltinOptions_NotEqualOptions = 54,
+ BuiltinOptions_ShapeOptions = 55,
+ BuiltinOptions_PowOptions = 56,
+ BuiltinOptions_ArgMinOptions = 57,
+ BuiltinOptions_FakeQuantOptions = 58,
+ BuiltinOptions_PackOptions = 59,
+ BuiltinOptions_LogicalOrOptions = 60,
+ BuiltinOptions_OneHotOptions = 61,
+ BuiltinOptions_LogicalAndOptions = 62,
+ BuiltinOptions_LogicalNotOptions = 63,
+ BuiltinOptions_UnpackOptions = 64,
+ BuiltinOptions_FloorDivOptions = 65,
+ BuiltinOptions_SquareOptions = 66,
+ BuiltinOptions_ZerosLikeOptions = 67,
+ BuiltinOptions_FillOptions = 68,
+ BuiltinOptions_BidirectionalSequenceLSTMOptions = 69,
+ BuiltinOptions_BidirectionalSequenceRNNOptions = 70,
+ BuiltinOptions_UnidirectionalSequenceLSTMOptions = 71,
+ BuiltinOptions_FloorModOptions = 72,
+ BuiltinOptions_RangeOptions = 73,
+ BuiltinOptions_ResizeNearestNeighborOptions = 74,
+ BuiltinOptions_LeakyReluOptions = 75,
+ BuiltinOptions_SquaredDifferenceOptions = 76,
+ BuiltinOptions_MirrorPadOptions = 77,
+ BuiltinOptions_AbsOptions = 78,
+ BuiltinOptions_SplitVOptions = 79,
+ BuiltinOptions_UniqueOptions = 80,
+ BuiltinOptions_ReverseV2Options = 81,
+ BuiltinOptions_AddNOptions = 82,
+ BuiltinOptions_GatherNdOptions = 83,
+ BuiltinOptions_CosOptions = 84,
+ BuiltinOptions_WhereOptions = 85,
+ BuiltinOptions_RankOptions = 86,
+ BuiltinOptions_ReverseSequenceOptions = 87,
+ BuiltinOptions_MatrixDiagOptions = 88,
+ BuiltinOptions_QuantizeOptions = 89,
+ BuiltinOptions_MatrixSetDiagOptions = 90,
+ BuiltinOptions_HardSwishOptions = 91,
+ BuiltinOptions_IfOptions = 92,
+ BuiltinOptions_WhileOptions = 93,
+ BuiltinOptions_DepthToSpaceOptions = 94,
+ BuiltinOptions_NonMaxSuppressionV4Options = 95,
+ BuiltinOptions_NonMaxSuppressionV5Options = 96,
+ BuiltinOptions_ScatterNdOptions = 97,
+ BuiltinOptions_SelectV2Options = 98,
+ BuiltinOptions_DensifyOptions = 99,
+ BuiltinOptions_SegmentSumOptions = 100,
+ BuiltinOptions_BatchMatMulOptions = 101,
+ BuiltinOptions_CumsumOptions = 102,
+ BuiltinOptions_CallOnceOptions = 103,
+ BuiltinOptions_BroadcastToOptions = 104,
+ BuiltinOptions_Rfft2dOptions = 105,
+ BuiltinOptions_Conv3DOptions = 106,
+ BuiltinOptions_HashtableOptions = 107,
+ BuiltinOptions_HashtableFindOptions = 108,
+ BuiltinOptions_HashtableImportOptions = 109,
+ BuiltinOptions_HashtableSizeOptions = 110,
+ BuiltinOptions_VarHandleOptions = 111,
+ BuiltinOptions_ReadVariableOptions = 112,
+ BuiltinOptions_AssignVariableOptions = 113,
+ BuiltinOptions_RandomOptions = 114,
+ BuiltinOptions_BucketizeOptions = 115,
+ BuiltinOptions_GeluOptions = 116,
+ BuiltinOptions_DynamicUpdateSliceOptions = 117,
+ BuiltinOptions_UnsortedSegmentProdOptions = 118,
+ BuiltinOptions_UnsortedSegmentMaxOptions = 119,
+ BuiltinOptions_UnsortedSegmentSumOptions = 120,
+ BuiltinOptions_ATan2Options = 121,
+ BuiltinOptions_MIN = BuiltinOptions_NONE,
+ BuiltinOptions_MAX = BuiltinOptions_ATan2Options
+};
+
+inline const BuiltinOptions (&EnumValuesBuiltinOptions())[122]
+{
+ static const BuiltinOptions values[] = {BuiltinOptions_NONE,
+ BuiltinOptions_Conv2DOptions,
+ BuiltinOptions_DepthwiseConv2DOptions,
+ BuiltinOptions_ConcatEmbeddingsOptions,
+ BuiltinOptions_LSHProjectionOptions,
+ BuiltinOptions_Pool2DOptions,
+ BuiltinOptions_SVDFOptions,
+ BuiltinOptions_RNNOptions,
+ BuiltinOptions_FullyConnectedOptions,
+ BuiltinOptions_SoftmaxOptions,
+ BuiltinOptions_ConcatenationOptions,
+ BuiltinOptions_AddOptions,
+ BuiltinOptions_L2NormOptions,
+ BuiltinOptions_LocalResponseNormalizationOptions,
+ BuiltinOptions_LSTMOptions,
+ BuiltinOptions_ResizeBilinearOptions,
+ BuiltinOptions_CallOptions,
+ BuiltinOptions_ReshapeOptions,
+ BuiltinOptions_SkipGramOptions,
+ BuiltinOptions_SpaceToDepthOptions,
+ BuiltinOptions_EmbeddingLookupSparseOptions,
+ BuiltinOptions_MulOptions,
+ BuiltinOptions_PadOptions,
+ BuiltinOptions_GatherOptions,
+ BuiltinOptions_BatchToSpaceNDOptions,
+ BuiltinOptions_SpaceToBatchNDOptions,
+ BuiltinOptions_TransposeOptions,
+ BuiltinOptions_ReducerOptions,
+ BuiltinOptions_SubOptions,
+ BuiltinOptions_DivOptions,
+ BuiltinOptions_SqueezeOptions,
+ BuiltinOptions_SequenceRNNOptions,
+ BuiltinOptions_StridedSliceOptions,
+ BuiltinOptions_ExpOptions,
+ BuiltinOptions_TopKV2Options,
+ BuiltinOptions_SplitOptions,
+ BuiltinOptions_LogSoftmaxOptions,
+ BuiltinOptions_CastOptions,
+ BuiltinOptions_DequantizeOptions,
+ BuiltinOptions_MaximumMinimumOptions,
+ BuiltinOptions_ArgMaxOptions,
+ BuiltinOptions_LessOptions,
+ BuiltinOptions_NegOptions,
+ BuiltinOptions_PadV2Options,
+ BuiltinOptions_GreaterOptions,
+ BuiltinOptions_GreaterEqualOptions,
+ BuiltinOptions_LessEqualOptions,
+ BuiltinOptions_SelectOptions,
+ BuiltinOptions_SliceOptions,
+ BuiltinOptions_TransposeConvOptions,
+ BuiltinOptions_SparseToDenseOptions,
+ BuiltinOptions_TileOptions,
+ BuiltinOptions_ExpandDimsOptions,
+ BuiltinOptions_EqualOptions,
+ BuiltinOptions_NotEqualOptions,
+ BuiltinOptions_ShapeOptions,
+ BuiltinOptions_PowOptions,
+ BuiltinOptions_ArgMinOptions,
+ BuiltinOptions_FakeQuantOptions,
+ BuiltinOptions_PackOptions,
+ BuiltinOptions_LogicalOrOptions,
+ BuiltinOptions_OneHotOptions,
+ BuiltinOptions_LogicalAndOptions,
+ BuiltinOptions_LogicalNotOptions,
+ BuiltinOptions_UnpackOptions,
+ BuiltinOptions_FloorDivOptions,
+ BuiltinOptions_SquareOptions,
+ BuiltinOptions_ZerosLikeOptions,
+ BuiltinOptions_FillOptions,
+ BuiltinOptions_BidirectionalSequenceLSTMOptions,
+ BuiltinOptions_BidirectionalSequenceRNNOptions,
+ BuiltinOptions_UnidirectionalSequenceLSTMOptions,
+ BuiltinOptions_FloorModOptions,
+ BuiltinOptions_RangeOptions,
+ BuiltinOptions_ResizeNearestNeighborOptions,
+ BuiltinOptions_LeakyReluOptions,
+ BuiltinOptions_SquaredDifferenceOptions,
+ BuiltinOptions_MirrorPadOptions,
+ BuiltinOptions_AbsOptions,
+ BuiltinOptions_SplitVOptions,
+ BuiltinOptions_UniqueOptions,
+ BuiltinOptions_ReverseV2Options,
+ BuiltinOptions_AddNOptions,
+ BuiltinOptions_GatherNdOptions,
+ BuiltinOptions_CosOptions,
+ BuiltinOptions_WhereOptions,
+ BuiltinOptions_RankOptions,
+ BuiltinOptions_ReverseSequenceOptions,
+ BuiltinOptions_MatrixDiagOptions,
+ BuiltinOptions_QuantizeOptions,
+ BuiltinOptions_MatrixSetDiagOptions,
+ BuiltinOptions_HardSwishOptions,
+ BuiltinOptions_IfOptions,
+ BuiltinOptions_WhileOptions,
+ BuiltinOptions_DepthToSpaceOptions,
+ BuiltinOptions_NonMaxSuppressionV4Options,
+ BuiltinOptions_NonMaxSuppressionV5Options,
+ BuiltinOptions_ScatterNdOptions,
+ BuiltinOptions_SelectV2Options,
+ BuiltinOptions_DensifyOptions,
+ BuiltinOptions_SegmentSumOptions,
+ BuiltinOptions_BatchMatMulOptions,
+ BuiltinOptions_CumsumOptions,
+ BuiltinOptions_CallOnceOptions,
+ BuiltinOptions_BroadcastToOptions,
+ BuiltinOptions_Rfft2dOptions,
+ BuiltinOptions_Conv3DOptions,
+ BuiltinOptions_HashtableOptions,
+ BuiltinOptions_HashtableFindOptions,
+ BuiltinOptions_HashtableImportOptions,
+ BuiltinOptions_HashtableSizeOptions,
+ BuiltinOptions_VarHandleOptions,
+ BuiltinOptions_ReadVariableOptions,
+ BuiltinOptions_AssignVariableOptions,
+ BuiltinOptions_RandomOptions,
+ BuiltinOptions_BucketizeOptions,
+ BuiltinOptions_GeluOptions,
+ BuiltinOptions_DynamicUpdateSliceOptions,
+ BuiltinOptions_UnsortedSegmentProdOptions,
+ BuiltinOptions_UnsortedSegmentMaxOptions,
+ BuiltinOptions_UnsortedSegmentSumOptions,
+ BuiltinOptions_ATan2Options};
+ return values;
+}
+
+inline const char *const *EnumNamesBuiltinOptions()
+{
+ static const char *const names[123] = {"NONE",
+ "Conv2DOptions",
+ "DepthwiseConv2DOptions",
+ "ConcatEmbeddingsOptions",
+ "LSHProjectionOptions",
+ "Pool2DOptions",
+ "SVDFOptions",
+ "RNNOptions",
+ "FullyConnectedOptions",
+ "SoftmaxOptions",
+ "ConcatenationOptions",
+ "AddOptions",
+ "L2NormOptions",
+ "LocalResponseNormalizationOptions",
+ "LSTMOptions",
+ "ResizeBilinearOptions",
+ "CallOptions",
+ "ReshapeOptions",
+ "SkipGramOptions",
+ "SpaceToDepthOptions",
+ "EmbeddingLookupSparseOptions",
+ "MulOptions",
+ "PadOptions",
+ "GatherOptions",
+ "BatchToSpaceNDOptions",
+ "SpaceToBatchNDOptions",
+ "TransposeOptions",
+ "ReducerOptions",
+ "SubOptions",
+ "DivOptions",
+ "SqueezeOptions",
+ "SequenceRNNOptions",
+ "StridedSliceOptions",
+ "ExpOptions",
+ "TopKV2Options",
+ "SplitOptions",
+ "LogSoftmaxOptions",
+ "CastOptions",
+ "DequantizeOptions",
+ "MaximumMinimumOptions",
+ "ArgMaxOptions",
+ "LessOptions",
+ "NegOptions",
+ "PadV2Options",
+ "GreaterOptions",
+ "GreaterEqualOptions",
+ "LessEqualOptions",
+ "SelectOptions",
+ "SliceOptions",
+ "TransposeConvOptions",
+ "SparseToDenseOptions",
+ "TileOptions",
+ "ExpandDimsOptions",
+ "EqualOptions",
+ "NotEqualOptions",
+ "ShapeOptions",
+ "PowOptions",
+ "ArgMinOptions",
+ "FakeQuantOptions",
+ "PackOptions",
+ "LogicalOrOptions",
+ "OneHotOptions",
+ "LogicalAndOptions",
+ "LogicalNotOptions",
+ "UnpackOptions",
+ "FloorDivOptions",
+ "SquareOptions",
+ "ZerosLikeOptions",
+ "FillOptions",
+ "BidirectionalSequenceLSTMOptions",
+ "BidirectionalSequenceRNNOptions",
+ "UnidirectionalSequenceLSTMOptions",
+ "FloorModOptions",
+ "RangeOptions",
+ "ResizeNearestNeighborOptions",
+ "LeakyReluOptions",
+ "SquaredDifferenceOptions",
+ "MirrorPadOptions",
+ "AbsOptions",
+ "SplitVOptions",
+ "UniqueOptions",
+ "ReverseV2Options",
+ "AddNOptions",
+ "GatherNdOptions",
+ "CosOptions",
+ "WhereOptions",
+ "RankOptions",
+ "ReverseSequenceOptions",
+ "MatrixDiagOptions",
+ "QuantizeOptions",
+ "MatrixSetDiagOptions",
+ "HardSwishOptions",
+ "IfOptions",
+ "WhileOptions",
+ "DepthToSpaceOptions",
+ "NonMaxSuppressionV4Options",
+ "NonMaxSuppressionV5Options",
+ "ScatterNdOptions",
+ "SelectV2Options",
+ "DensifyOptions",
+ "SegmentSumOptions",
+ "BatchMatMulOptions",
+ "CumsumOptions",
+ "CallOnceOptions",
+ "BroadcastToOptions",
+ "Rfft2dOptions",
+ "Conv3DOptions",
+ "HashtableOptions",
+ "HashtableFindOptions",
+ "HashtableImportOptions",
+ "HashtableSizeOptions",
+ "VarHandleOptions",
+ "ReadVariableOptions",
+ "AssignVariableOptions",
+ "RandomOptions",
+ "BucketizeOptions",
+ "GeluOptions",
+ "DynamicUpdateSliceOptions",
+ "UnsortedSegmentProdOptions",
+ "UnsortedSegmentMaxOptions",
+ "UnsortedSegmentSumOptions",
+ "ATan2Options",
+ nullptr};
+ return names;
+}
+
+inline const char *EnumNameBuiltinOptions(BuiltinOptions e)
+{
+ if (::flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_ATan2Options))
+ return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesBuiltinOptions()[index];
+}
+
+template <typename T> struct BuiltinOptionsTraits
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_NONE;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::Conv2DOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_Conv2DOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::DepthwiseConv2DOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_DepthwiseConv2DOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ConcatEmbeddingsOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ConcatEmbeddingsOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::LSHProjectionOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_LSHProjectionOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::Pool2DOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_Pool2DOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SVDFOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SVDFOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::RNNOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_RNNOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::FullyConnectedOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_FullyConnectedOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SoftmaxOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SoftmaxOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ConcatenationOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ConcatenationOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::AddOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_AddOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::L2NormOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_L2NormOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::LocalResponseNormalizationOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_LocalResponseNormalizationOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::LSTMOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_LSTMOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ResizeBilinearOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ResizeBilinearOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::CallOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_CallOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ReshapeOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ReshapeOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SkipGramOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SkipGramOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SpaceToDepthOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SpaceToDepthOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::EmbeddingLookupSparseOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_EmbeddingLookupSparseOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::MulOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_MulOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::PadOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_PadOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::GatherOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_GatherOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::BatchToSpaceNDOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_BatchToSpaceNDOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SpaceToBatchNDOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SpaceToBatchNDOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::TransposeOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ReducerOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ReducerOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SubOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SubOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::DivOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_DivOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SqueezeOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SqueezeOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SequenceRNNOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SequenceRNNOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::StridedSliceOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_StridedSliceOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ExpOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ExpOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::TopKV2Options>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_TopKV2Options;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SplitOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SplitOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::LogSoftmaxOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_LogSoftmaxOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::CastOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_CastOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::DequantizeOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_DequantizeOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::MaximumMinimumOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_MaximumMinimumOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ArgMaxOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ArgMaxOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::LessOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_LessOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::NegOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_NegOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::PadV2Options>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_PadV2Options;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::GreaterOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_GreaterOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::GreaterEqualOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_GreaterEqualOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::LessEqualOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_LessEqualOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SelectOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SelectOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SliceOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SliceOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::TransposeConvOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_TransposeConvOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SparseToDenseOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SparseToDenseOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::TileOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_TileOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ExpandDimsOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ExpandDimsOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::EqualOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_EqualOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::NotEqualOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ShapeOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ShapeOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::PowOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_PowOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ArgMinOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ArgMinOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::FakeQuantOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_FakeQuantOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::PackOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_PackOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::LogicalOrOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::OneHotOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::LogicalAndOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_LogicalAndOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::LogicalNotOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_LogicalNotOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::UnpackOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_UnpackOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::FloorDivOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_FloorDivOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SquareOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SquareOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ZerosLikeOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ZerosLikeOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::FillOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_FillOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::BidirectionalSequenceLSTMOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceLSTMOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::BidirectionalSequenceRNNOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceRNNOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::UnidirectionalSequenceLSTMOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_UnidirectionalSequenceLSTMOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::FloorModOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_FloorModOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::RangeOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_RangeOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ResizeNearestNeighborOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ResizeNearestNeighborOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::LeakyReluOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_LeakyReluOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SquaredDifferenceOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SquaredDifferenceOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::MirrorPadOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_MirrorPadOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::AbsOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_AbsOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SplitVOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SplitVOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::UniqueOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_UniqueOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ReverseV2Options>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ReverseV2Options;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::AddNOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_AddNOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::GatherNdOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_GatherNdOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::CosOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_CosOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::WhereOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_WhereOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::RankOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_RankOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ReverseSequenceOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ReverseSequenceOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::MatrixDiagOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_MatrixDiagOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::QuantizeOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_QuantizeOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::MatrixSetDiagOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_MatrixSetDiagOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::HardSwishOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_HardSwishOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::IfOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_IfOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::WhileOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_WhileOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::DepthToSpaceOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_DepthToSpaceOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::NonMaxSuppressionV4Options>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV4Options;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::NonMaxSuppressionV5Options>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV5Options;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ScatterNdOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ScatterNdOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SelectV2Options>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SelectV2Options;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::DensifyOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_DensifyOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::SegmentSumOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_SegmentSumOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::BatchMatMulOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_BatchMatMulOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::CumsumOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_CumsumOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::CallOnceOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_CallOnceOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::BroadcastToOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_BroadcastToOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::Rfft2dOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_Rfft2dOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::Conv3DOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_Conv3DOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::HashtableOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_HashtableOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::HashtableFindOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_HashtableFindOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::HashtableImportOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_HashtableImportOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::HashtableSizeOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_HashtableSizeOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::VarHandleOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_VarHandleOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ReadVariableOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ReadVariableOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::AssignVariableOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_AssignVariableOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::RandomOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_RandomOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::BucketizeOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_BucketizeOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::GeluOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_GeluOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::DynamicUpdateSliceOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_DynamicUpdateSliceOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::UnsortedSegmentProdOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentProdOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::UnsortedSegmentMaxOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentMaxOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::UnsortedSegmentSumOptions>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentSumOptions;
+};
+
+template <> struct BuiltinOptionsTraits<onert_tflite::ATan2Options>
+{
+ static const BuiltinOptions enum_value = BuiltinOptions_ATan2Options;
+};
+
+bool VerifyBuiltinOptions(::flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
+bool VerifyBuiltinOptionsVector(::flatbuffers::Verifier &verifier,
+ const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values,
+ const ::flatbuffers::Vector<uint8_t> *types);
+
+enum Padding : int8_t
+{
+ Padding_SAME = 0,
+ Padding_VALID = 1,
+ Padding_MIN = Padding_SAME,
+ Padding_MAX = Padding_VALID
+};
+
+inline const Padding (&EnumValuesPadding())[2]
+{
+ static const Padding values[] = {Padding_SAME, Padding_VALID};
+ return values;
+}
+
+inline const char *const *EnumNamesPadding()
+{
+ static const char *const names[3] = {"SAME", "VALID", nullptr};
+ return names;
+}
+
+inline const char *EnumNamePadding(Padding e)
+{
+ if (::flatbuffers::IsOutRange(e, Padding_SAME, Padding_VALID))
+ return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesPadding()[index];
+}
+
+enum ActivationFunctionType : int8_t
+{
+ ActivationFunctionType_NONE = 0,
+ ActivationFunctionType_RELU = 1,
+ ActivationFunctionType_RELU_N1_TO_1 = 2,
+ ActivationFunctionType_RELU6 = 3,
+ ActivationFunctionType_TANH = 4,
+ ActivationFunctionType_SIGN_BIT = 5,
+ ActivationFunctionType_MIN = ActivationFunctionType_NONE,
+ ActivationFunctionType_MAX = ActivationFunctionType_SIGN_BIT
+};
+
+inline const ActivationFunctionType (&EnumValuesActivationFunctionType())[6]
+{
+ static const ActivationFunctionType values[] = {
+ ActivationFunctionType_NONE, ActivationFunctionType_RELU, ActivationFunctionType_RELU_N1_TO_1,
+ ActivationFunctionType_RELU6, ActivationFunctionType_TANH, ActivationFunctionType_SIGN_BIT};
+ return values;
+}
+
+inline const char *const *EnumNamesActivationFunctionType()
+{
+ static const char *const names[7] = {"NONE", "RELU", "RELU_N1_TO_1", "RELU6",
+ "TANH", "SIGN_BIT", nullptr};
+ return names;
+}
+
+inline const char *EnumNameActivationFunctionType(ActivationFunctionType e)
+{
+ if (::flatbuffers::IsOutRange(e, ActivationFunctionType_NONE, ActivationFunctionType_SIGN_BIT))
+ return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesActivationFunctionType()[index];
+}
+
+enum LSHProjectionType : int8_t
+{
+ LSHProjectionType_UNKNOWN = 0,
+ LSHProjectionType_SPARSE = 1,
+ LSHProjectionType_DENSE = 2,
+ LSHProjectionType_MIN = LSHProjectionType_UNKNOWN,
+ LSHProjectionType_MAX = LSHProjectionType_DENSE
+};
+
+inline const LSHProjectionType (&EnumValuesLSHProjectionType())[3]
+{
+ static const LSHProjectionType values[] = {LSHProjectionType_UNKNOWN, LSHProjectionType_SPARSE,
+ LSHProjectionType_DENSE};
+ return values;
+}
+
+inline const char *const *EnumNamesLSHProjectionType()
+{
+ static const char *const names[4] = {"UNKNOWN", "SPARSE", "DENSE", nullptr};
+ return names;
+}
+
+inline const char *EnumNameLSHProjectionType(LSHProjectionType e)
+{
+ if (::flatbuffers::IsOutRange(e, LSHProjectionType_UNKNOWN, LSHProjectionType_DENSE))
+ return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesLSHProjectionType()[index];
+}
+
+enum FullyConnectedOptionsWeightsFormat : int8_t
+{
+ FullyConnectedOptionsWeightsFormat_DEFAULT = 0,
+ FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 = 1,
+ FullyConnectedOptionsWeightsFormat_MIN = FullyConnectedOptionsWeightsFormat_DEFAULT,
+ FullyConnectedOptionsWeightsFormat_MAX = FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8
+};
+
+inline const FullyConnectedOptionsWeightsFormat (&EnumValuesFullyConnectedOptionsWeightsFormat())[2]
+{
+ static const FullyConnectedOptionsWeightsFormat values[] = {
+ FullyConnectedOptionsWeightsFormat_DEFAULT,
+ FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8};
+ return values;
+}
+
+inline const char *const *EnumNamesFullyConnectedOptionsWeightsFormat()
+{
+ static const char *const names[3] = {"DEFAULT", "SHUFFLED4x16INT8", nullptr};
+ return names;
+}
+
+inline const char *EnumNameFullyConnectedOptionsWeightsFormat(FullyConnectedOptionsWeightsFormat e)
+{
+ if (::flatbuffers::IsOutRange(e, FullyConnectedOptionsWeightsFormat_DEFAULT,
+ FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8))
+ return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesFullyConnectedOptionsWeightsFormat()[index];
+}
+
+enum LSTMKernelType : int8_t
+{
+ LSTMKernelType_FULL = 0,
+ LSTMKernelType_BASIC = 1,
+ LSTMKernelType_MIN = LSTMKernelType_FULL,
+ LSTMKernelType_MAX = LSTMKernelType_BASIC
+};
+
+inline const LSTMKernelType (&EnumValuesLSTMKernelType())[2]
+{
+ static const LSTMKernelType values[] = {LSTMKernelType_FULL, LSTMKernelType_BASIC};
+ return values;
+}
+
+inline const char *const *EnumNamesLSTMKernelType()
+{
+ static const char *const names[3] = {"FULL", "BASIC", nullptr};
+ return names;
+}
+
+inline const char *EnumNameLSTMKernelType(LSTMKernelType e)
+{
+ if (::flatbuffers::IsOutRange(e, LSTMKernelType_FULL, LSTMKernelType_BASIC))
+ return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesLSTMKernelType()[index];
+}
+
+enum CombinerType : int8_t
+{
+ CombinerType_SUM = 0,
+ CombinerType_MEAN = 1,
+ CombinerType_SQRTN = 2,
+ CombinerType_MIN = CombinerType_SUM,
+ CombinerType_MAX = CombinerType_SQRTN
+};
+
+inline const CombinerType (&EnumValuesCombinerType())[3]
+{
+ static const CombinerType values[] = {CombinerType_SUM, CombinerType_MEAN, CombinerType_SQRTN};
+ return values;
+}
+
+inline const char *const *EnumNamesCombinerType()
+{
+ static const char *const names[4] = {"SUM", "MEAN", "SQRTN", nullptr};
+ return names;
+}
+
+inline const char *EnumNameCombinerType(CombinerType e)
+{
+ if (::flatbuffers::IsOutRange(e, CombinerType_SUM, CombinerType_SQRTN))
+ return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesCombinerType()[index];
+}
+
+enum MirrorPadMode : int8_t
+{
+ MirrorPadMode_REFLECT = 0,
+ MirrorPadMode_SYMMETRIC = 1,
+ MirrorPadMode_MIN = MirrorPadMode_REFLECT,
+ MirrorPadMode_MAX = MirrorPadMode_SYMMETRIC
+};
+
+inline const MirrorPadMode (&EnumValuesMirrorPadMode())[2]
+{
+ static const MirrorPadMode values[] = {MirrorPadMode_REFLECT, MirrorPadMode_SYMMETRIC};
+ return values;
+}
+
+inline const char *const *EnumNamesMirrorPadMode()
+{
+ static const char *const names[3] = {"REFLECT", "SYMMETRIC", nullptr};
+ return names;
+}
+
+inline const char *EnumNameMirrorPadMode(MirrorPadMode e)
+{
+ if (::flatbuffers::IsOutRange(e, MirrorPadMode_REFLECT, MirrorPadMode_SYMMETRIC))
+ return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesMirrorPadMode()[index];
+}
+
+enum CustomOptionsFormat : int8_t
+{
+ CustomOptionsFormat_FLEXBUFFERS = 0,
+ CustomOptionsFormat_MIN = CustomOptionsFormat_FLEXBUFFERS,
+ CustomOptionsFormat_MAX = CustomOptionsFormat_FLEXBUFFERS
+};
+
+inline const CustomOptionsFormat (&EnumValuesCustomOptionsFormat())[1]
+{
+ static const CustomOptionsFormat values[] = {CustomOptionsFormat_FLEXBUFFERS};
+ return values;
+}
+
+inline const char *const *EnumNamesCustomOptionsFormat()
+{
+ static const char *const names[2] = {"FLEXBUFFERS", nullptr};
+ return names;
+}
+
+inline const char *EnumNameCustomOptionsFormat(CustomOptionsFormat e)
+{
+ if (::flatbuffers::IsOutRange(e, CustomOptionsFormat_FLEXBUFFERS,
+ CustomOptionsFormat_FLEXBUFFERS))
+ return "";
+ const size_t index = static_cast<size_t>(e);
+ return EnumNamesCustomOptionsFormat()[index];
+}
+
+struct CustomQuantization FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef CustomQuantizationBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_CUSTOM = 4
+ };
+ const ::flatbuffers::Vector<uint8_t> *custom() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_CUSTOM);
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_CUSTOM) &&
+ verifier.VerifyVector(custom()) && verifier.EndTable();
+ }
+};
+
+struct CustomQuantizationBuilder
+{
+ typedef CustomQuantization Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_custom(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> custom)
+ {
+ fbb_.AddOffset(CustomQuantization::VT_CUSTOM, custom);
+ }
+ explicit CustomQuantizationBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<CustomQuantization> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<CustomQuantization>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<CustomQuantization>
+CreateCustomQuantization(::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> custom = 0)
+{
+ CustomQuantizationBuilder builder_(_fbb);
+ builder_.add_custom(custom);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<CustomQuantization>
+CreateCustomQuantizationDirect(::flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<uint8_t> *custom = nullptr)
+{
+ if (custom)
+ {
+ _fbb.ForceVectorAlignment(custom->size(), sizeof(uint8_t), 16);
+ }
+ auto custom__ = custom ? _fbb.CreateVector<uint8_t>(*custom) : 0;
+ return onert_tflite::CreateCustomQuantization(_fbb, custom__);
+}
+
+struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef QuantizationParametersBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_MIN = 4,
+ VT_MAX = 6,
+ VT_SCALE = 8,
+ VT_ZERO_POINT = 10,
+ VT_DETAILS_TYPE = 12,
+ VT_DETAILS = 14,
+ VT_QUANTIZED_DIMENSION = 16
+ };
+ const ::flatbuffers::Vector<float> *min() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<float> *>(VT_MIN);
+ }
+ const ::flatbuffers::Vector<float> *max() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<float> *>(VT_MAX);
+ }
+ const ::flatbuffers::Vector<float> *scale() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<float> *>(VT_SCALE);
+ }
+ const ::flatbuffers::Vector<int64_t> *zero_point() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int64_t> *>(VT_ZERO_POINT);
+ }
+ onert_tflite::QuantizationDetails details_type() const
+ {
+ return static_cast<onert_tflite::QuantizationDetails>(GetField<uint8_t>(VT_DETAILS_TYPE, 0));
+ }
+ const void *details() const { return GetPointer<const void *>(VT_DETAILS); }
+ template <typename T> const T *details_as() const;
+ const onert_tflite::CustomQuantization *details_as_CustomQuantization() const
+ {
+ return details_type() == onert_tflite::QuantizationDetails_CustomQuantization
+ ? static_cast<const onert_tflite::CustomQuantization *>(details())
+ : nullptr;
+ }
+ int32_t quantized_dimension() const { return GetField<int32_t>(VT_QUANTIZED_DIMENSION, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_MIN) &&
+ verifier.VerifyVector(min()) && VerifyOffset(verifier, VT_MAX) &&
+ verifier.VerifyVector(max()) && VerifyOffset(verifier, VT_SCALE) &&
+ verifier.VerifyVector(scale()) && VerifyOffset(verifier, VT_ZERO_POINT) &&
+ verifier.VerifyVector(zero_point()) &&
+ VerifyField<uint8_t>(verifier, VT_DETAILS_TYPE, 1) &&
+ VerifyOffset(verifier, VT_DETAILS) &&
+ VerifyQuantizationDetails(verifier, details(), details_type()) &&
+ VerifyField<int32_t>(verifier, VT_QUANTIZED_DIMENSION, 4) && verifier.EndTable();
+ }
+};
+
+template <>
+inline const onert_tflite::CustomQuantization *
+QuantizationParameters::details_as<onert_tflite::CustomQuantization>() const
+{
+ return details_as_CustomQuantization();
+}
+
+struct QuantizationParametersBuilder
+{
+ typedef QuantizationParameters Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_min(::flatbuffers::Offset<::flatbuffers::Vector<float>> min)
+ {
+ fbb_.AddOffset(QuantizationParameters::VT_MIN, min);
+ }
+ void add_max(::flatbuffers::Offset<::flatbuffers::Vector<float>> max)
+ {
+ fbb_.AddOffset(QuantizationParameters::VT_MAX, max);
+ }
+ void add_scale(::flatbuffers::Offset<::flatbuffers::Vector<float>> scale)
+ {
+ fbb_.AddOffset(QuantizationParameters::VT_SCALE, scale);
+ }
+ void add_zero_point(::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> zero_point)
+ {
+ fbb_.AddOffset(QuantizationParameters::VT_ZERO_POINT, zero_point);
+ }
+ void add_details_type(onert_tflite::QuantizationDetails details_type)
+ {
+ fbb_.AddElement<uint8_t>(QuantizationParameters::VT_DETAILS_TYPE,
+ static_cast<uint8_t>(details_type), 0);
+ }
+ void add_details(::flatbuffers::Offset<void> details)
+ {
+ fbb_.AddOffset(QuantizationParameters::VT_DETAILS, details);
+ }
+ void add_quantized_dimension(int32_t quantized_dimension)
+ {
+ fbb_.AddElement<int32_t>(QuantizationParameters::VT_QUANTIZED_DIMENSION, quantized_dimension,
+ 0);
+ }
+ explicit QuantizationParametersBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<QuantizationParameters> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<QuantizationParameters>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<QuantizationParameters> CreateQuantizationParameters(
+ ::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::Vector<float>> min = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<float>> max = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<float>> scale = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int64_t>> zero_point = 0,
+ onert_tflite::QuantizationDetails details_type = onert_tflite::QuantizationDetails_NONE,
+ ::flatbuffers::Offset<void> details = 0, int32_t quantized_dimension = 0)
+{
+ QuantizationParametersBuilder builder_(_fbb);
+ builder_.add_quantized_dimension(quantized_dimension);
+ builder_.add_details(details);
+ builder_.add_zero_point(zero_point);
+ builder_.add_scale(scale);
+ builder_.add_max(max);
+ builder_.add_min(min);
+ builder_.add_details_type(details_type);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<QuantizationParameters> CreateQuantizationParametersDirect(
+ ::flatbuffers::FlatBufferBuilder &_fbb, const std::vector<float> *min = nullptr,
+ const std::vector<float> *max = nullptr, const std::vector<float> *scale = nullptr,
+ const std::vector<int64_t> *zero_point = nullptr,
+ onert_tflite::QuantizationDetails details_type = onert_tflite::QuantizationDetails_NONE,
+ ::flatbuffers::Offset<void> details = 0, int32_t quantized_dimension = 0)
+{
+ auto min__ = min ? _fbb.CreateVector<float>(*min) : 0;
+ auto max__ = max ? _fbb.CreateVector<float>(*max) : 0;
+ auto scale__ = scale ? _fbb.CreateVector<float>(*scale) : 0;
+ auto zero_point__ = zero_point ? _fbb.CreateVector<int64_t>(*zero_point) : 0;
+ return onert_tflite::CreateQuantizationParameters(_fbb, min__, max__, scale__, zero_point__,
+ details_type, details, quantized_dimension);
+}
+
+struct Int32Vector FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef Int32VectorBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_VALUES = 4
+ };
+ const ::flatbuffers::Vector<int32_t> *values() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_VALUES);
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_VALUES) &&
+ verifier.VerifyVector(values()) && verifier.EndTable();
+ }
+};
+
+struct Int32VectorBuilder
+{
+ typedef Int32Vector Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_values(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> values)
+ {
+ fbb_.AddOffset(Int32Vector::VT_VALUES, values);
+ }
+ explicit Int32VectorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<Int32Vector> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<Int32Vector>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<Int32Vector>
+CreateInt32Vector(::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> values = 0)
+{
+ Int32VectorBuilder builder_(_fbb);
+ builder_.add_values(values);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<Int32Vector>
+CreateInt32VectorDirect(::flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *values = nullptr)
+{
+ auto values__ = values ? _fbb.CreateVector<int32_t>(*values) : 0;
+ return onert_tflite::CreateInt32Vector(_fbb, values__);
+}
+
+struct Uint16Vector FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef Uint16VectorBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_VALUES = 4
+ };
+ const ::flatbuffers::Vector<uint16_t> *values() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<uint16_t> *>(VT_VALUES);
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_VALUES) &&
+ verifier.VerifyVector(values()) && verifier.EndTable();
+ }
+};
+
+struct Uint16VectorBuilder
+{
+ typedef Uint16Vector Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_values(::flatbuffers::Offset<::flatbuffers::Vector<uint16_t>> values)
+ {
+ fbb_.AddOffset(Uint16Vector::VT_VALUES, values);
+ }
+ explicit Uint16VectorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<Uint16Vector> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<Uint16Vector>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<Uint16Vector>
+CreateUint16Vector(::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint16_t>> values = 0)
+{
+ Uint16VectorBuilder builder_(_fbb);
+ builder_.add_values(values);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<Uint16Vector>
+CreateUint16VectorDirect(::flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<uint16_t> *values = nullptr)
+{
+ if (values)
+ {
+ _fbb.ForceVectorAlignment(values->size(), sizeof(uint16_t), 4);
+ }
+ auto values__ = values ? _fbb.CreateVector<uint16_t>(*values) : 0;
+ return onert_tflite::CreateUint16Vector(_fbb, values__);
+}
+
+struct Uint8Vector FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef Uint8VectorBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_VALUES = 4
+ };
+ const ::flatbuffers::Vector<uint8_t> *values() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_VALUES);
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_VALUES) &&
+ verifier.VerifyVector(values()) && verifier.EndTable();
+ }
+};
+
+struct Uint8VectorBuilder
+{
+ typedef Uint8Vector Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_values(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> values)
+ {
+ fbb_.AddOffset(Uint8Vector::VT_VALUES, values);
+ }
+ explicit Uint8VectorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<Uint8Vector> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<Uint8Vector>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<Uint8Vector>
+CreateUint8Vector(::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> values = 0)
+{
+ Uint8VectorBuilder builder_(_fbb);
+ builder_.add_values(values);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<Uint8Vector>
+CreateUint8VectorDirect(::flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<uint8_t> *values = nullptr)
+{
+ if (values)
+ {
+ _fbb.ForceVectorAlignment(values->size(), sizeof(uint8_t), 4);
+ }
+ auto values__ = values ? _fbb.CreateVector<uint8_t>(*values) : 0;
+ return onert_tflite::CreateUint8Vector(_fbb, values__);
+}
+
+struct DimensionMetadata FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef DimensionMetadataBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_FORMAT = 4,
+ VT_DENSE_SIZE = 6,
+ VT_ARRAY_SEGMENTS_TYPE = 8,
+ VT_ARRAY_SEGMENTS = 10,
+ VT_ARRAY_INDICES_TYPE = 12,
+ VT_ARRAY_INDICES = 14
+ };
+ onert_tflite::DimensionType format() const
+ {
+ return static_cast<onert_tflite::DimensionType>(GetField<int8_t>(VT_FORMAT, 0));
+ }
+ int32_t dense_size() const { return GetField<int32_t>(VT_DENSE_SIZE, 0); }
+ onert_tflite::SparseIndexVector array_segments_type() const
+ {
+ return static_cast<onert_tflite::SparseIndexVector>(
+ GetField<uint8_t>(VT_ARRAY_SEGMENTS_TYPE, 0));
+ }
+ const void *array_segments() const { return GetPointer<const void *>(VT_ARRAY_SEGMENTS); }
+ template <typename T> const T *array_segments_as() const;
+ const onert_tflite::Int32Vector *array_segments_as_Int32Vector() const
+ {
+ return array_segments_type() == onert_tflite::SparseIndexVector_Int32Vector
+ ? static_cast<const onert_tflite::Int32Vector *>(array_segments())
+ : nullptr;
+ }
+ const onert_tflite::Uint16Vector *array_segments_as_Uint16Vector() const
+ {
+ return array_segments_type() == onert_tflite::SparseIndexVector_Uint16Vector
+ ? static_cast<const onert_tflite::Uint16Vector *>(array_segments())
+ : nullptr;
+ }
+ const onert_tflite::Uint8Vector *array_segments_as_Uint8Vector() const
+ {
+ return array_segments_type() == onert_tflite::SparseIndexVector_Uint8Vector
+ ? static_cast<const onert_tflite::Uint8Vector *>(array_segments())
+ : nullptr;
+ }
+ onert_tflite::SparseIndexVector array_indices_type() const
+ {
+ return static_cast<onert_tflite::SparseIndexVector>(
+ GetField<uint8_t>(VT_ARRAY_INDICES_TYPE, 0));
+ }
+ const void *array_indices() const { return GetPointer<const void *>(VT_ARRAY_INDICES); }
+ template <typename T> const T *array_indices_as() const;
+ const onert_tflite::Int32Vector *array_indices_as_Int32Vector() const
+ {
+ return array_indices_type() == onert_tflite::SparseIndexVector_Int32Vector
+ ? static_cast<const onert_tflite::Int32Vector *>(array_indices())
+ : nullptr;
+ }
+ const onert_tflite::Uint16Vector *array_indices_as_Uint16Vector() const
+ {
+ return array_indices_type() == onert_tflite::SparseIndexVector_Uint16Vector
+ ? static_cast<const onert_tflite::Uint16Vector *>(array_indices())
+ : nullptr;
+ }
+ const onert_tflite::Uint8Vector *array_indices_as_Uint8Vector() const
+ {
+ return array_indices_type() == onert_tflite::SparseIndexVector_Uint8Vector
+ ? static_cast<const onert_tflite::Uint8Vector *>(array_indices())
+ : nullptr;
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_FORMAT, 1) &&
+ VerifyField<int32_t>(verifier, VT_DENSE_SIZE, 4) &&
+ VerifyField<uint8_t>(verifier, VT_ARRAY_SEGMENTS_TYPE, 1) &&
+ VerifyOffset(verifier, VT_ARRAY_SEGMENTS) &&
+ VerifySparseIndexVector(verifier, array_segments(), array_segments_type()) &&
+ VerifyField<uint8_t>(verifier, VT_ARRAY_INDICES_TYPE, 1) &&
+ VerifyOffset(verifier, VT_ARRAY_INDICES) &&
+ VerifySparseIndexVector(verifier, array_indices(), array_indices_type()) &&
+ verifier.EndTable();
+ }
+};
+
+template <>
+inline const onert_tflite::Int32Vector *
+DimensionMetadata::array_segments_as<onert_tflite::Int32Vector>() const
+{
+ return array_segments_as_Int32Vector();
+}
+
+template <>
+inline const onert_tflite::Uint16Vector *
+DimensionMetadata::array_segments_as<onert_tflite::Uint16Vector>() const
+{
+ return array_segments_as_Uint16Vector();
+}
+
+template <>
+inline const onert_tflite::Uint8Vector *
+DimensionMetadata::array_segments_as<onert_tflite::Uint8Vector>() const
+{
+ return array_segments_as_Uint8Vector();
+}
+
+template <>
+inline const onert_tflite::Int32Vector *
+DimensionMetadata::array_indices_as<onert_tflite::Int32Vector>() const
+{
+ return array_indices_as_Int32Vector();
+}
+
+template <>
+inline const onert_tflite::Uint16Vector *
+DimensionMetadata::array_indices_as<onert_tflite::Uint16Vector>() const
+{
+ return array_indices_as_Uint16Vector();
+}
+
+template <>
+inline const onert_tflite::Uint8Vector *
+DimensionMetadata::array_indices_as<onert_tflite::Uint8Vector>() const
+{
+ return array_indices_as_Uint8Vector();
+}
+
+struct DimensionMetadataBuilder
+{
+ typedef DimensionMetadata Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_format(onert_tflite::DimensionType format)
+ {
+ fbb_.AddElement<int8_t>(DimensionMetadata::VT_FORMAT, static_cast<int8_t>(format), 0);
+ }
+ void add_dense_size(int32_t dense_size)
+ {
+ fbb_.AddElement<int32_t>(DimensionMetadata::VT_DENSE_SIZE, dense_size, 0);
+ }
+ void add_array_segments_type(onert_tflite::SparseIndexVector array_segments_type)
+ {
+ fbb_.AddElement<uint8_t>(DimensionMetadata::VT_ARRAY_SEGMENTS_TYPE,
+ static_cast<uint8_t>(array_segments_type), 0);
+ }
+ void add_array_segments(::flatbuffers::Offset<void> array_segments)
+ {
+ fbb_.AddOffset(DimensionMetadata::VT_ARRAY_SEGMENTS, array_segments);
+ }
+ void add_array_indices_type(onert_tflite::SparseIndexVector array_indices_type)
+ {
+ fbb_.AddElement<uint8_t>(DimensionMetadata::VT_ARRAY_INDICES_TYPE,
+ static_cast<uint8_t>(array_indices_type), 0);
+ }
+ void add_array_indices(::flatbuffers::Offset<void> array_indices)
+ {
+ fbb_.AddOffset(DimensionMetadata::VT_ARRAY_INDICES, array_indices);
+ }
+ explicit DimensionMetadataBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<DimensionMetadata> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<DimensionMetadata>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<DimensionMetadata> CreateDimensionMetadata(
+ ::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::DimensionType format = onert_tflite::DimensionType_DENSE, int32_t dense_size = 0,
+ onert_tflite::SparseIndexVector array_segments_type = onert_tflite::SparseIndexVector_NONE,
+ ::flatbuffers::Offset<void> array_segments = 0,
+ onert_tflite::SparseIndexVector array_indices_type = onert_tflite::SparseIndexVector_NONE,
+ ::flatbuffers::Offset<void> array_indices = 0)
+{
+ DimensionMetadataBuilder builder_(_fbb);
+ builder_.add_array_indices(array_indices);
+ builder_.add_array_segments(array_segments);
+ builder_.add_dense_size(dense_size);
+ builder_.add_array_indices_type(array_indices_type);
+ builder_.add_array_segments_type(array_segments_type);
+ builder_.add_format(format);
+ return builder_.Finish();
+}
+
+struct SparsityParameters FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SparsityParametersBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_TRAVERSAL_ORDER = 4,
+ VT_BLOCK_MAP = 6,
+ VT_DIM_METADATA = 8
+ };
+ const ::flatbuffers::Vector<int32_t> *traversal_order() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_TRAVERSAL_ORDER);
+ }
+ const ::flatbuffers::Vector<int32_t> *block_map() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_BLOCK_MAP);
+ }
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::DimensionMetadata>> *
+ dim_metadata() const
+ {
+ return GetPointer<
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::DimensionMetadata>> *>(
+ VT_DIM_METADATA);
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_TRAVERSAL_ORDER) &&
+ verifier.VerifyVector(traversal_order()) && VerifyOffset(verifier, VT_BLOCK_MAP) &&
+ verifier.VerifyVector(block_map()) && VerifyOffset(verifier, VT_DIM_METADATA) &&
+ verifier.VerifyVector(dim_metadata()) && verifier.VerifyVectorOfTables(dim_metadata()) &&
+ verifier.EndTable();
+ }
+};
+
+struct SparsityParametersBuilder
+{
+ typedef SparsityParameters Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_traversal_order(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> traversal_order)
+ {
+ fbb_.AddOffset(SparsityParameters::VT_TRAVERSAL_ORDER, traversal_order);
+ }
+ void add_block_map(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> block_map)
+ {
+ fbb_.AddOffset(SparsityParameters::VT_BLOCK_MAP, block_map);
+ }
+ void
+ add_dim_metadata(::flatbuffers::Offset<
+ ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::DimensionMetadata>>>
+ dim_metadata)
+ {
+ fbb_.AddOffset(SparsityParameters::VT_DIM_METADATA, dim_metadata);
+ }
+ explicit SparsityParametersBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SparsityParameters> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SparsityParameters>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SparsityParameters> CreateSparsityParameters(
+ ::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> traversal_order = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> block_map = 0,
+ ::flatbuffers::Offset<
+ ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::DimensionMetadata>>>
+ dim_metadata = 0)
+{
+ SparsityParametersBuilder builder_(_fbb);
+ builder_.add_dim_metadata(dim_metadata);
+ builder_.add_block_map(block_map);
+ builder_.add_traversal_order(traversal_order);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<SparsityParameters> CreateSparsityParametersDirect(
+ ::flatbuffers::FlatBufferBuilder &_fbb, const std::vector<int32_t> *traversal_order = nullptr,
+ const std::vector<int32_t> *block_map = nullptr,
+ const std::vector<::flatbuffers::Offset<onert_tflite::DimensionMetadata>> *dim_metadata = nullptr)
+{
+ auto traversal_order__ = traversal_order ? _fbb.CreateVector<int32_t>(*traversal_order) : 0;
+ auto block_map__ = block_map ? _fbb.CreateVector<int32_t>(*block_map) : 0;
+ auto dim_metadata__ =
+ dim_metadata
+ ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::DimensionMetadata>>(*dim_metadata)
+ : 0;
+ return onert_tflite::CreateSparsityParameters(_fbb, traversal_order__, block_map__,
+ dim_metadata__);
+}
+
+struct Tensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef TensorBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_SHAPE = 4,
+ VT_TYPE = 6,
+ VT_BUFFER = 8,
+ VT_NAME = 10,
+ VT_QUANTIZATION = 12,
+ VT_IS_VARIABLE = 14,
+ VT_SPARSITY = 16,
+ VT_SHAPE_SIGNATURE = 18,
+ VT_HAS_RANK = 20
+ };
+ const ::flatbuffers::Vector<int32_t> *shape() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_SHAPE);
+ }
+ onert_tflite::TensorType type() const
+ {
+ return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_TYPE, 0));
+ }
+ uint32_t buffer() const { return GetField<uint32_t>(VT_BUFFER, 0); }
+ const ::flatbuffers::String *name() const
+ {
+ return GetPointer<const ::flatbuffers::String *>(VT_NAME);
+ }
+ const onert_tflite::QuantizationParameters *quantization() const
+ {
+ return GetPointer<const onert_tflite::QuantizationParameters *>(VT_QUANTIZATION);
+ }
+ bool is_variable() const { return GetField<uint8_t>(VT_IS_VARIABLE, 0) != 0; }
+ const onert_tflite::SparsityParameters *sparsity() const
+ {
+ return GetPointer<const onert_tflite::SparsityParameters *>(VT_SPARSITY);
+ }
+ const ::flatbuffers::Vector<int32_t> *shape_signature() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_SHAPE_SIGNATURE);
+ }
+ bool has_rank() const { return GetField<uint8_t>(VT_HAS_RANK, 0) != 0; }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) &&
+ verifier.VerifyVector(shape()) && VerifyField<int8_t>(verifier, VT_TYPE, 1) &&
+ VerifyField<uint32_t>(verifier, VT_BUFFER, 4) && VerifyOffset(verifier, VT_NAME) &&
+ verifier.VerifyString(name()) && VerifyOffset(verifier, VT_QUANTIZATION) &&
+ verifier.VerifyTable(quantization()) &&
+ VerifyField<uint8_t>(verifier, VT_IS_VARIABLE, 1) &&
+ VerifyOffset(verifier, VT_SPARSITY) && verifier.VerifyTable(sparsity()) &&
+ VerifyOffset(verifier, VT_SHAPE_SIGNATURE) && verifier.VerifyVector(shape_signature()) &&
+ VerifyField<uint8_t>(verifier, VT_HAS_RANK, 1) && verifier.EndTable();
+ }
+};
+
+struct TensorBuilder
+{
+ typedef Tensor Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_shape(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> shape)
+ {
+ fbb_.AddOffset(Tensor::VT_SHAPE, shape);
+ }
+ void add_type(onert_tflite::TensorType type)
+ {
+ fbb_.AddElement<int8_t>(Tensor::VT_TYPE, static_cast<int8_t>(type), 0);
+ }
+ void add_buffer(uint32_t buffer) { fbb_.AddElement<uint32_t>(Tensor::VT_BUFFER, buffer, 0); }
+ void add_name(::flatbuffers::Offset<::flatbuffers::String> name)
+ {
+ fbb_.AddOffset(Tensor::VT_NAME, name);
+ }
+ void add_quantization(::flatbuffers::Offset<onert_tflite::QuantizationParameters> quantization)
+ {
+ fbb_.AddOffset(Tensor::VT_QUANTIZATION, quantization);
+ }
+ void add_is_variable(bool is_variable)
+ {
+ fbb_.AddElement<uint8_t>(Tensor::VT_IS_VARIABLE, static_cast<uint8_t>(is_variable), 0);
+ }
+ void add_sparsity(::flatbuffers::Offset<onert_tflite::SparsityParameters> sparsity)
+ {
+ fbb_.AddOffset(Tensor::VT_SPARSITY, sparsity);
+ }
+ void add_shape_signature(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> shape_signature)
+ {
+ fbb_.AddOffset(Tensor::VT_SHAPE_SIGNATURE, shape_signature);
+ }
+ void add_has_rank(bool has_rank)
+ {
+ fbb_.AddElement<uint8_t>(Tensor::VT_HAS_RANK, static_cast<uint8_t>(has_rank), 0);
+ }
+ explicit TensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<Tensor> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<Tensor>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<Tensor> CreateTensor(
+ ::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> shape = 0,
+ onert_tflite::TensorType type = onert_tflite::TensorType_FLOAT32, uint32_t buffer = 0,
+ ::flatbuffers::Offset<::flatbuffers::String> name = 0,
+ ::flatbuffers::Offset<onert_tflite::QuantizationParameters> quantization = 0,
+ bool is_variable = false, ::flatbuffers::Offset<onert_tflite::SparsityParameters> sparsity = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> shape_signature = 0, bool has_rank = false)
+{
+ TensorBuilder builder_(_fbb);
+ builder_.add_shape_signature(shape_signature);
+ builder_.add_sparsity(sparsity);
+ builder_.add_quantization(quantization);
+ builder_.add_name(name);
+ builder_.add_buffer(buffer);
+ builder_.add_shape(shape);
+ builder_.add_has_rank(has_rank);
+ builder_.add_is_variable(is_variable);
+ builder_.add_type(type);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<Tensor> CreateTensorDirect(
+ ::flatbuffers::FlatBufferBuilder &_fbb, const std::vector<int32_t> *shape = nullptr,
+ onert_tflite::TensorType type = onert_tflite::TensorType_FLOAT32, uint32_t buffer = 0,
+ const char *name = nullptr,
+ ::flatbuffers::Offset<onert_tflite::QuantizationParameters> quantization = 0,
+ bool is_variable = false, ::flatbuffers::Offset<onert_tflite::SparsityParameters> sparsity = 0,
+ const std::vector<int32_t> *shape_signature = nullptr, bool has_rank = false)
+{
+ auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0;
+ auto name__ = name ? _fbb.CreateString(name) : 0;
+ auto shape_signature__ = shape_signature ? _fbb.CreateVector<int32_t>(*shape_signature) : 0;
+ return onert_tflite::CreateTensor(_fbb, shape__, type, buffer, name__, quantization, is_variable,
+ sparsity, shape_signature__, has_rank);
+}
+
+struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef Conv2DOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_PADDING = 4,
+ VT_STRIDE_W = 6,
+ VT_STRIDE_H = 8,
+ VT_FUSED_ACTIVATION_FUNCTION = 10,
+ VT_DILATION_W_FACTOR = 12,
+ VT_DILATION_H_FACTOR = 14
+ };
+ onert_tflite::Padding padding() const
+ {
+ return static_cast<onert_tflite::Padding>(GetField<int8_t>(VT_PADDING, 0));
+ }
+ int32_t stride_w() const { return GetField<int32_t>(VT_STRIDE_W, 0); }
+ int32_t stride_h() const { return GetField<int32_t>(VT_STRIDE_H, 0); }
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ int32_t dilation_w_factor() const { return GetField<int32_t>(VT_DILATION_W_FACTOR, 1); }
+ int32_t dilation_h_factor() const { return GetField<int32_t>(VT_DILATION_H_FACTOR, 1); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_PADDING, 1) &&
+ VerifyField<int32_t>(verifier, VT_STRIDE_W, 4) &&
+ VerifyField<int32_t>(verifier, VT_STRIDE_H, 4) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_W_FACTOR, 4) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_H_FACTOR, 4) && verifier.EndTable();
+ }
+};
+
+struct Conv2DOptionsBuilder
+{
+ typedef Conv2DOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_padding(onert_tflite::Padding padding)
+ {
+ fbb_.AddElement<int8_t>(Conv2DOptions::VT_PADDING, static_cast<int8_t>(padding), 0);
+ }
+ void add_stride_w(int32_t stride_w)
+ {
+ fbb_.AddElement<int32_t>(Conv2DOptions::VT_STRIDE_W, stride_w, 0);
+ }
+ void add_stride_h(int32_t stride_h)
+ {
+ fbb_.AddElement<int32_t>(Conv2DOptions::VT_STRIDE_H, stride_h, 0);
+ }
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(Conv2DOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_dilation_w_factor(int32_t dilation_w_factor)
+ {
+ fbb_.AddElement<int32_t>(Conv2DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1);
+ }
+ void add_dilation_h_factor(int32_t dilation_h_factor)
+ {
+ fbb_.AddElement<int32_t>(Conv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1);
+ }
+ explicit Conv2DOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<Conv2DOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<Conv2DOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<Conv2DOptions>
+CreateConv2DOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::Padding padding = onert_tflite::Padding_SAME,
+ int32_t stride_w = 0, int32_t stride_h = 0,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE,
+ int32_t dilation_w_factor = 1, int32_t dilation_h_factor = 1)
+{
+ Conv2DOptionsBuilder builder_(_fbb);
+ builder_.add_dilation_h_factor(dilation_h_factor);
+ builder_.add_dilation_w_factor(dilation_w_factor);
+ builder_.add_stride_h(stride_h);
+ builder_.add_stride_w(stride_w);
+ builder_.add_fused_activation_function(fused_activation_function);
+ builder_.add_padding(padding);
+ return builder_.Finish();
+}
+
+struct Conv3DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef Conv3DOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_PADDING = 4,
+ VT_STRIDE_D = 6,
+ VT_STRIDE_W = 8,
+ VT_STRIDE_H = 10,
+ VT_FUSED_ACTIVATION_FUNCTION = 12,
+ VT_DILATION_D_FACTOR = 14,
+ VT_DILATION_W_FACTOR = 16,
+ VT_DILATION_H_FACTOR = 18
+ };
+ onert_tflite::Padding padding() const
+ {
+ return static_cast<onert_tflite::Padding>(GetField<int8_t>(VT_PADDING, 0));
+ }
+ int32_t stride_d() const { return GetField<int32_t>(VT_STRIDE_D, 0); }
+ int32_t stride_w() const { return GetField<int32_t>(VT_STRIDE_W, 0); }
+ int32_t stride_h() const { return GetField<int32_t>(VT_STRIDE_H, 0); }
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ int32_t dilation_d_factor() const { return GetField<int32_t>(VT_DILATION_D_FACTOR, 1); }
+ int32_t dilation_w_factor() const { return GetField<int32_t>(VT_DILATION_W_FACTOR, 1); }
+ int32_t dilation_h_factor() const { return GetField<int32_t>(VT_DILATION_H_FACTOR, 1); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_PADDING, 1) &&
+ VerifyField<int32_t>(verifier, VT_STRIDE_D, 4) &&
+ VerifyField<int32_t>(verifier, VT_STRIDE_W, 4) &&
+ VerifyField<int32_t>(verifier, VT_STRIDE_H, 4) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_D_FACTOR, 4) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_W_FACTOR, 4) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_H_FACTOR, 4) && verifier.EndTable();
+ }
+};
+
+struct Conv3DOptionsBuilder
+{
+ typedef Conv3DOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_padding(onert_tflite::Padding padding)
+ {
+ fbb_.AddElement<int8_t>(Conv3DOptions::VT_PADDING, static_cast<int8_t>(padding), 0);
+ }
+ void add_stride_d(int32_t stride_d)
+ {
+ fbb_.AddElement<int32_t>(Conv3DOptions::VT_STRIDE_D, stride_d, 0);
+ }
+ void add_stride_w(int32_t stride_w)
+ {
+ fbb_.AddElement<int32_t>(Conv3DOptions::VT_STRIDE_W, stride_w, 0);
+ }
+ void add_stride_h(int32_t stride_h)
+ {
+ fbb_.AddElement<int32_t>(Conv3DOptions::VT_STRIDE_H, stride_h, 0);
+ }
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(Conv3DOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_dilation_d_factor(int32_t dilation_d_factor)
+ {
+ fbb_.AddElement<int32_t>(Conv3DOptions::VT_DILATION_D_FACTOR, dilation_d_factor, 1);
+ }
+ void add_dilation_w_factor(int32_t dilation_w_factor)
+ {
+ fbb_.AddElement<int32_t>(Conv3DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1);
+ }
+ void add_dilation_h_factor(int32_t dilation_h_factor)
+ {
+ fbb_.AddElement<int32_t>(Conv3DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1);
+ }
+ explicit Conv3DOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<Conv3DOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<Conv3DOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<Conv3DOptions>
+CreateConv3DOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::Padding padding = onert_tflite::Padding_SAME,
+ int32_t stride_d = 0, int32_t stride_w = 0, int32_t stride_h = 0,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE,
+ int32_t dilation_d_factor = 1, int32_t dilation_w_factor = 1,
+ int32_t dilation_h_factor = 1)
+{
+ Conv3DOptionsBuilder builder_(_fbb);
+ builder_.add_dilation_h_factor(dilation_h_factor);
+ builder_.add_dilation_w_factor(dilation_w_factor);
+ builder_.add_dilation_d_factor(dilation_d_factor);
+ builder_.add_stride_h(stride_h);
+ builder_.add_stride_w(stride_w);
+ builder_.add_stride_d(stride_d);
+ builder_.add_fused_activation_function(fused_activation_function);
+ builder_.add_padding(padding);
+ return builder_.Finish();
+}
+
+struct Pool2DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef Pool2DOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_PADDING = 4,
+ VT_STRIDE_W = 6,
+ VT_STRIDE_H = 8,
+ VT_FILTER_WIDTH = 10,
+ VT_FILTER_HEIGHT = 12,
+ VT_FUSED_ACTIVATION_FUNCTION = 14
+ };
+ onert_tflite::Padding padding() const
+ {
+ return static_cast<onert_tflite::Padding>(GetField<int8_t>(VT_PADDING, 0));
+ }
+ int32_t stride_w() const { return GetField<int32_t>(VT_STRIDE_W, 0); }
+ int32_t stride_h() const { return GetField<int32_t>(VT_STRIDE_H, 0); }
+ int32_t filter_width() const { return GetField<int32_t>(VT_FILTER_WIDTH, 0); }
+ int32_t filter_height() const { return GetField<int32_t>(VT_FILTER_HEIGHT, 0); }
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_PADDING, 1) &&
+ VerifyField<int32_t>(verifier, VT_STRIDE_W, 4) &&
+ VerifyField<int32_t>(verifier, VT_STRIDE_H, 4) &&
+ VerifyField<int32_t>(verifier, VT_FILTER_WIDTH, 4) &&
+ VerifyField<int32_t>(verifier, VT_FILTER_HEIGHT, 4) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && verifier.EndTable();
+ }
+};
+
+struct Pool2DOptionsBuilder
+{
+ typedef Pool2DOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_padding(onert_tflite::Padding padding)
+ {
+ fbb_.AddElement<int8_t>(Pool2DOptions::VT_PADDING, static_cast<int8_t>(padding), 0);
+ }
+ void add_stride_w(int32_t stride_w)
+ {
+ fbb_.AddElement<int32_t>(Pool2DOptions::VT_STRIDE_W, stride_w, 0);
+ }
+ void add_stride_h(int32_t stride_h)
+ {
+ fbb_.AddElement<int32_t>(Pool2DOptions::VT_STRIDE_H, stride_h, 0);
+ }
+ void add_filter_width(int32_t filter_width)
+ {
+ fbb_.AddElement<int32_t>(Pool2DOptions::VT_FILTER_WIDTH, filter_width, 0);
+ }
+ void add_filter_height(int32_t filter_height)
+ {
+ fbb_.AddElement<int32_t>(Pool2DOptions::VT_FILTER_HEIGHT, filter_height, 0);
+ }
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(Pool2DOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ explicit Pool2DOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<Pool2DOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<Pool2DOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<Pool2DOptions>
+CreatePool2DOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::Padding padding = onert_tflite::Padding_SAME,
+ int32_t stride_w = 0, int32_t stride_h = 0, int32_t filter_width = 0,
+ int32_t filter_height = 0,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE)
+{
+ Pool2DOptionsBuilder builder_(_fbb);
+ builder_.add_filter_height(filter_height);
+ builder_.add_filter_width(filter_width);
+ builder_.add_stride_h(stride_h);
+ builder_.add_stride_w(stride_w);
+ builder_.add_fused_activation_function(fused_activation_function);
+ builder_.add_padding(padding);
+ return builder_.Finish();
+}
+
+struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef DepthwiseConv2DOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_PADDING = 4,
+ VT_STRIDE_W = 6,
+ VT_STRIDE_H = 8,
+ VT_DEPTH_MULTIPLIER = 10,
+ VT_FUSED_ACTIVATION_FUNCTION = 12,
+ VT_DILATION_W_FACTOR = 14,
+ VT_DILATION_H_FACTOR = 16
+ };
+ onert_tflite::Padding padding() const
+ {
+ return static_cast<onert_tflite::Padding>(GetField<int8_t>(VT_PADDING, 0));
+ }
+ int32_t stride_w() const { return GetField<int32_t>(VT_STRIDE_W, 0); }
+ int32_t stride_h() const { return GetField<int32_t>(VT_STRIDE_H, 0); }
+ int32_t depth_multiplier() const { return GetField<int32_t>(VT_DEPTH_MULTIPLIER, 0); }
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ int32_t dilation_w_factor() const { return GetField<int32_t>(VT_DILATION_W_FACTOR, 1); }
+ int32_t dilation_h_factor() const { return GetField<int32_t>(VT_DILATION_H_FACTOR, 1); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_PADDING, 1) &&
+ VerifyField<int32_t>(verifier, VT_STRIDE_W, 4) &&
+ VerifyField<int32_t>(verifier, VT_STRIDE_H, 4) &&
+ VerifyField<int32_t>(verifier, VT_DEPTH_MULTIPLIER, 4) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_W_FACTOR, 4) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_H_FACTOR, 4) && verifier.EndTable();
+ }
+};
+
+struct DepthwiseConv2DOptionsBuilder
+{
+ typedef DepthwiseConv2DOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_padding(onert_tflite::Padding padding)
+ {
+ fbb_.AddElement<int8_t>(DepthwiseConv2DOptions::VT_PADDING, static_cast<int8_t>(padding), 0);
+ }
+ void add_stride_w(int32_t stride_w)
+ {
+ fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_STRIDE_W, stride_w, 0);
+ }
+ void add_stride_h(int32_t stride_h)
+ {
+ fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_STRIDE_H, stride_h, 0);
+ }
+ void add_depth_multiplier(int32_t depth_multiplier)
+ {
+ fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DEPTH_MULTIPLIER, depth_multiplier, 0);
+ }
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_dilation_w_factor(int32_t dilation_w_factor)
+ {
+ fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1);
+ }
+ void add_dilation_h_factor(int32_t dilation_h_factor)
+ {
+ fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1);
+ }
+ explicit DepthwiseConv2DOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<DepthwiseConv2DOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<DepthwiseConv2DOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<DepthwiseConv2DOptions>
+CreateDepthwiseConv2DOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::Padding padding = onert_tflite::Padding_SAME,
+ int32_t stride_w = 0, int32_t stride_h = 0,
+ int32_t depth_multiplier = 0,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE,
+ int32_t dilation_w_factor = 1, int32_t dilation_h_factor = 1)
+{
+ DepthwiseConv2DOptionsBuilder builder_(_fbb);
+ builder_.add_dilation_h_factor(dilation_h_factor);
+ builder_.add_dilation_w_factor(dilation_w_factor);
+ builder_.add_depth_multiplier(depth_multiplier);
+ builder_.add_stride_h(stride_h);
+ builder_.add_stride_w(stride_w);
+ builder_.add_fused_activation_function(fused_activation_function);
+ builder_.add_padding(padding);
+ return builder_.Finish();
+}
+
+struct ConcatEmbeddingsOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ConcatEmbeddingsOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_NUM_CHANNELS = 4,
+ VT_NUM_COLUMNS_PER_CHANNEL = 6,
+ VT_EMBEDDING_DIM_PER_CHANNEL = 8
+ };
+ int32_t num_channels() const { return GetField<int32_t>(VT_NUM_CHANNELS, 0); }
+ const ::flatbuffers::Vector<int32_t> *num_columns_per_channel() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_NUM_COLUMNS_PER_CHANNEL);
+ }
+ const ::flatbuffers::Vector<int32_t> *embedding_dim_per_channel() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_EMBEDDING_DIM_PER_CHANNEL);
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_NUM_CHANNELS, 4) &&
+ VerifyOffset(verifier, VT_NUM_COLUMNS_PER_CHANNEL) &&
+ verifier.VerifyVector(num_columns_per_channel()) &&
+ VerifyOffset(verifier, VT_EMBEDDING_DIM_PER_CHANNEL) &&
+ verifier.VerifyVector(embedding_dim_per_channel()) && verifier.EndTable();
+ }
+};
+
+struct ConcatEmbeddingsOptionsBuilder
+{
+ typedef ConcatEmbeddingsOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_num_channels(int32_t num_channels)
+ {
+ fbb_.AddElement<int32_t>(ConcatEmbeddingsOptions::VT_NUM_CHANNELS, num_channels, 0);
+ }
+ void add_num_columns_per_channel(
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> num_columns_per_channel)
+ {
+ fbb_.AddOffset(ConcatEmbeddingsOptions::VT_NUM_COLUMNS_PER_CHANNEL, num_columns_per_channel);
+ }
+ void add_embedding_dim_per_channel(
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> embedding_dim_per_channel)
+ {
+ fbb_.AddOffset(ConcatEmbeddingsOptions::VT_EMBEDDING_DIM_PER_CHANNEL,
+ embedding_dim_per_channel);
+ }
+ explicit ConcatEmbeddingsOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ConcatEmbeddingsOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ConcatEmbeddingsOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ConcatEmbeddingsOptions> CreateConcatEmbeddingsOptions(
+ ::flatbuffers::FlatBufferBuilder &_fbb, int32_t num_channels = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> num_columns_per_channel = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> embedding_dim_per_channel = 0)
+{
+ ConcatEmbeddingsOptionsBuilder builder_(_fbb);
+ builder_.add_embedding_dim_per_channel(embedding_dim_per_channel);
+ builder_.add_num_columns_per_channel(num_columns_per_channel);
+ builder_.add_num_channels(num_channels);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<ConcatEmbeddingsOptions>
+CreateConcatEmbeddingsOptionsDirect(::flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t num_channels = 0,
+ const std::vector<int32_t> *num_columns_per_channel = nullptr,
+ const std::vector<int32_t> *embedding_dim_per_channel = nullptr)
+{
+ auto num_columns_per_channel__ =
+ num_columns_per_channel ? _fbb.CreateVector<int32_t>(*num_columns_per_channel) : 0;
+ auto embedding_dim_per_channel__ =
+ embedding_dim_per_channel ? _fbb.CreateVector<int32_t>(*embedding_dim_per_channel) : 0;
+ return onert_tflite::CreateConcatEmbeddingsOptions(_fbb, num_channels, num_columns_per_channel__,
+ embedding_dim_per_channel__);
+}
+
+struct LSHProjectionOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef LSHProjectionOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_TYPE = 4
+ };
+ onert_tflite::LSHProjectionType type() const
+ {
+ return static_cast<onert_tflite::LSHProjectionType>(GetField<int8_t>(VT_TYPE, 0));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_TYPE, 1) &&
+ verifier.EndTable();
+ }
+};
+
+struct LSHProjectionOptionsBuilder
+{
+ typedef LSHProjectionOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_type(onert_tflite::LSHProjectionType type)
+ {
+ fbb_.AddElement<int8_t>(LSHProjectionOptions::VT_TYPE, static_cast<int8_t>(type), 0);
+ }
+ explicit LSHProjectionOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<LSHProjectionOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<LSHProjectionOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<LSHProjectionOptions> CreateLSHProjectionOptions(
+ ::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::LSHProjectionType type = onert_tflite::LSHProjectionType_UNKNOWN)
+{
+ LSHProjectionOptionsBuilder builder_(_fbb);
+ builder_.add_type(type);
+ return builder_.Finish();
+}
+
+struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SVDFOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_RANK = 4,
+ VT_FUSED_ACTIVATION_FUNCTION = 6,
+ VT_ASYMMETRIC_QUANTIZE_INPUTS = 8
+ };
+ int32_t rank() const { return GetField<int32_t>(VT_RANK, 0); }
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ bool asymmetric_quantize_inputs() const
+ {
+ return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_RANK, 4) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
+ VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable();
+ }
+};
+
+struct SVDFOptionsBuilder
+{
+ typedef SVDFOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_rank(int32_t rank) { fbb_.AddElement<int32_t>(SVDFOptions::VT_RANK, rank, 0); }
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs)
+ {
+ fbb_.AddElement<uint8_t>(SVDFOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS,
+ static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
+ }
+ explicit SVDFOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SVDFOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SVDFOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SVDFOptions>
+CreateSVDFOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t rank = 0,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE,
+ bool asymmetric_quantize_inputs = false)
+{
+ SVDFOptionsBuilder builder_(_fbb);
+ builder_.add_rank(rank);
+ builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+struct RNNOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef RNNOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_FUSED_ACTIVATION_FUNCTION = 4,
+ VT_ASYMMETRIC_QUANTIZE_INPUTS = 6
+ };
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ bool asymmetric_quantize_inputs() const
+ {
+ return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
+ VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable();
+ }
+};
+
+struct RNNOptionsBuilder
+{
+ typedef RNNOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs)
+ {
+ fbb_.AddElement<uint8_t>(RNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS,
+ static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
+ }
+ explicit RNNOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<RNNOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<RNNOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<RNNOptions>
+CreateRNNOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE,
+ bool asymmetric_quantize_inputs = false)
+{
+ RNNOptionsBuilder builder_(_fbb);
+ builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SequenceRNNOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_TIME_MAJOR = 4,
+ VT_FUSED_ACTIVATION_FUNCTION = 6,
+ VT_ASYMMETRIC_QUANTIZE_INPUTS = 8
+ };
+ bool time_major() const { return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; }
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ bool asymmetric_quantize_inputs() const
+ {
+ return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_TIME_MAJOR, 1) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
+ VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable();
+ }
+};
+
+struct SequenceRNNOptionsBuilder
+{
+ typedef SequenceRNNOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_time_major(bool time_major)
+ {
+ fbb_.AddElement<uint8_t>(SequenceRNNOptions::VT_TIME_MAJOR, static_cast<uint8_t>(time_major),
+ 0);
+ }
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs)
+ {
+ fbb_.AddElement<uint8_t>(SequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS,
+ static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
+ }
+ explicit SequenceRNNOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SequenceRNNOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SequenceRNNOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SequenceRNNOptions>
+CreateSequenceRNNOptions(::flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE,
+ bool asymmetric_quantize_inputs = false)
+{
+ SequenceRNNOptionsBuilder builder_(_fbb);
+ builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
+ builder_.add_fused_activation_function(fused_activation_function);
+ builder_.add_time_major(time_major);
+ return builder_.Finish();
+}
+
+struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef BidirectionalSequenceRNNOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_TIME_MAJOR = 4,
+ VT_FUSED_ACTIVATION_FUNCTION = 6,
+ VT_MERGE_OUTPUTS = 8,
+ VT_ASYMMETRIC_QUANTIZE_INPUTS = 10
+ };
+ bool time_major() const { return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; }
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ bool merge_outputs() const { return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0; }
+ bool asymmetric_quantize_inputs() const
+ {
+ return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_TIME_MAJOR, 1) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
+ VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS, 1) &&
+ VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable();
+ }
+};
+
+struct BidirectionalSequenceRNNOptionsBuilder
+{
+ typedef BidirectionalSequenceRNNOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_time_major(bool time_major)
+ {
+ fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_TIME_MAJOR,
+ static_cast<uint8_t>(time_major), 0);
+ }
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(BidirectionalSequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_merge_outputs(bool merge_outputs)
+ {
+ fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS,
+ static_cast<uint8_t>(merge_outputs), 0);
+ }
+ void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs)
+ {
+ fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS,
+ static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
+ }
+ explicit BidirectionalSequenceRNNOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<BidirectionalSequenceRNNOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<BidirectionalSequenceRNNOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalSequenceRNNOptions(
+ ::flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE,
+ bool merge_outputs = false, bool asymmetric_quantize_inputs = false)
+{
+ BidirectionalSequenceRNNOptionsBuilder builder_(_fbb);
+ builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
+ builder_.add_merge_outputs(merge_outputs);
+ builder_.add_fused_activation_function(fused_activation_function);
+ builder_.add_time_major(time_major);
+ return builder_.Finish();
+}
+
+struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef FullyConnectedOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_FUSED_ACTIVATION_FUNCTION = 4,
+ VT_WEIGHTS_FORMAT = 6,
+ VT_KEEP_NUM_DIMS = 8,
+ VT_ASYMMETRIC_QUANTIZE_INPUTS = 10
+ };
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ onert_tflite::FullyConnectedOptionsWeightsFormat weights_format() const
+ {
+ return static_cast<onert_tflite::FullyConnectedOptionsWeightsFormat>(
+ GetField<int8_t>(VT_WEIGHTS_FORMAT, 0));
+ }
+ bool keep_num_dims() const { return GetField<uint8_t>(VT_KEEP_NUM_DIMS, 0) != 0; }
+ bool asymmetric_quantize_inputs() const
+ {
+ return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
+ VerifyField<int8_t>(verifier, VT_WEIGHTS_FORMAT, 1) &&
+ VerifyField<uint8_t>(verifier, VT_KEEP_NUM_DIMS, 1) &&
+ VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable();
+ }
+};
+
+struct FullyConnectedOptionsBuilder
+{
+ typedef FullyConnectedOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(FullyConnectedOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_weights_format(onert_tflite::FullyConnectedOptionsWeightsFormat weights_format)
+ {
+ fbb_.AddElement<int8_t>(FullyConnectedOptions::VT_WEIGHTS_FORMAT,
+ static_cast<int8_t>(weights_format), 0);
+ }
+ void add_keep_num_dims(bool keep_num_dims)
+ {
+ fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_KEEP_NUM_DIMS,
+ static_cast<uint8_t>(keep_num_dims), 0);
+ }
+ void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs)
+ {
+ fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS,
+ static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
+ }
+ explicit FullyConnectedOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<FullyConnectedOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<FullyConnectedOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<FullyConnectedOptions>
+CreateFullyConnectedOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE,
+ onert_tflite::FullyConnectedOptionsWeightsFormat weights_format =
+ onert_tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
+ bool keep_num_dims = false, bool asymmetric_quantize_inputs = false)
+{
+ FullyConnectedOptionsBuilder builder_(_fbb);
+ builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
+ builder_.add_keep_num_dims(keep_num_dims);
+ builder_.add_weights_format(weights_format);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+struct SoftmaxOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SoftmaxOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_BETA = 4
+ };
+ float beta() const { return GetField<float>(VT_BETA, 0.0f); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<float>(verifier, VT_BETA, 4) &&
+ verifier.EndTable();
+ }
+};
+
+struct SoftmaxOptionsBuilder
+{
+ typedef SoftmaxOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_beta(float beta) { fbb_.AddElement<float>(SoftmaxOptions::VT_BETA, beta, 0.0f); }
+ explicit SoftmaxOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SoftmaxOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SoftmaxOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SoftmaxOptions>
+CreateSoftmaxOptions(::flatbuffers::FlatBufferBuilder &_fbb, float beta = 0.0f)
+{
+ SoftmaxOptionsBuilder builder_(_fbb);
+ builder_.add_beta(beta);
+ return builder_.Finish();
+}
+
+struct ConcatenationOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ConcatenationOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_AXIS = 4,
+ VT_FUSED_ACTIVATION_FUNCTION = 6
+ };
+ int32_t axis() const { return GetField<int32_t>(VT_AXIS, 0); }
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_AXIS, 4) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && verifier.EndTable();
+ }
+};
+
+struct ConcatenationOptionsBuilder
+{
+ typedef ConcatenationOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_axis(int32_t axis) { fbb_.AddElement<int32_t>(ConcatenationOptions::VT_AXIS, axis, 0); }
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(ConcatenationOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ explicit ConcatenationOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ConcatenationOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ConcatenationOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ConcatenationOptions>
+CreateConcatenationOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t axis = 0,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE)
+{
+ ConcatenationOptionsBuilder builder_(_fbb);
+ builder_.add_axis(axis);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+struct AddOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef AddOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_FUSED_ACTIVATION_FUNCTION = 4,
+ VT_POT_SCALE_INT16 = 6
+ };
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ bool pot_scale_int16() const { return GetField<uint8_t>(VT_POT_SCALE_INT16, 1) != 0; }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
+ VerifyField<uint8_t>(verifier, VT_POT_SCALE_INT16, 1) && verifier.EndTable();
+ }
+};
+
+struct AddOptionsBuilder
+{
+ typedef AddOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(AddOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_pot_scale_int16(bool pot_scale_int16)
+ {
+ fbb_.AddElement<uint8_t>(AddOptions::VT_POT_SCALE_INT16, static_cast<uint8_t>(pot_scale_int16),
+ 1);
+ }
+ explicit AddOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<AddOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<AddOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<AddOptions>
+CreateAddOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE,
+ bool pot_scale_int16 = true)
+{
+ AddOptionsBuilder builder_(_fbb);
+ builder_.add_pot_scale_int16(pot_scale_int16);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+struct MulOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef MulOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_FUSED_ACTIVATION_FUNCTION = 4
+ };
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && verifier.EndTable();
+ }
+};
+
+struct MulOptionsBuilder
+{
+ typedef MulOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(MulOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ explicit MulOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<MulOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<MulOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<MulOptions>
+CreateMulOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE)
+{
+ MulOptionsBuilder builder_(_fbb);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+struct L2NormOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef L2NormOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_FUSED_ACTIVATION_FUNCTION = 4
+ };
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && verifier.EndTable();
+ }
+};
+
+struct L2NormOptionsBuilder
+{
+ typedef L2NormOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(L2NormOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ explicit L2NormOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<L2NormOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<L2NormOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<L2NormOptions>
+CreateL2NormOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE)
+{
+ L2NormOptionsBuilder builder_(_fbb);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+struct LocalResponseNormalizationOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef LocalResponseNormalizationOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_RADIUS = 4,
+ VT_BIAS = 6,
+ VT_ALPHA = 8,
+ VT_BETA = 10
+ };
+ int32_t radius() const { return GetField<int32_t>(VT_RADIUS, 0); }
+ float bias() const { return GetField<float>(VT_BIAS, 0.0f); }
+ float alpha() const { return GetField<float>(VT_ALPHA, 0.0f); }
+ float beta() const { return GetField<float>(VT_BETA, 0.0f); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_RADIUS, 4) &&
+ VerifyField<float>(verifier, VT_BIAS, 4) && VerifyField<float>(verifier, VT_ALPHA, 4) &&
+ VerifyField<float>(verifier, VT_BETA, 4) && verifier.EndTable();
+ }
+};
+
+struct LocalResponseNormalizationOptionsBuilder
+{
+ typedef LocalResponseNormalizationOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_radius(int32_t radius)
+ {
+ fbb_.AddElement<int32_t>(LocalResponseNormalizationOptions::VT_RADIUS, radius, 0);
+ }
+ void add_bias(float bias)
+ {
+ fbb_.AddElement<float>(LocalResponseNormalizationOptions::VT_BIAS, bias, 0.0f);
+ }
+ void add_alpha(float alpha)
+ {
+ fbb_.AddElement<float>(LocalResponseNormalizationOptions::VT_ALPHA, alpha, 0.0f);
+ }
+ void add_beta(float beta)
+ {
+ fbb_.AddElement<float>(LocalResponseNormalizationOptions::VT_BETA, beta, 0.0f);
+ }
+ explicit LocalResponseNormalizationOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<LocalResponseNormalizationOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<LocalResponseNormalizationOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<LocalResponseNormalizationOptions>
+CreateLocalResponseNormalizationOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t radius = 0,
+ float bias = 0.0f, float alpha = 0.0f, float beta = 0.0f)
+{
+ LocalResponseNormalizationOptionsBuilder builder_(_fbb);
+ builder_.add_beta(beta);
+ builder_.add_alpha(alpha);
+ builder_.add_bias(bias);
+ builder_.add_radius(radius);
+ return builder_.Finish();
+}
+
+struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef LSTMOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_FUSED_ACTIVATION_FUNCTION = 4,
+ VT_CELL_CLIP = 6,
+ VT_PROJ_CLIP = 8,
+ VT_KERNEL_TYPE = 10,
+ VT_ASYMMETRIC_QUANTIZE_INPUTS = 12
+ };
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ float cell_clip() const { return GetField<float>(VT_CELL_CLIP, 0.0f); }
+ float proj_clip() const { return GetField<float>(VT_PROJ_CLIP, 0.0f); }
+ onert_tflite::LSTMKernelType kernel_type() const
+ {
+ return static_cast<onert_tflite::LSTMKernelType>(GetField<int8_t>(VT_KERNEL_TYPE, 0));
+ }
+ bool asymmetric_quantize_inputs() const
+ {
+ return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
+ VerifyField<float>(verifier, VT_CELL_CLIP, 4) &&
+ VerifyField<float>(verifier, VT_PROJ_CLIP, 4) &&
+ VerifyField<int8_t>(verifier, VT_KERNEL_TYPE, 1) &&
+ VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable();
+ }
+};
+
+struct LSTMOptionsBuilder
+{
+ typedef LSTMOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(LSTMOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_cell_clip(float cell_clip)
+ {
+ fbb_.AddElement<float>(LSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f);
+ }
+ void add_proj_clip(float proj_clip)
+ {
+ fbb_.AddElement<float>(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
+ }
+ void add_kernel_type(onert_tflite::LSTMKernelType kernel_type)
+ {
+ fbb_.AddElement<int8_t>(LSTMOptions::VT_KERNEL_TYPE, static_cast<int8_t>(kernel_type), 0);
+ }
+ void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs)
+ {
+ fbb_.AddElement<uint8_t>(LSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS,
+ static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
+ }
+ explicit LSTMOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<LSTMOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<LSTMOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<LSTMOptions>
+CreateLSTMOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE,
+ float cell_clip = 0.0f, float proj_clip = 0.0f,
+ onert_tflite::LSTMKernelType kernel_type = onert_tflite::LSTMKernelType_FULL,
+ bool asymmetric_quantize_inputs = false)
+{
+ LSTMOptionsBuilder builder_(_fbb);
+ builder_.add_proj_clip(proj_clip);
+ builder_.add_cell_clip(cell_clip);
+ builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
+ builder_.add_kernel_type(kernel_type);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef UnidirectionalSequenceLSTMOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_FUSED_ACTIVATION_FUNCTION = 4,
+ VT_CELL_CLIP = 6,
+ VT_PROJ_CLIP = 8,
+ VT_TIME_MAJOR = 10,
+ VT_ASYMMETRIC_QUANTIZE_INPUTS = 12
+ };
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ float cell_clip() const { return GetField<float>(VT_CELL_CLIP, 0.0f); }
+ float proj_clip() const { return GetField<float>(VT_PROJ_CLIP, 0.0f); }
+ bool time_major() const { return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; }
+ bool asymmetric_quantize_inputs() const
+ {
+ return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
+ VerifyField<float>(verifier, VT_CELL_CLIP, 4) &&
+ VerifyField<float>(verifier, VT_PROJ_CLIP, 4) &&
+ VerifyField<uint8_t>(verifier, VT_TIME_MAJOR, 1) &&
+ VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable();
+ }
+};
+
+struct UnidirectionalSequenceLSTMOptionsBuilder
+{
+ typedef UnidirectionalSequenceLSTMOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(UnidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_cell_clip(float cell_clip)
+ {
+ fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f);
+ }
+ void add_proj_clip(float proj_clip)
+ {
+ fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
+ }
+ void add_time_major(bool time_major)
+ {
+ fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR,
+ static_cast<uint8_t>(time_major), 0);
+ }
+ void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs)
+ {
+ fbb_.AddElement<uint8_t>(UnidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS,
+ static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
+ }
+ explicit UnidirectionalSequenceLSTMOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<UnidirectionalSequenceLSTMOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<UnidirectionalSequenceLSTMOptions>
+CreateUnidirectionalSequenceLSTMOptions(
+ ::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE,
+ float cell_clip = 0.0f, float proj_clip = 0.0f, bool time_major = false,
+ bool asymmetric_quantize_inputs = false)
+{
+ UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
+ builder_.add_proj_clip(proj_clip);
+ builder_.add_cell_clip(cell_clip);
+ builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
+ builder_.add_time_major(time_major);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef BidirectionalSequenceLSTMOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_FUSED_ACTIVATION_FUNCTION = 4,
+ VT_CELL_CLIP = 6,
+ VT_PROJ_CLIP = 8,
+ VT_MERGE_OUTPUTS = 10,
+ VT_TIME_MAJOR = 12,
+ VT_ASYMMETRIC_QUANTIZE_INPUTS = 14
+ };
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ float cell_clip() const { return GetField<float>(VT_CELL_CLIP, 0.0f); }
+ float proj_clip() const { return GetField<float>(VT_PROJ_CLIP, 0.0f); }
+ bool merge_outputs() const { return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0; }
+ bool time_major() const { return GetField<uint8_t>(VT_TIME_MAJOR, 1) != 0; }
+ bool asymmetric_quantize_inputs() const
+ {
+ return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
+ VerifyField<float>(verifier, VT_CELL_CLIP, 4) &&
+ VerifyField<float>(verifier, VT_PROJ_CLIP, 4) &&
+ VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS, 1) &&
+ VerifyField<uint8_t>(verifier, VT_TIME_MAJOR, 1) &&
+ VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable();
+ }
+};
+
+struct BidirectionalSequenceLSTMOptionsBuilder
+{
+ typedef BidirectionalSequenceLSTMOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(BidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_cell_clip(float cell_clip)
+ {
+ fbb_.AddElement<float>(BidirectionalSequenceLSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f);
+ }
+ void add_proj_clip(float proj_clip)
+ {
+ fbb_.AddElement<float>(BidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
+ }
+ void add_merge_outputs(bool merge_outputs)
+ {
+ fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_MERGE_OUTPUTS,
+ static_cast<uint8_t>(merge_outputs), 0);
+ }
+ void add_time_major(bool time_major)
+ {
+ fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_TIME_MAJOR,
+ static_cast<uint8_t>(time_major), 1);
+ }
+ void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs)
+ {
+ fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS,
+ static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
+ }
+ explicit BidirectionalSequenceLSTMOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<BidirectionalSequenceLSTMOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<BidirectionalSequenceLSTMOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<BidirectionalSequenceLSTMOptions>
+CreateBidirectionalSequenceLSTMOptions(
+ ::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE,
+ float cell_clip = 0.0f, float proj_clip = 0.0f, bool merge_outputs = false,
+ bool time_major = true, bool asymmetric_quantize_inputs = false)
+{
+ BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
+ builder_.add_proj_clip(proj_clip);
+ builder_.add_cell_clip(cell_clip);
+ builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
+ builder_.add_time_major(time_major);
+ builder_.add_merge_outputs(merge_outputs);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ResizeBilinearOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_ALIGN_CORNERS = 8,
+ VT_HALF_PIXEL_CENTERS = 10
+ };
+ bool align_corners() const { return GetField<uint8_t>(VT_ALIGN_CORNERS, 0) != 0; }
+ bool half_pixel_centers() const { return GetField<uint8_t>(VT_HALF_PIXEL_CENTERS, 0) != 0; }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_ALIGN_CORNERS, 1) &&
+ VerifyField<uint8_t>(verifier, VT_HALF_PIXEL_CENTERS, 1) && verifier.EndTable();
+ }
+};
+
+struct ResizeBilinearOptionsBuilder
+{
+ typedef ResizeBilinearOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_align_corners(bool align_corners)
+ {
+ fbb_.AddElement<uint8_t>(ResizeBilinearOptions::VT_ALIGN_CORNERS,
+ static_cast<uint8_t>(align_corners), 0);
+ }
+ void add_half_pixel_centers(bool half_pixel_centers)
+ {
+ fbb_.AddElement<uint8_t>(ResizeBilinearOptions::VT_HALF_PIXEL_CENTERS,
+ static_cast<uint8_t>(half_pixel_centers), 0);
+ }
+ explicit ResizeBilinearOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ResizeBilinearOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ResizeBilinearOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ResizeBilinearOptions>
+CreateResizeBilinearOptions(::flatbuffers::FlatBufferBuilder &_fbb, bool align_corners = false,
+ bool half_pixel_centers = false)
+{
+ ResizeBilinearOptionsBuilder builder_(_fbb);
+ builder_.add_half_pixel_centers(half_pixel_centers);
+ builder_.add_align_corners(align_corners);
+ return builder_.Finish();
+}
+
+struct ResizeNearestNeighborOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ResizeNearestNeighborOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_ALIGN_CORNERS = 4,
+ VT_HALF_PIXEL_CENTERS = 6
+ };
+ bool align_corners() const { return GetField<uint8_t>(VT_ALIGN_CORNERS, 0) != 0; }
+ bool half_pixel_centers() const { return GetField<uint8_t>(VT_HALF_PIXEL_CENTERS, 0) != 0; }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_ALIGN_CORNERS, 1) &&
+ VerifyField<uint8_t>(verifier, VT_HALF_PIXEL_CENTERS, 1) && verifier.EndTable();
+ }
+};
+
+struct ResizeNearestNeighborOptionsBuilder
+{
+ typedef ResizeNearestNeighborOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_align_corners(bool align_corners)
+ {
+ fbb_.AddElement<uint8_t>(ResizeNearestNeighborOptions::VT_ALIGN_CORNERS,
+ static_cast<uint8_t>(align_corners), 0);
+ }
+ void add_half_pixel_centers(bool half_pixel_centers)
+ {
+ fbb_.AddElement<uint8_t>(ResizeNearestNeighborOptions::VT_HALF_PIXEL_CENTERS,
+ static_cast<uint8_t>(half_pixel_centers), 0);
+ }
+ explicit ResizeNearestNeighborOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ResizeNearestNeighborOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ResizeNearestNeighborOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ResizeNearestNeighborOptions>
+CreateResizeNearestNeighborOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ bool align_corners = false, bool half_pixel_centers = false)
+{
+ ResizeNearestNeighborOptionsBuilder builder_(_fbb);
+ builder_.add_half_pixel_centers(half_pixel_centers);
+ builder_.add_align_corners(align_corners);
+ return builder_.Finish();
+}
+
+struct CallOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef CallOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_SUBGRAPH = 4
+ };
+ uint32_t subgraph() const { return GetField<uint32_t>(VT_SUBGRAPH, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<uint32_t>(verifier, VT_SUBGRAPH, 4) &&
+ verifier.EndTable();
+ }
+};
+
+struct CallOptionsBuilder
+{
+ typedef CallOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_subgraph(uint32_t subgraph)
+ {
+ fbb_.AddElement<uint32_t>(CallOptions::VT_SUBGRAPH, subgraph, 0);
+ }
+ explicit CallOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<CallOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<CallOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<CallOptions> CreateCallOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ uint32_t subgraph = 0)
+{
+ CallOptionsBuilder builder_(_fbb);
+ builder_.add_subgraph(subgraph);
+ return builder_.Finish();
+}
+
+struct PadOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef PadOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct PadOptionsBuilder
+{
+ typedef PadOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit PadOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<PadOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<PadOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<PadOptions> CreatePadOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ PadOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct PadV2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef PadV2OptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct PadV2OptionsBuilder
+{
+ typedef PadV2Options Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit PadV2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<PadV2Options> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<PadV2Options>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<PadV2Options>
+CreatePadV2Options(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ PadV2OptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct ReshapeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ReshapeOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_NEW_SHAPE = 4
+ };
+ const ::flatbuffers::Vector<int32_t> *new_shape() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_NEW_SHAPE);
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NEW_SHAPE) &&
+ verifier.VerifyVector(new_shape()) && verifier.EndTable();
+ }
+};
+
+struct ReshapeOptionsBuilder
+{
+ typedef ReshapeOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_new_shape(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> new_shape)
+ {
+ fbb_.AddOffset(ReshapeOptions::VT_NEW_SHAPE, new_shape);
+ }
+ explicit ReshapeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ReshapeOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ReshapeOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ReshapeOptions>
+CreateReshapeOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> new_shape = 0)
+{
+ ReshapeOptionsBuilder builder_(_fbb);
+ builder_.add_new_shape(new_shape);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<ReshapeOptions>
+CreateReshapeOptionsDirect(::flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *new_shape = nullptr)
+{
+ auto new_shape__ = new_shape ? _fbb.CreateVector<int32_t>(*new_shape) : 0;
+ return onert_tflite::CreateReshapeOptions(_fbb, new_shape__);
+}
+
+struct SpaceToBatchNDOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SpaceToBatchNDOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct SpaceToBatchNDOptionsBuilder
+{
+ typedef SpaceToBatchNDOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit SpaceToBatchNDOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SpaceToBatchNDOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SpaceToBatchNDOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SpaceToBatchNDOptions>
+CreateSpaceToBatchNDOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ SpaceToBatchNDOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef BatchToSpaceNDOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct BatchToSpaceNDOptionsBuilder
+{
+ typedef BatchToSpaceNDOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit BatchToSpaceNDOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<BatchToSpaceNDOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<BatchToSpaceNDOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<BatchToSpaceNDOptions>
+CreateBatchToSpaceNDOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ BatchToSpaceNDOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct SkipGramOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SkipGramOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_NGRAM_SIZE = 4,
+ VT_MAX_SKIP_SIZE = 6,
+ VT_INCLUDE_ALL_NGRAMS = 8
+ };
+ int32_t ngram_size() const { return GetField<int32_t>(VT_NGRAM_SIZE, 0); }
+ int32_t max_skip_size() const { return GetField<int32_t>(VT_MAX_SKIP_SIZE, 0); }
+ bool include_all_ngrams() const { return GetField<uint8_t>(VT_INCLUDE_ALL_NGRAMS, 0) != 0; }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_NGRAM_SIZE, 4) &&
+ VerifyField<int32_t>(verifier, VT_MAX_SKIP_SIZE, 4) &&
+ VerifyField<uint8_t>(verifier, VT_INCLUDE_ALL_NGRAMS, 1) && verifier.EndTable();
+ }
+};
+
+struct SkipGramOptionsBuilder
+{
+ typedef SkipGramOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_ngram_size(int32_t ngram_size)
+ {
+ fbb_.AddElement<int32_t>(SkipGramOptions::VT_NGRAM_SIZE, ngram_size, 0);
+ }
+ void add_max_skip_size(int32_t max_skip_size)
+ {
+ fbb_.AddElement<int32_t>(SkipGramOptions::VT_MAX_SKIP_SIZE, max_skip_size, 0);
+ }
+ void add_include_all_ngrams(bool include_all_ngrams)
+ {
+ fbb_.AddElement<uint8_t>(SkipGramOptions::VT_INCLUDE_ALL_NGRAMS,
+ static_cast<uint8_t>(include_all_ngrams), 0);
+ }
+ explicit SkipGramOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SkipGramOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SkipGramOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SkipGramOptions>
+CreateSkipGramOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t ngram_size = 0,
+ int32_t max_skip_size = 0, bool include_all_ngrams = false)
+{
+ SkipGramOptionsBuilder builder_(_fbb);
+ builder_.add_max_skip_size(max_skip_size);
+ builder_.add_ngram_size(ngram_size);
+ builder_.add_include_all_ngrams(include_all_ngrams);
+ return builder_.Finish();
+}
+
+struct SpaceToDepthOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SpaceToDepthOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_BLOCK_SIZE = 4
+ };
+ int32_t block_size() const { return GetField<int32_t>(VT_BLOCK_SIZE, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_BLOCK_SIZE, 4) &&
+ verifier.EndTable();
+ }
+};
+
+struct SpaceToDepthOptionsBuilder
+{
+ typedef SpaceToDepthOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_block_size(int32_t block_size)
+ {
+ fbb_.AddElement<int32_t>(SpaceToDepthOptions::VT_BLOCK_SIZE, block_size, 0);
+ }
+ explicit SpaceToDepthOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SpaceToDepthOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SpaceToDepthOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SpaceToDepthOptions>
+CreateSpaceToDepthOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t block_size = 0)
+{
+ SpaceToDepthOptionsBuilder builder_(_fbb);
+ builder_.add_block_size(block_size);
+ return builder_.Finish();
+}
+
+struct DepthToSpaceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef DepthToSpaceOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_BLOCK_SIZE = 4
+ };
+ int32_t block_size() const { return GetField<int32_t>(VT_BLOCK_SIZE, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_BLOCK_SIZE, 4) &&
+ verifier.EndTable();
+ }
+};
+
+struct DepthToSpaceOptionsBuilder
+{
+ typedef DepthToSpaceOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_block_size(int32_t block_size)
+ {
+ fbb_.AddElement<int32_t>(DepthToSpaceOptions::VT_BLOCK_SIZE, block_size, 0);
+ }
+ explicit DepthToSpaceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<DepthToSpaceOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<DepthToSpaceOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<DepthToSpaceOptions>
+CreateDepthToSpaceOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t block_size = 0)
+{
+ DepthToSpaceOptionsBuilder builder_(_fbb);
+ builder_.add_block_size(block_size);
+ return builder_.Finish();
+}
+
+struct SubOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SubOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_FUSED_ACTIVATION_FUNCTION = 4,
+ VT_POT_SCALE_INT16 = 6
+ };
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ bool pot_scale_int16() const { return GetField<uint8_t>(VT_POT_SCALE_INT16, 1) != 0; }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
+ VerifyField<uint8_t>(verifier, VT_POT_SCALE_INT16, 1) && verifier.EndTable();
+ }
+};
+
+struct SubOptionsBuilder
+{
+ typedef SubOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(SubOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_pot_scale_int16(bool pot_scale_int16)
+ {
+ fbb_.AddElement<uint8_t>(SubOptions::VT_POT_SCALE_INT16, static_cast<uint8_t>(pot_scale_int16),
+ 1);
+ }
+ explicit SubOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SubOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SubOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SubOptions>
+CreateSubOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE,
+ bool pot_scale_int16 = true)
+{
+ SubOptionsBuilder builder_(_fbb);
+ builder_.add_pot_scale_int16(pot_scale_int16);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+struct DivOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef DivOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_FUSED_ACTIVATION_FUNCTION = 4
+ };
+ onert_tflite::ActivationFunctionType fused_activation_function() const
+ {
+ return static_cast<onert_tflite::ActivationFunctionType>(
+ GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && verifier.EndTable();
+ }
+};
+
+struct DivOptionsBuilder
+{
+ typedef DivOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_fused_activation_function(onert_tflite::ActivationFunctionType fused_activation_function)
+ {
+ fbb_.AddElement<int8_t>(DivOptions::VT_FUSED_ACTIVATION_FUNCTION,
+ static_cast<int8_t>(fused_activation_function), 0);
+ }
+ explicit DivOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<DivOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<DivOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<DivOptions>
+CreateDivOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::ActivationFunctionType fused_activation_function =
+ onert_tflite::ActivationFunctionType_NONE)
+{
+ DivOptionsBuilder builder_(_fbb);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+struct TopKV2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef TopKV2OptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct TopKV2OptionsBuilder
+{
+ typedef TopKV2Options Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit TopKV2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<TopKV2Options> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<TopKV2Options>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<TopKV2Options>
+CreateTopKV2Options(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ TopKV2OptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct EmbeddingLookupSparseOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef EmbeddingLookupSparseOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_COMBINER = 4
+ };
+ onert_tflite::CombinerType combiner() const
+ {
+ return static_cast<onert_tflite::CombinerType>(GetField<int8_t>(VT_COMBINER, 0));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_COMBINER, 1) &&
+ verifier.EndTable();
+ }
+};
+
+struct EmbeddingLookupSparseOptionsBuilder
+{
+ typedef EmbeddingLookupSparseOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_combiner(onert_tflite::CombinerType combiner)
+ {
+ fbb_.AddElement<int8_t>(EmbeddingLookupSparseOptions::VT_COMBINER,
+ static_cast<int8_t>(combiner), 0);
+ }
+ explicit EmbeddingLookupSparseOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<EmbeddingLookupSparseOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<EmbeddingLookupSparseOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<EmbeddingLookupSparseOptions> CreateEmbeddingLookupSparseOptions(
+ ::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::CombinerType combiner = onert_tflite::CombinerType_SUM)
+{
+ EmbeddingLookupSparseOptionsBuilder builder_(_fbb);
+ builder_.add_combiner(combiner);
+ return builder_.Finish();
+}
+
+struct GatherOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef GatherOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_AXIS = 4,
+ VT_BATCH_DIMS = 6
+ };
+ int32_t axis() const { return GetField<int32_t>(VT_AXIS, 0); }
+ int32_t batch_dims() const { return GetField<int32_t>(VT_BATCH_DIMS, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_AXIS, 4) &&
+ VerifyField<int32_t>(verifier, VT_BATCH_DIMS, 4) && verifier.EndTable();
+ }
+};
+
+struct GatherOptionsBuilder
+{
+ typedef GatherOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_axis(int32_t axis) { fbb_.AddElement<int32_t>(GatherOptions::VT_AXIS, axis, 0); }
+ void add_batch_dims(int32_t batch_dims)
+ {
+ fbb_.AddElement<int32_t>(GatherOptions::VT_BATCH_DIMS, batch_dims, 0);
+ }
+ explicit GatherOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<GatherOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<GatherOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<GatherOptions>
+CreateGatherOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t axis = 0,
+ int32_t batch_dims = 0)
+{
+ GatherOptionsBuilder builder_(_fbb);
+ builder_.add_batch_dims(batch_dims);
+ builder_.add_axis(axis);
+ return builder_.Finish();
+}
+
+struct TransposeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef TransposeOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct TransposeOptionsBuilder
+{
+ typedef TransposeOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit TransposeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<TransposeOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<TransposeOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<TransposeOptions>
+CreateTransposeOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ TransposeOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct ExpOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ExpOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct ExpOptionsBuilder
+{
+ typedef ExpOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit ExpOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ExpOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ExpOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ExpOptions> CreateExpOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ ExpOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct CosOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef CosOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct CosOptionsBuilder
+{
+ typedef CosOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit CosOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<CosOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<CosOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<CosOptions> CreateCosOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ CosOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct ReducerOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ReducerOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_KEEP_DIMS = 4
+ };
+ bool keep_dims() const { return GetField<uint8_t>(VT_KEEP_DIMS, 0) != 0; }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_KEEP_DIMS, 1) &&
+ verifier.EndTable();
+ }
+};
+
+struct ReducerOptionsBuilder
+{
+ typedef ReducerOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_keep_dims(bool keep_dims)
+ {
+ fbb_.AddElement<uint8_t>(ReducerOptions::VT_KEEP_DIMS, static_cast<uint8_t>(keep_dims), 0);
+ }
+ explicit ReducerOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ReducerOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ReducerOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ReducerOptions>
+CreateReducerOptions(::flatbuffers::FlatBufferBuilder &_fbb, bool keep_dims = false)
+{
+ ReducerOptionsBuilder builder_(_fbb);
+ builder_.add_keep_dims(keep_dims);
+ return builder_.Finish();
+}
+
+struct SqueezeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SqueezeOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_SQUEEZE_DIMS = 4
+ };
+ const ::flatbuffers::Vector<int32_t> *squeeze_dims() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_SQUEEZE_DIMS);
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SQUEEZE_DIMS) &&
+ verifier.VerifyVector(squeeze_dims()) && verifier.EndTable();
+ }
+};
+
+struct SqueezeOptionsBuilder
+{
+ typedef SqueezeOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_squeeze_dims(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> squeeze_dims)
+ {
+ fbb_.AddOffset(SqueezeOptions::VT_SQUEEZE_DIMS, squeeze_dims);
+ }
+ explicit SqueezeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SqueezeOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SqueezeOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SqueezeOptions>
+CreateSqueezeOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> squeeze_dims = 0)
+{
+ SqueezeOptionsBuilder builder_(_fbb);
+ builder_.add_squeeze_dims(squeeze_dims);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<SqueezeOptions>
+CreateSqueezeOptionsDirect(::flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *squeeze_dims = nullptr)
+{
+ auto squeeze_dims__ = squeeze_dims ? _fbb.CreateVector<int32_t>(*squeeze_dims) : 0;
+ return onert_tflite::CreateSqueezeOptions(_fbb, squeeze_dims__);
+}
+
+struct SplitOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SplitOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_NUM_SPLITS = 4
+ };
+ int32_t num_splits() const { return GetField<int32_t>(VT_NUM_SPLITS, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_NUM_SPLITS, 4) &&
+ verifier.EndTable();
+ }
+};
+
+struct SplitOptionsBuilder
+{
+ typedef SplitOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_num_splits(int32_t num_splits)
+ {
+ fbb_.AddElement<int32_t>(SplitOptions::VT_NUM_SPLITS, num_splits, 0);
+ }
+ explicit SplitOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SplitOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SplitOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SplitOptions>
+CreateSplitOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t num_splits = 0)
+{
+ SplitOptionsBuilder builder_(_fbb);
+ builder_.add_num_splits(num_splits);
+ return builder_.Finish();
+}
+
+struct SplitVOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SplitVOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_NUM_SPLITS = 4
+ };
+ int32_t num_splits() const { return GetField<int32_t>(VT_NUM_SPLITS, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_NUM_SPLITS, 4) &&
+ verifier.EndTable();
+ }
+};
+
+struct SplitVOptionsBuilder
+{
+ typedef SplitVOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_num_splits(int32_t num_splits)
+ {
+ fbb_.AddElement<int32_t>(SplitVOptions::VT_NUM_SPLITS, num_splits, 0);
+ }
+ explicit SplitVOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SplitVOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SplitVOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SplitVOptions>
+CreateSplitVOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t num_splits = 0)
+{
+ SplitVOptionsBuilder builder_(_fbb);
+ builder_.add_num_splits(num_splits);
+ return builder_.Finish();
+}
+
+struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef StridedSliceOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_BEGIN_MASK = 4,
+ VT_END_MASK = 6,
+ VT_ELLIPSIS_MASK = 8,
+ VT_NEW_AXIS_MASK = 10,
+ VT_SHRINK_AXIS_MASK = 12
+ };
+ int32_t begin_mask() const { return GetField<int32_t>(VT_BEGIN_MASK, 0); }
+ int32_t end_mask() const { return GetField<int32_t>(VT_END_MASK, 0); }
+ int32_t ellipsis_mask() const { return GetField<int32_t>(VT_ELLIPSIS_MASK, 0); }
+ int32_t new_axis_mask() const { return GetField<int32_t>(VT_NEW_AXIS_MASK, 0); }
+ int32_t shrink_axis_mask() const { return GetField<int32_t>(VT_SHRINK_AXIS_MASK, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_BEGIN_MASK, 4) &&
+ VerifyField<int32_t>(verifier, VT_END_MASK, 4) &&
+ VerifyField<int32_t>(verifier, VT_ELLIPSIS_MASK, 4) &&
+ VerifyField<int32_t>(verifier, VT_NEW_AXIS_MASK, 4) &&
+ VerifyField<int32_t>(verifier, VT_SHRINK_AXIS_MASK, 4) && verifier.EndTable();
+ }
+};
+
+struct StridedSliceOptionsBuilder
+{
+ typedef StridedSliceOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_begin_mask(int32_t begin_mask)
+ {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_BEGIN_MASK, begin_mask, 0);
+ }
+ void add_end_mask(int32_t end_mask)
+ {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_END_MASK, end_mask, 0);
+ }
+ void add_ellipsis_mask(int32_t ellipsis_mask)
+ {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_ELLIPSIS_MASK, ellipsis_mask, 0);
+ }
+ void add_new_axis_mask(int32_t new_axis_mask)
+ {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_NEW_AXIS_MASK, new_axis_mask, 0);
+ }
+ void add_shrink_axis_mask(int32_t shrink_axis_mask)
+ {
+ fbb_.AddElement<int32_t>(StridedSliceOptions::VT_SHRINK_AXIS_MASK, shrink_axis_mask, 0);
+ }
+ explicit StridedSliceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<StridedSliceOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<StridedSliceOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<StridedSliceOptions>
+CreateStridedSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t begin_mask = 0,
+ int32_t end_mask = 0, int32_t ellipsis_mask = 0,
+ int32_t new_axis_mask = 0, int32_t shrink_axis_mask = 0)
+{
+ StridedSliceOptionsBuilder builder_(_fbb);
+ builder_.add_shrink_axis_mask(shrink_axis_mask);
+ builder_.add_new_axis_mask(new_axis_mask);
+ builder_.add_ellipsis_mask(ellipsis_mask);
+ builder_.add_end_mask(end_mask);
+ builder_.add_begin_mask(begin_mask);
+ return builder_.Finish();
+}
+
+struct LogSoftmaxOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef LogSoftmaxOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct LogSoftmaxOptionsBuilder
+{
+ typedef LogSoftmaxOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit LogSoftmaxOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<LogSoftmaxOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<LogSoftmaxOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<LogSoftmaxOptions>
+CreateLogSoftmaxOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ LogSoftmaxOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct CastOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef CastOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_IN_DATA_TYPE = 4,
+ VT_OUT_DATA_TYPE = 6
+ };
+ onert_tflite::TensorType in_data_type() const
+ {
+ return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_IN_DATA_TYPE, 0));
+ }
+ onert_tflite::TensorType out_data_type() const
+ {
+ return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_OUT_DATA_TYPE, 0));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_IN_DATA_TYPE, 1) &&
+ VerifyField<int8_t>(verifier, VT_OUT_DATA_TYPE, 1) && verifier.EndTable();
+ }
+};
+
+struct CastOptionsBuilder
+{
+ typedef CastOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_in_data_type(onert_tflite::TensorType in_data_type)
+ {
+ fbb_.AddElement<int8_t>(CastOptions::VT_IN_DATA_TYPE, static_cast<int8_t>(in_data_type), 0);
+ }
+ void add_out_data_type(onert_tflite::TensorType out_data_type)
+ {
+ fbb_.AddElement<int8_t>(CastOptions::VT_OUT_DATA_TYPE, static_cast<int8_t>(out_data_type), 0);
+ }
+ explicit CastOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<CastOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<CastOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<CastOptions>
+CreateCastOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::TensorType in_data_type = onert_tflite::TensorType_FLOAT32,
+ onert_tflite::TensorType out_data_type = onert_tflite::TensorType_FLOAT32)
+{
+ CastOptionsBuilder builder_(_fbb);
+ builder_.add_out_data_type(out_data_type);
+ builder_.add_in_data_type(in_data_type);
+ return builder_.Finish();
+}
+
+struct DequantizeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef DequantizeOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct DequantizeOptionsBuilder
+{
+ typedef DequantizeOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit DequantizeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<DequantizeOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<DequantizeOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<DequantizeOptions>
+CreateDequantizeOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ DequantizeOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct MaximumMinimumOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef MaximumMinimumOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct MaximumMinimumOptionsBuilder
+{
+ typedef MaximumMinimumOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit MaximumMinimumOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<MaximumMinimumOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<MaximumMinimumOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<MaximumMinimumOptions>
+CreateMaximumMinimumOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ MaximumMinimumOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct TileOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef TileOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct TileOptionsBuilder
+{
+ typedef TileOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit TileOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<TileOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<TileOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<TileOptions> CreateTileOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ TileOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct ArgMaxOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ArgMaxOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_OUTPUT_TYPE = 4
+ };
+ onert_tflite::TensorType output_type() const
+ {
+ return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_OUTPUT_TYPE, 0));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_OUTPUT_TYPE, 1) &&
+ verifier.EndTable();
+ }
+};
+
+struct ArgMaxOptionsBuilder
+{
+ typedef ArgMaxOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_output_type(onert_tflite::TensorType output_type)
+ {
+ fbb_.AddElement<int8_t>(ArgMaxOptions::VT_OUTPUT_TYPE, static_cast<int8_t>(output_type), 0);
+ }
+ explicit ArgMaxOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ArgMaxOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ArgMaxOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ArgMaxOptions>
+CreateArgMaxOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::TensorType output_type = onert_tflite::TensorType_FLOAT32)
+{
+ ArgMaxOptionsBuilder builder_(_fbb);
+ builder_.add_output_type(output_type);
+ return builder_.Finish();
+}
+
+struct ArgMinOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ArgMinOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_OUTPUT_TYPE = 4
+ };
+ onert_tflite::TensorType output_type() const
+ {
+ return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_OUTPUT_TYPE, 0));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_OUTPUT_TYPE, 1) &&
+ verifier.EndTable();
+ }
+};
+
+struct ArgMinOptionsBuilder
+{
+ typedef ArgMinOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_output_type(onert_tflite::TensorType output_type)
+ {
+ fbb_.AddElement<int8_t>(ArgMinOptions::VT_OUTPUT_TYPE, static_cast<int8_t>(output_type), 0);
+ }
+ explicit ArgMinOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ArgMinOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ArgMinOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ArgMinOptions>
+CreateArgMinOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::TensorType output_type = onert_tflite::TensorType_FLOAT32)
+{
+ ArgMinOptionsBuilder builder_(_fbb);
+ builder_.add_output_type(output_type);
+ return builder_.Finish();
+}
+
+struct GreaterOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef GreaterOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct GreaterOptionsBuilder
+{
+ typedef GreaterOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit GreaterOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<GreaterOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<GreaterOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<GreaterOptions>
+CreateGreaterOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ GreaterOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct GreaterEqualOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef GreaterEqualOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct GreaterEqualOptionsBuilder
+{
+ typedef GreaterEqualOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit GreaterEqualOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<GreaterEqualOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<GreaterEqualOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<GreaterEqualOptions>
+CreateGreaterEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ GreaterEqualOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct LessOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef LessOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct LessOptionsBuilder
+{
+ typedef LessOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit LessOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<LessOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<LessOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<LessOptions> CreateLessOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ LessOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct LessEqualOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef LessEqualOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct LessEqualOptionsBuilder
+{
+ typedef LessEqualOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit LessEqualOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<LessEqualOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<LessEqualOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<LessEqualOptions>
+CreateLessEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ LessEqualOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct NegOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef NegOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct NegOptionsBuilder
+{
+ typedef NegOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit NegOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<NegOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<NegOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<NegOptions> CreateNegOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ NegOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct SelectOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SelectOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct SelectOptionsBuilder
+{
+ typedef SelectOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit SelectOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SelectOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SelectOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SelectOptions>
+CreateSelectOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ SelectOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct SliceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SliceOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct SliceOptionsBuilder
+{
+ typedef SliceOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit SliceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SliceOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SliceOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SliceOptions>
+CreateSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ SliceOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct TransposeConvOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef TransposeConvOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_PADDING = 4,
+ VT_STRIDE_W = 6,
+ VT_STRIDE_H = 8
+ };
+ onert_tflite::Padding padding() const
+ {
+ return static_cast<onert_tflite::Padding>(GetField<int8_t>(VT_PADDING, 0));
+ }
+ int32_t stride_w() const { return GetField<int32_t>(VT_STRIDE_W, 0); }
+ int32_t stride_h() const { return GetField<int32_t>(VT_STRIDE_H, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_PADDING, 1) &&
+ VerifyField<int32_t>(verifier, VT_STRIDE_W, 4) &&
+ VerifyField<int32_t>(verifier, VT_STRIDE_H, 4) && verifier.EndTable();
+ }
+};
+
+struct TransposeConvOptionsBuilder
+{
+ typedef TransposeConvOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_padding(onert_tflite::Padding padding)
+ {
+ fbb_.AddElement<int8_t>(TransposeConvOptions::VT_PADDING, static_cast<int8_t>(padding), 0);
+ }
+ void add_stride_w(int32_t stride_w)
+ {
+ fbb_.AddElement<int32_t>(TransposeConvOptions::VT_STRIDE_W, stride_w, 0);
+ }
+ void add_stride_h(int32_t stride_h)
+ {
+ fbb_.AddElement<int32_t>(TransposeConvOptions::VT_STRIDE_H, stride_h, 0);
+ }
+ explicit TransposeConvOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<TransposeConvOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<TransposeConvOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<TransposeConvOptions>
+CreateTransposeConvOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::Padding padding = onert_tflite::Padding_SAME,
+ int32_t stride_w = 0, int32_t stride_h = 0)
+{
+ TransposeConvOptionsBuilder builder_(_fbb);
+ builder_.add_stride_h(stride_h);
+ builder_.add_stride_w(stride_w);
+ builder_.add_padding(padding);
+ return builder_.Finish();
+}
+
+struct ExpandDimsOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ExpandDimsOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct ExpandDimsOptionsBuilder
+{
+ typedef ExpandDimsOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit ExpandDimsOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ExpandDimsOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ExpandDimsOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ExpandDimsOptions>
+CreateExpandDimsOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ ExpandDimsOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct SparseToDenseOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SparseToDenseOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_VALIDATE_INDICES = 4
+ };
+ bool validate_indices() const { return GetField<uint8_t>(VT_VALIDATE_INDICES, 0) != 0; }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_VALIDATE_INDICES, 1) &&
+ verifier.EndTable();
+ }
+};
+
+struct SparseToDenseOptionsBuilder
+{
+ typedef SparseToDenseOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_validate_indices(bool validate_indices)
+ {
+ fbb_.AddElement<uint8_t>(SparseToDenseOptions::VT_VALIDATE_INDICES,
+ static_cast<uint8_t>(validate_indices), 0);
+ }
+ explicit SparseToDenseOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SparseToDenseOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SparseToDenseOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SparseToDenseOptions>
+CreateSparseToDenseOptions(::flatbuffers::FlatBufferBuilder &_fbb, bool validate_indices = false)
+{
+ SparseToDenseOptionsBuilder builder_(_fbb);
+ builder_.add_validate_indices(validate_indices);
+ return builder_.Finish();
+}
+
+struct EqualOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef EqualOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct EqualOptionsBuilder
+{
+ typedef EqualOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit EqualOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<EqualOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<EqualOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<EqualOptions>
+CreateEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ EqualOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct NotEqualOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef NotEqualOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct NotEqualOptionsBuilder
+{
+ typedef NotEqualOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit NotEqualOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<NotEqualOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<NotEqualOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<NotEqualOptions>
+CreateNotEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ NotEqualOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct ShapeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ShapeOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_OUT_TYPE = 4
+ };
+ onert_tflite::TensorType out_type() const
+ {
+ return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_OUT_TYPE, 0));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_OUT_TYPE, 1) &&
+ verifier.EndTable();
+ }
+};
+
+struct ShapeOptionsBuilder
+{
+ typedef ShapeOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_out_type(onert_tflite::TensorType out_type)
+ {
+ fbb_.AddElement<int8_t>(ShapeOptions::VT_OUT_TYPE, static_cast<int8_t>(out_type), 0);
+ }
+ explicit ShapeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ShapeOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ShapeOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ShapeOptions>
+CreateShapeOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::TensorType out_type = onert_tflite::TensorType_FLOAT32)
+{
+ ShapeOptionsBuilder builder_(_fbb);
+ builder_.add_out_type(out_type);
+ return builder_.Finish();
+}
+
+struct RankOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef RankOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct RankOptionsBuilder
+{
+ typedef RankOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit RankOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<RankOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<RankOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<RankOptions> CreateRankOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ RankOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct PowOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef PowOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct PowOptionsBuilder
+{
+ typedef PowOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit PowOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<PowOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<PowOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<PowOptions> CreatePowOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ PowOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct FakeQuantOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef FakeQuantOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_MIN = 4,
+ VT_MAX = 6,
+ VT_NUM_BITS = 8,
+ VT_NARROW_RANGE = 10
+ };
+ float min() const { return GetField<float>(VT_MIN, 0.0f); }
+ float max() const { return GetField<float>(VT_MAX, 0.0f); }
+ int32_t num_bits() const { return GetField<int32_t>(VT_NUM_BITS, 0); }
+ bool narrow_range() const { return GetField<uint8_t>(VT_NARROW_RANGE, 0) != 0; }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<float>(verifier, VT_MIN, 4) &&
+ VerifyField<float>(verifier, VT_MAX, 4) &&
+ VerifyField<int32_t>(verifier, VT_NUM_BITS, 4) &&
+ VerifyField<uint8_t>(verifier, VT_NARROW_RANGE, 1) && verifier.EndTable();
+ }
+};
+
+struct FakeQuantOptionsBuilder
+{
+ typedef FakeQuantOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_min(float min) { fbb_.AddElement<float>(FakeQuantOptions::VT_MIN, min, 0.0f); }
+ void add_max(float max) { fbb_.AddElement<float>(FakeQuantOptions::VT_MAX, max, 0.0f); }
+ void add_num_bits(int32_t num_bits)
+ {
+ fbb_.AddElement<int32_t>(FakeQuantOptions::VT_NUM_BITS, num_bits, 0);
+ }
+ void add_narrow_range(bool narrow_range)
+ {
+ fbb_.AddElement<uint8_t>(FakeQuantOptions::VT_NARROW_RANGE, static_cast<uint8_t>(narrow_range),
+ 0);
+ }
+ explicit FakeQuantOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<FakeQuantOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<FakeQuantOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<FakeQuantOptions>
+CreateFakeQuantOptions(::flatbuffers::FlatBufferBuilder &_fbb, float min = 0.0f, float max = 0.0f,
+ int32_t num_bits = 0, bool narrow_range = false)
+{
+ FakeQuantOptionsBuilder builder_(_fbb);
+ builder_.add_num_bits(num_bits);
+ builder_.add_max(max);
+ builder_.add_min(min);
+ builder_.add_narrow_range(narrow_range);
+ return builder_.Finish();
+}
+
+struct PackOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef PackOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_VALUES_COUNT = 4,
+ VT_AXIS = 6
+ };
+ int32_t values_count() const { return GetField<int32_t>(VT_VALUES_COUNT, 0); }
+ int32_t axis() const { return GetField<int32_t>(VT_AXIS, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_VALUES_COUNT, 4) &&
+ VerifyField<int32_t>(verifier, VT_AXIS, 4) && verifier.EndTable();
+ }
+};
+
+struct PackOptionsBuilder
+{
+ typedef PackOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_values_count(int32_t values_count)
+ {
+ fbb_.AddElement<int32_t>(PackOptions::VT_VALUES_COUNT, values_count, 0);
+ }
+ void add_axis(int32_t axis) { fbb_.AddElement<int32_t>(PackOptions::VT_AXIS, axis, 0); }
+ explicit PackOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<PackOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<PackOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<PackOptions> CreatePackOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t values_count = 0,
+ int32_t axis = 0)
+{
+ PackOptionsBuilder builder_(_fbb);
+ builder_.add_axis(axis);
+ builder_.add_values_count(values_count);
+ return builder_.Finish();
+}
+
+struct LogicalOrOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef LogicalOrOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct LogicalOrOptionsBuilder
+{
+ typedef LogicalOrOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit LogicalOrOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<LogicalOrOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<LogicalOrOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<LogicalOrOptions>
+CreateLogicalOrOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ LogicalOrOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct OneHotOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef OneHotOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_AXIS = 4
+ };
+ int32_t axis() const { return GetField<int32_t>(VT_AXIS, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_AXIS, 4) &&
+ verifier.EndTable();
+ }
+};
+
+struct OneHotOptionsBuilder
+{
+ typedef OneHotOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_axis(int32_t axis) { fbb_.AddElement<int32_t>(OneHotOptions::VT_AXIS, axis, 0); }
+ explicit OneHotOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<OneHotOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<OneHotOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<OneHotOptions>
+CreateOneHotOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t axis = 0)
+{
+ OneHotOptionsBuilder builder_(_fbb);
+ builder_.add_axis(axis);
+ return builder_.Finish();
+}
+
+struct AbsOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef AbsOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct AbsOptionsBuilder
+{
+ typedef AbsOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit AbsOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<AbsOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<AbsOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<AbsOptions> CreateAbsOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ AbsOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct HardSwishOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef HardSwishOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct HardSwishOptionsBuilder
+{
+ typedef HardSwishOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit HardSwishOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<HardSwishOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<HardSwishOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<HardSwishOptions>
+CreateHardSwishOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ HardSwishOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct LogicalAndOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef LogicalAndOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct LogicalAndOptionsBuilder
+{
+ typedef LogicalAndOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit LogicalAndOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<LogicalAndOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<LogicalAndOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<LogicalAndOptions>
+CreateLogicalAndOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ LogicalAndOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct LogicalNotOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef LogicalNotOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct LogicalNotOptionsBuilder
+{
+ typedef LogicalNotOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit LogicalNotOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<LogicalNotOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<LogicalNotOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<LogicalNotOptions>
+CreateLogicalNotOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ LogicalNotOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct UnpackOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef UnpackOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_NUM = 4,
+ VT_AXIS = 6
+ };
+ int32_t num() const { return GetField<int32_t>(VT_NUM, 0); }
+ int32_t axis() const { return GetField<int32_t>(VT_AXIS, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_NUM, 4) &&
+ VerifyField<int32_t>(verifier, VT_AXIS, 4) && verifier.EndTable();
+ }
+};
+
+struct UnpackOptionsBuilder
+{
+ typedef UnpackOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_num(int32_t num) { fbb_.AddElement<int32_t>(UnpackOptions::VT_NUM, num, 0); }
+ void add_axis(int32_t axis) { fbb_.AddElement<int32_t>(UnpackOptions::VT_AXIS, axis, 0); }
+ explicit UnpackOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<UnpackOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<UnpackOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<UnpackOptions>
+CreateUnpackOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t num = 0, int32_t axis = 0)
+{
+ UnpackOptionsBuilder builder_(_fbb);
+ builder_.add_axis(axis);
+ builder_.add_num(num);
+ return builder_.Finish();
+}
+
+struct FloorDivOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef FloorDivOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct FloorDivOptionsBuilder
+{
+ typedef FloorDivOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit FloorDivOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<FloorDivOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<FloorDivOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<FloorDivOptions>
+CreateFloorDivOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ FloorDivOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct SquareOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SquareOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct SquareOptionsBuilder
+{
+ typedef SquareOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit SquareOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SquareOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SquareOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SquareOptions>
+CreateSquareOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ SquareOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct ZerosLikeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ZerosLikeOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct ZerosLikeOptionsBuilder
+{
+ typedef ZerosLikeOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit ZerosLikeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ZerosLikeOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ZerosLikeOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ZerosLikeOptions>
+CreateZerosLikeOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ ZerosLikeOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct FillOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef FillOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct FillOptionsBuilder
+{
+ typedef FillOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit FillOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<FillOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<FillOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<FillOptions> CreateFillOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ FillOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct FloorModOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef FloorModOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct FloorModOptionsBuilder
+{
+ typedef FloorModOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit FloorModOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<FloorModOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<FloorModOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<FloorModOptions>
+CreateFloorModOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ FloorModOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct RangeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef RangeOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct RangeOptionsBuilder
+{
+ typedef RangeOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit RangeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<RangeOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<RangeOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<RangeOptions>
+CreateRangeOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ RangeOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct LeakyReluOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef LeakyReluOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_ALPHA = 4
+ };
+ float alpha() const { return GetField<float>(VT_ALPHA, 0.0f); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<float>(verifier, VT_ALPHA, 4) &&
+ verifier.EndTable();
+ }
+};
+
+struct LeakyReluOptionsBuilder
+{
+ typedef LeakyReluOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_alpha(float alpha) { fbb_.AddElement<float>(LeakyReluOptions::VT_ALPHA, alpha, 0.0f); }
+ explicit LeakyReluOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<LeakyReluOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<LeakyReluOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<LeakyReluOptions>
+CreateLeakyReluOptions(::flatbuffers::FlatBufferBuilder &_fbb, float alpha = 0.0f)
+{
+ LeakyReluOptionsBuilder builder_(_fbb);
+ builder_.add_alpha(alpha);
+ return builder_.Finish();
+}
+
+struct SquaredDifferenceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SquaredDifferenceOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct SquaredDifferenceOptionsBuilder
+{
+ typedef SquaredDifferenceOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit SquaredDifferenceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SquaredDifferenceOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SquaredDifferenceOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SquaredDifferenceOptions>
+CreateSquaredDifferenceOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ SquaredDifferenceOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct MirrorPadOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef MirrorPadOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_MODE = 4
+ };
+ onert_tflite::MirrorPadMode mode() const
+ {
+ return static_cast<onert_tflite::MirrorPadMode>(GetField<int8_t>(VT_MODE, 0));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_MODE, 1) &&
+ verifier.EndTable();
+ }
+};
+
+struct MirrorPadOptionsBuilder
+{
+ typedef MirrorPadOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_mode(onert_tflite::MirrorPadMode mode)
+ {
+ fbb_.AddElement<int8_t>(MirrorPadOptions::VT_MODE, static_cast<int8_t>(mode), 0);
+ }
+ explicit MirrorPadOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<MirrorPadOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<MirrorPadOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<MirrorPadOptions>
+CreateMirrorPadOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::MirrorPadMode mode = onert_tflite::MirrorPadMode_REFLECT)
+{
+ MirrorPadOptionsBuilder builder_(_fbb);
+ builder_.add_mode(mode);
+ return builder_.Finish();
+}
+
+struct UniqueOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef UniqueOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_IDX_OUT_TYPE = 4
+ };
+ onert_tflite::TensorType idx_out_type() const
+ {
+ return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_IDX_OUT_TYPE, 2));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int8_t>(verifier, VT_IDX_OUT_TYPE, 1) &&
+ verifier.EndTable();
+ }
+};
+
+struct UniqueOptionsBuilder
+{
+ typedef UniqueOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_idx_out_type(onert_tflite::TensorType idx_out_type)
+ {
+ fbb_.AddElement<int8_t>(UniqueOptions::VT_IDX_OUT_TYPE, static_cast<int8_t>(idx_out_type), 2);
+ }
+ explicit UniqueOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<UniqueOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<UniqueOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<UniqueOptions>
+CreateUniqueOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ onert_tflite::TensorType idx_out_type = onert_tflite::TensorType_INT32)
+{
+ UniqueOptionsBuilder builder_(_fbb);
+ builder_.add_idx_out_type(idx_out_type);
+ return builder_.Finish();
+}
+
+struct ReverseV2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ReverseV2OptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct ReverseV2OptionsBuilder
+{
+ typedef ReverseV2Options Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit ReverseV2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ReverseV2Options> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ReverseV2Options>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ReverseV2Options>
+CreateReverseV2Options(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ ReverseV2OptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct AddNOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef AddNOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct AddNOptionsBuilder
+{
+ typedef AddNOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit AddNOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<AddNOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<AddNOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<AddNOptions> CreateAddNOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ AddNOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct GatherNdOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef GatherNdOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct GatherNdOptionsBuilder
+{
+ typedef GatherNdOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit GatherNdOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<GatherNdOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<GatherNdOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<GatherNdOptions>
+CreateGatherNdOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ GatherNdOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct WhereOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef WhereOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct WhereOptionsBuilder
+{
+ typedef WhereOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit WhereOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<WhereOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<WhereOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<WhereOptions>
+CreateWhereOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ WhereOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct ReverseSequenceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ReverseSequenceOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_SEQ_DIM = 4,
+ VT_BATCH_DIM = 6
+ };
+ int32_t seq_dim() const { return GetField<int32_t>(VT_SEQ_DIM, 0); }
+ int32_t batch_dim() const { return GetField<int32_t>(VT_BATCH_DIM, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_SEQ_DIM, 4) &&
+ VerifyField<int32_t>(verifier, VT_BATCH_DIM, 4) && verifier.EndTable();
+ }
+};
+
+struct ReverseSequenceOptionsBuilder
+{
+ typedef ReverseSequenceOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_seq_dim(int32_t seq_dim)
+ {
+ fbb_.AddElement<int32_t>(ReverseSequenceOptions::VT_SEQ_DIM, seq_dim, 0);
+ }
+ void add_batch_dim(int32_t batch_dim)
+ {
+ fbb_.AddElement<int32_t>(ReverseSequenceOptions::VT_BATCH_DIM, batch_dim, 0);
+ }
+ explicit ReverseSequenceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ReverseSequenceOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ReverseSequenceOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ReverseSequenceOptions>
+CreateReverseSequenceOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t seq_dim = 0,
+ int32_t batch_dim = 0)
+{
+ ReverseSequenceOptionsBuilder builder_(_fbb);
+ builder_.add_batch_dim(batch_dim);
+ builder_.add_seq_dim(seq_dim);
+ return builder_.Finish();
+}
+
+struct MatrixDiagOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef MatrixDiagOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct MatrixDiagOptionsBuilder
+{
+ typedef MatrixDiagOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit MatrixDiagOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<MatrixDiagOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<MatrixDiagOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<MatrixDiagOptions>
+CreateMatrixDiagOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ MatrixDiagOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct QuantizeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef QuantizeOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct QuantizeOptionsBuilder
+{
+ typedef QuantizeOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit QuantizeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<QuantizeOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<QuantizeOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<QuantizeOptions>
+CreateQuantizeOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ QuantizeOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct MatrixSetDiagOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef MatrixSetDiagOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct MatrixSetDiagOptionsBuilder
+{
+ typedef MatrixSetDiagOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit MatrixSetDiagOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<MatrixSetDiagOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<MatrixSetDiagOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<MatrixSetDiagOptions>
+CreateMatrixSetDiagOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ MatrixSetDiagOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct IfOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef IfOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_THEN_SUBGRAPH_INDEX = 4,
+ VT_ELSE_SUBGRAPH_INDEX = 6
+ };
+ int32_t then_subgraph_index() const { return GetField<int32_t>(VT_THEN_SUBGRAPH_INDEX, 0); }
+ int32_t else_subgraph_index() const { return GetField<int32_t>(VT_ELSE_SUBGRAPH_INDEX, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_THEN_SUBGRAPH_INDEX, 4) &&
+ VerifyField<int32_t>(verifier, VT_ELSE_SUBGRAPH_INDEX, 4) && verifier.EndTable();
+ }
+};
+
+struct IfOptionsBuilder
+{
+ typedef IfOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_then_subgraph_index(int32_t then_subgraph_index)
+ {
+ fbb_.AddElement<int32_t>(IfOptions::VT_THEN_SUBGRAPH_INDEX, then_subgraph_index, 0);
+ }
+ void add_else_subgraph_index(int32_t else_subgraph_index)
+ {
+ fbb_.AddElement<int32_t>(IfOptions::VT_ELSE_SUBGRAPH_INDEX, else_subgraph_index, 0);
+ }
+ explicit IfOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<IfOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<IfOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<IfOptions> CreateIfOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t then_subgraph_index = 0,
+ int32_t else_subgraph_index = 0)
+{
+ IfOptionsBuilder builder_(_fbb);
+ builder_.add_else_subgraph_index(else_subgraph_index);
+ builder_.add_then_subgraph_index(then_subgraph_index);
+ return builder_.Finish();
+}
+
+struct CallOnceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef CallOnceOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_INIT_SUBGRAPH_INDEX = 4
+ };
+ int32_t init_subgraph_index() const { return GetField<int32_t>(VT_INIT_SUBGRAPH_INDEX, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_INIT_SUBGRAPH_INDEX, 4) && verifier.EndTable();
+ }
+};
+
+struct CallOnceOptionsBuilder
+{
+ typedef CallOnceOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_init_subgraph_index(int32_t init_subgraph_index)
+ {
+ fbb_.AddElement<int32_t>(CallOnceOptions::VT_INIT_SUBGRAPH_INDEX, init_subgraph_index, 0);
+ }
+ explicit CallOnceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<CallOnceOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<CallOnceOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<CallOnceOptions>
+CreateCallOnceOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t init_subgraph_index = 0)
+{
+ CallOnceOptionsBuilder builder_(_fbb);
+ builder_.add_init_subgraph_index(init_subgraph_index);
+ return builder_.Finish();
+}
+
+struct WhileOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef WhileOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_COND_SUBGRAPH_INDEX = 4,
+ VT_BODY_SUBGRAPH_INDEX = 6
+ };
+ int32_t cond_subgraph_index() const { return GetField<int32_t>(VT_COND_SUBGRAPH_INDEX, 0); }
+ int32_t body_subgraph_index() const { return GetField<int32_t>(VT_BODY_SUBGRAPH_INDEX, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_COND_SUBGRAPH_INDEX, 4) &&
+ VerifyField<int32_t>(verifier, VT_BODY_SUBGRAPH_INDEX, 4) && verifier.EndTable();
+ }
+};
+
+struct WhileOptionsBuilder
+{
+ typedef WhileOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_cond_subgraph_index(int32_t cond_subgraph_index)
+ {
+ fbb_.AddElement<int32_t>(WhileOptions::VT_COND_SUBGRAPH_INDEX, cond_subgraph_index, 0);
+ }
+ void add_body_subgraph_index(int32_t body_subgraph_index)
+ {
+ fbb_.AddElement<int32_t>(WhileOptions::VT_BODY_SUBGRAPH_INDEX, body_subgraph_index, 0);
+ }
+ explicit WhileOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<WhileOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<WhileOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<WhileOptions>
+CreateWhileOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t cond_subgraph_index = 0,
+ int32_t body_subgraph_index = 0)
+{
+ WhileOptionsBuilder builder_(_fbb);
+ builder_.add_body_subgraph_index(body_subgraph_index);
+ builder_.add_cond_subgraph_index(cond_subgraph_index);
+ return builder_.Finish();
+}
+
+struct NonMaxSuppressionV4Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef NonMaxSuppressionV4OptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct NonMaxSuppressionV4OptionsBuilder
+{
+ typedef NonMaxSuppressionV4Options Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit NonMaxSuppressionV4OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<NonMaxSuppressionV4Options> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<NonMaxSuppressionV4Options>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<NonMaxSuppressionV4Options>
+CreateNonMaxSuppressionV4Options(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ NonMaxSuppressionV4OptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct NonMaxSuppressionV5Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef NonMaxSuppressionV5OptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct NonMaxSuppressionV5OptionsBuilder
+{
+ typedef NonMaxSuppressionV5Options Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit NonMaxSuppressionV5OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<NonMaxSuppressionV5Options> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<NonMaxSuppressionV5Options>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<NonMaxSuppressionV5Options>
+CreateNonMaxSuppressionV5Options(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ NonMaxSuppressionV5OptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct ScatterNdOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ScatterNdOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct ScatterNdOptionsBuilder
+{
+ typedef ScatterNdOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit ScatterNdOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ScatterNdOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ScatterNdOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ScatterNdOptions>
+CreateScatterNdOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ ScatterNdOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct SelectV2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SelectV2OptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct SelectV2OptionsBuilder
+{
+ typedef SelectV2Options Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit SelectV2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SelectV2Options> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SelectV2Options>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SelectV2Options>
+CreateSelectV2Options(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ SelectV2OptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct DensifyOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef DensifyOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct DensifyOptionsBuilder
+{
+ typedef DensifyOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit DensifyOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<DensifyOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<DensifyOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<DensifyOptions>
+CreateDensifyOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ DensifyOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct SegmentSumOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SegmentSumOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct SegmentSumOptionsBuilder
+{
+ typedef SegmentSumOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit SegmentSumOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SegmentSumOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SegmentSumOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SegmentSumOptions>
+CreateSegmentSumOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ SegmentSumOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct BatchMatMulOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef BatchMatMulOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_ADJ_X = 4,
+ VT_ADJ_Y = 6,
+ VT_ASYMMETRIC_QUANTIZE_INPUTS = 8
+ };
+ bool adj_x() const { return GetField<uint8_t>(VT_ADJ_X, 0) != 0; }
+ bool adj_y() const { return GetField<uint8_t>(VT_ADJ_Y, 0) != 0; }
+ bool asymmetric_quantize_inputs() const
+ {
+ return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_ADJ_X, 1) &&
+ VerifyField<uint8_t>(verifier, VT_ADJ_Y, 1) &&
+ VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && verifier.EndTable();
+ }
+};
+
+struct BatchMatMulOptionsBuilder
+{
+ typedef BatchMatMulOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_adj_x(bool adj_x)
+ {
+ fbb_.AddElement<uint8_t>(BatchMatMulOptions::VT_ADJ_X, static_cast<uint8_t>(adj_x), 0);
+ }
+ void add_adj_y(bool adj_y)
+ {
+ fbb_.AddElement<uint8_t>(BatchMatMulOptions::VT_ADJ_Y, static_cast<uint8_t>(adj_y), 0);
+ }
+ void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs)
+ {
+ fbb_.AddElement<uint8_t>(BatchMatMulOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS,
+ static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
+ }
+ explicit BatchMatMulOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<BatchMatMulOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<BatchMatMulOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<BatchMatMulOptions>
+CreateBatchMatMulOptions(::flatbuffers::FlatBufferBuilder &_fbb, bool adj_x = false,
+ bool adj_y = false, bool asymmetric_quantize_inputs = false)
+{
+ BatchMatMulOptionsBuilder builder_(_fbb);
+ builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
+ builder_.add_adj_y(adj_y);
+ builder_.add_adj_x(adj_x);
+ return builder_.Finish();
+}
+
+struct CumsumOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef CumsumOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_EXCLUSIVE = 4,
+ VT_REVERSE = 6
+ };
+ bool exclusive() const { return GetField<uint8_t>(VT_EXCLUSIVE, 0) != 0; }
+ bool reverse() const { return GetField<uint8_t>(VT_REVERSE, 0) != 0; }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_EXCLUSIVE, 1) &&
+ VerifyField<uint8_t>(verifier, VT_REVERSE, 1) && verifier.EndTable();
+ }
+};
+
+struct CumsumOptionsBuilder
+{
+ typedef CumsumOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_exclusive(bool exclusive)
+ {
+ fbb_.AddElement<uint8_t>(CumsumOptions::VT_EXCLUSIVE, static_cast<uint8_t>(exclusive), 0);
+ }
+ void add_reverse(bool reverse)
+ {
+ fbb_.AddElement<uint8_t>(CumsumOptions::VT_REVERSE, static_cast<uint8_t>(reverse), 0);
+ }
+ explicit CumsumOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<CumsumOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<CumsumOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<CumsumOptions>
+CreateCumsumOptions(::flatbuffers::FlatBufferBuilder &_fbb, bool exclusive = false,
+ bool reverse = false)
+{
+ CumsumOptionsBuilder builder_(_fbb);
+ builder_.add_reverse(reverse);
+ builder_.add_exclusive(exclusive);
+ return builder_.Finish();
+}
+
+struct BroadcastToOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef BroadcastToOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct BroadcastToOptionsBuilder
+{
+ typedef BroadcastToOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit BroadcastToOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<BroadcastToOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<BroadcastToOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<BroadcastToOptions>
+CreateBroadcastToOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ BroadcastToOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct Rfft2dOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef Rfft2dOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct Rfft2dOptionsBuilder
+{
+ typedef Rfft2dOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit Rfft2dOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<Rfft2dOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<Rfft2dOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<Rfft2dOptions>
+CreateRfft2dOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ Rfft2dOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct HashtableOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef HashtableOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_TABLE_ID = 4,
+ VT_KEY_DTYPE = 6,
+ VT_VALUE_DTYPE = 8
+ };
+ int32_t table_id() const { return GetField<int32_t>(VT_TABLE_ID, 0); }
+ onert_tflite::TensorType key_dtype() const
+ {
+ return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_KEY_DTYPE, 0));
+ }
+ onert_tflite::TensorType value_dtype() const
+ {
+ return static_cast<onert_tflite::TensorType>(GetField<int8_t>(VT_VALUE_DTYPE, 0));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int32_t>(verifier, VT_TABLE_ID, 4) &&
+ VerifyField<int8_t>(verifier, VT_KEY_DTYPE, 1) &&
+ VerifyField<int8_t>(verifier, VT_VALUE_DTYPE, 1) && verifier.EndTable();
+ }
+};
+
+struct HashtableOptionsBuilder
+{
+ typedef HashtableOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_table_id(int32_t table_id)
+ {
+ fbb_.AddElement<int32_t>(HashtableOptions::VT_TABLE_ID, table_id, 0);
+ }
+ void add_key_dtype(onert_tflite::TensorType key_dtype)
+ {
+ fbb_.AddElement<int8_t>(HashtableOptions::VT_KEY_DTYPE, static_cast<int8_t>(key_dtype), 0);
+ }
+ void add_value_dtype(onert_tflite::TensorType value_dtype)
+ {
+ fbb_.AddElement<int8_t>(HashtableOptions::VT_VALUE_DTYPE, static_cast<int8_t>(value_dtype), 0);
+ }
+ explicit HashtableOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<HashtableOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<HashtableOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<HashtableOptions>
+CreateHashtableOptions(::flatbuffers::FlatBufferBuilder &_fbb, int32_t table_id = 0,
+ onert_tflite::TensorType key_dtype = onert_tflite::TensorType_FLOAT32,
+ onert_tflite::TensorType value_dtype = onert_tflite::TensorType_FLOAT32)
+{
+ HashtableOptionsBuilder builder_(_fbb);
+ builder_.add_table_id(table_id);
+ builder_.add_value_dtype(value_dtype);
+ builder_.add_key_dtype(key_dtype);
+ return builder_.Finish();
+}
+
+struct HashtableFindOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef HashtableFindOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct HashtableFindOptionsBuilder
+{
+ typedef HashtableFindOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit HashtableFindOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<HashtableFindOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<HashtableFindOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<HashtableFindOptions>
+CreateHashtableFindOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ HashtableFindOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct HashtableImportOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef HashtableImportOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct HashtableImportOptionsBuilder
+{
+ typedef HashtableImportOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit HashtableImportOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<HashtableImportOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<HashtableImportOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<HashtableImportOptions>
+CreateHashtableImportOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ HashtableImportOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct HashtableSizeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef HashtableSizeOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct HashtableSizeOptionsBuilder
+{
+ typedef HashtableSizeOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit HashtableSizeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<HashtableSizeOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<HashtableSizeOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<HashtableSizeOptions>
+CreateHashtableSizeOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ HashtableSizeOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct VarHandleOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef VarHandleOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_CONTAINER = 4,
+ VT_SHARED_NAME = 6
+ };
+ const ::flatbuffers::String *container() const
+ {
+ return GetPointer<const ::flatbuffers::String *>(VT_CONTAINER);
+ }
+ const ::flatbuffers::String *shared_name() const
+ {
+ return GetPointer<const ::flatbuffers::String *>(VT_SHARED_NAME);
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_CONTAINER) &&
+ verifier.VerifyString(container()) && VerifyOffset(verifier, VT_SHARED_NAME) &&
+ verifier.VerifyString(shared_name()) && verifier.EndTable();
+ }
+};
+
+struct VarHandleOptionsBuilder
+{
+ typedef VarHandleOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_container(::flatbuffers::Offset<::flatbuffers::String> container)
+ {
+ fbb_.AddOffset(VarHandleOptions::VT_CONTAINER, container);
+ }
+ void add_shared_name(::flatbuffers::Offset<::flatbuffers::String> shared_name)
+ {
+ fbb_.AddOffset(VarHandleOptions::VT_SHARED_NAME, shared_name);
+ }
+ explicit VarHandleOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<VarHandleOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<VarHandleOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<VarHandleOptions>
+CreateVarHandleOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::String> container = 0,
+ ::flatbuffers::Offset<::flatbuffers::String> shared_name = 0)
+{
+ VarHandleOptionsBuilder builder_(_fbb);
+ builder_.add_shared_name(shared_name);
+ builder_.add_container(container);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<VarHandleOptions>
+CreateVarHandleOptionsDirect(::flatbuffers::FlatBufferBuilder &_fbb,
+ const char *container = nullptr, const char *shared_name = nullptr)
+{
+ auto container__ = container ? _fbb.CreateString(container) : 0;
+ auto shared_name__ = shared_name ? _fbb.CreateString(shared_name) : 0;
+ return onert_tflite::CreateVarHandleOptions(_fbb, container__, shared_name__);
+}
+
+struct ReadVariableOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ReadVariableOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct ReadVariableOptionsBuilder
+{
+ typedef ReadVariableOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit ReadVariableOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ReadVariableOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ReadVariableOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ReadVariableOptions>
+CreateReadVariableOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ ReadVariableOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct AssignVariableOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef AssignVariableOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct AssignVariableOptionsBuilder
+{
+ typedef AssignVariableOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit AssignVariableOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<AssignVariableOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<AssignVariableOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<AssignVariableOptions>
+CreateAssignVariableOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ AssignVariableOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct RandomOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef RandomOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_SEED = 4,
+ VT_SEED2 = 6
+ };
+ int64_t seed() const { return GetField<int64_t>(VT_SEED, 0); }
+ int64_t seed2() const { return GetField<int64_t>(VT_SEED2, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<int64_t>(verifier, VT_SEED, 8) &&
+ VerifyField<int64_t>(verifier, VT_SEED2, 8) && verifier.EndTable();
+ }
+};
+
+struct RandomOptionsBuilder
+{
+ typedef RandomOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_seed(int64_t seed) { fbb_.AddElement<int64_t>(RandomOptions::VT_SEED, seed, 0); }
+ void add_seed2(int64_t seed2) { fbb_.AddElement<int64_t>(RandomOptions::VT_SEED2, seed2, 0); }
+ explicit RandomOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<RandomOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<RandomOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<RandomOptions>
+CreateRandomOptions(::flatbuffers::FlatBufferBuilder &_fbb, int64_t seed = 0, int64_t seed2 = 0)
+{
+ RandomOptionsBuilder builder_(_fbb);
+ builder_.add_seed2(seed2);
+ builder_.add_seed(seed);
+ return builder_.Finish();
+}
+
+struct BucketizeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef BucketizeOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_BOUNDARIES = 4
+ };
+ const ::flatbuffers::Vector<float> *boundaries() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<float> *>(VT_BOUNDARIES);
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_BOUNDARIES) &&
+ verifier.VerifyVector(boundaries()) && verifier.EndTable();
+ }
+};
+
+struct BucketizeOptionsBuilder
+{
+ typedef BucketizeOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_boundaries(::flatbuffers::Offset<::flatbuffers::Vector<float>> boundaries)
+ {
+ fbb_.AddOffset(BucketizeOptions::VT_BOUNDARIES, boundaries);
+ }
+ explicit BucketizeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<BucketizeOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<BucketizeOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<BucketizeOptions>
+CreateBucketizeOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::Vector<float>> boundaries = 0)
+{
+ BucketizeOptionsBuilder builder_(_fbb);
+ builder_.add_boundaries(boundaries);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<BucketizeOptions>
+CreateBucketizeOptionsDirect(::flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<float> *boundaries = nullptr)
+{
+ auto boundaries__ = boundaries ? _fbb.CreateVector<float>(*boundaries) : 0;
+ return onert_tflite::CreateBucketizeOptions(_fbb, boundaries__);
+}
+
+struct GeluOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef GeluOptionsBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_APPROXIMATE = 4
+ };
+ bool approximate() const { return GetField<uint8_t>(VT_APPROXIMATE, 0) != 0; }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_APPROXIMATE, 1) &&
+ verifier.EndTable();
+ }
+};
+
+struct GeluOptionsBuilder
+{
+ typedef GeluOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_approximate(bool approximate)
+ {
+ fbb_.AddElement<uint8_t>(GeluOptions::VT_APPROXIMATE, static_cast<uint8_t>(approximate), 0);
+ }
+ explicit GeluOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<GeluOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<GeluOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<GeluOptions> CreateGeluOptions(::flatbuffers::FlatBufferBuilder &_fbb,
+ bool approximate = false)
+{
+ GeluOptionsBuilder builder_(_fbb);
+ builder_.add_approximate(approximate);
+ return builder_.Finish();
+}
+
+struct DynamicUpdateSliceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef DynamicUpdateSliceOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct DynamicUpdateSliceOptionsBuilder
+{
+ typedef DynamicUpdateSliceOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit DynamicUpdateSliceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<DynamicUpdateSliceOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<DynamicUpdateSliceOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<DynamicUpdateSliceOptions>
+CreateDynamicUpdateSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ DynamicUpdateSliceOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct UnsortedSegmentProdOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef UnsortedSegmentProdOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct UnsortedSegmentProdOptionsBuilder
+{
+ typedef UnsortedSegmentProdOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit UnsortedSegmentProdOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<UnsortedSegmentProdOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<UnsortedSegmentProdOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<UnsortedSegmentProdOptions>
+CreateUnsortedSegmentProdOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ UnsortedSegmentProdOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct UnsortedSegmentMaxOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef UnsortedSegmentMaxOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct UnsortedSegmentMaxOptionsBuilder
+{
+ typedef UnsortedSegmentMaxOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit UnsortedSegmentMaxOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<UnsortedSegmentMaxOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<UnsortedSegmentMaxOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<UnsortedSegmentMaxOptions>
+CreateUnsortedSegmentMaxOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ UnsortedSegmentMaxOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct UnsortedSegmentSumOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef UnsortedSegmentSumOptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct UnsortedSegmentSumOptionsBuilder
+{
+ typedef UnsortedSegmentSumOptions Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit UnsortedSegmentSumOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<UnsortedSegmentSumOptions> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<UnsortedSegmentSumOptions>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<UnsortedSegmentSumOptions>
+CreateUnsortedSegmentSumOptions(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ UnsortedSegmentSumOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct ATan2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ATan2OptionsBuilder Builder;
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && verifier.EndTable();
+ }
+};
+
+struct ATan2OptionsBuilder
+{
+ typedef ATan2Options Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ explicit ATan2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<ATan2Options> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<ATan2Options>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<ATan2Options>
+CreateATan2Options(::flatbuffers::FlatBufferBuilder &_fbb)
+{
+ ATan2OptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+struct OperatorCode FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef OperatorCodeBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_DEPRECATED_BUILTIN_CODE = 4,
+ VT_CUSTOM_CODE = 6,
+ VT_VERSION = 8,
+ VT_BUILTIN_CODE = 10
+ };
+ int8_t deprecated_builtin_code() const { return GetField<int8_t>(VT_DEPRECATED_BUILTIN_CODE, 0); }
+ const ::flatbuffers::String *custom_code() const
+ {
+ return GetPointer<const ::flatbuffers::String *>(VT_CUSTOM_CODE);
+ }
+ int32_t version() const { return GetField<int32_t>(VT_VERSION, 1); }
+ onert_tflite::BuiltinOperator builtin_code() const
+ {
+ return static_cast<onert_tflite::BuiltinOperator>(GetField<int32_t>(VT_BUILTIN_CODE, 0));
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_DEPRECATED_BUILTIN_CODE, 1) &&
+ VerifyOffset(verifier, VT_CUSTOM_CODE) && verifier.VerifyString(custom_code()) &&
+ VerifyField<int32_t>(verifier, VT_VERSION, 4) &&
+ VerifyField<int32_t>(verifier, VT_BUILTIN_CODE, 4) && verifier.EndTable();
+ }
+};
+
+struct OperatorCodeBuilder
+{
+ typedef OperatorCode Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_deprecated_builtin_code(int8_t deprecated_builtin_code)
+ {
+ fbb_.AddElement<int8_t>(OperatorCode::VT_DEPRECATED_BUILTIN_CODE, deprecated_builtin_code, 0);
+ }
+ void add_custom_code(::flatbuffers::Offset<::flatbuffers::String> custom_code)
+ {
+ fbb_.AddOffset(OperatorCode::VT_CUSTOM_CODE, custom_code);
+ }
+ void add_version(int32_t version)
+ {
+ fbb_.AddElement<int32_t>(OperatorCode::VT_VERSION, version, 1);
+ }
+ void add_builtin_code(onert_tflite::BuiltinOperator builtin_code)
+ {
+ fbb_.AddElement<int32_t>(OperatorCode::VT_BUILTIN_CODE, static_cast<int32_t>(builtin_code), 0);
+ }
+ explicit OperatorCodeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<OperatorCode> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<OperatorCode>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<OperatorCode>
+CreateOperatorCode(::flatbuffers::FlatBufferBuilder &_fbb, int8_t deprecated_builtin_code = 0,
+ ::flatbuffers::Offset<::flatbuffers::String> custom_code = 0,
+ int32_t version = 1,
+ onert_tflite::BuiltinOperator builtin_code = onert_tflite::BuiltinOperator_ADD)
+{
+ OperatorCodeBuilder builder_(_fbb);
+ builder_.add_builtin_code(builtin_code);
+ builder_.add_version(version);
+ builder_.add_custom_code(custom_code);
+ builder_.add_deprecated_builtin_code(deprecated_builtin_code);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<OperatorCode> CreateOperatorCodeDirect(
+ ::flatbuffers::FlatBufferBuilder &_fbb, int8_t deprecated_builtin_code = 0,
+ const char *custom_code = nullptr, int32_t version = 1,
+ onert_tflite::BuiltinOperator builtin_code = onert_tflite::BuiltinOperator_ADD)
+{
+ auto custom_code__ = custom_code ? _fbb.CreateString(custom_code) : 0;
+ return onert_tflite::CreateOperatorCode(_fbb, deprecated_builtin_code, custom_code__, version,
+ builtin_code);
+}
+
+struct Operator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef OperatorBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_OPCODE_INDEX = 4,
+ VT_INPUTS = 6,
+ VT_OUTPUTS = 8,
+ VT_BUILTIN_OPTIONS_TYPE = 10,
+ VT_BUILTIN_OPTIONS = 12,
+ VT_CUSTOM_OPTIONS = 14,
+ VT_CUSTOM_OPTIONS_FORMAT = 16,
+ VT_MUTATING_VARIABLE_INPUTS = 18,
+ VT_INTERMEDIATES = 20
+ };
+ uint32_t opcode_index() const { return GetField<uint32_t>(VT_OPCODE_INDEX, 0); }
+ const ::flatbuffers::Vector<int32_t> *inputs() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_INPUTS);
+ }
+ const ::flatbuffers::Vector<int32_t> *outputs() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_OUTPUTS);
+ }
+ onert_tflite::BuiltinOptions builtin_options_type() const
+ {
+ return static_cast<onert_tflite::BuiltinOptions>(GetField<uint8_t>(VT_BUILTIN_OPTIONS_TYPE, 0));
+ }
+ const void *builtin_options() const { return GetPointer<const void *>(VT_BUILTIN_OPTIONS); }
+ template <typename T> const T *builtin_options_as() const;
+ const onert_tflite::Conv2DOptions *builtin_options_as_Conv2DOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_Conv2DOptions
+ ? static_cast<const onert_tflite::Conv2DOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::DepthwiseConv2DOptions *builtin_options_as_DepthwiseConv2DOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_DepthwiseConv2DOptions
+ ? static_cast<const onert_tflite::DepthwiseConv2DOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ConcatEmbeddingsOptions *builtin_options_as_ConcatEmbeddingsOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ConcatEmbeddingsOptions
+ ? static_cast<const onert_tflite::ConcatEmbeddingsOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::LSHProjectionOptions *builtin_options_as_LSHProjectionOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_LSHProjectionOptions
+ ? static_cast<const onert_tflite::LSHProjectionOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::Pool2DOptions *builtin_options_as_Pool2DOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_Pool2DOptions
+ ? static_cast<const onert_tflite::Pool2DOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SVDFOptions *builtin_options_as_SVDFOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SVDFOptions
+ ? static_cast<const onert_tflite::SVDFOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::RNNOptions *builtin_options_as_RNNOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_RNNOptions
+ ? static_cast<const onert_tflite::RNNOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::FullyConnectedOptions *builtin_options_as_FullyConnectedOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_FullyConnectedOptions
+ ? static_cast<const onert_tflite::FullyConnectedOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SoftmaxOptions *builtin_options_as_SoftmaxOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SoftmaxOptions
+ ? static_cast<const onert_tflite::SoftmaxOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ConcatenationOptions *builtin_options_as_ConcatenationOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ConcatenationOptions
+ ? static_cast<const onert_tflite::ConcatenationOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::AddOptions *builtin_options_as_AddOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_AddOptions
+ ? static_cast<const onert_tflite::AddOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::L2NormOptions *builtin_options_as_L2NormOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_L2NormOptions
+ ? static_cast<const onert_tflite::L2NormOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::LocalResponseNormalizationOptions *
+ builtin_options_as_LocalResponseNormalizationOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_LocalResponseNormalizationOptions
+ ? static_cast<const onert_tflite::LocalResponseNormalizationOptions *>(
+ builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::LSTMOptions *builtin_options_as_LSTMOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_LSTMOptions
+ ? static_cast<const onert_tflite::LSTMOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ResizeBilinearOptions *builtin_options_as_ResizeBilinearOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ResizeBilinearOptions
+ ? static_cast<const onert_tflite::ResizeBilinearOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::CallOptions *builtin_options_as_CallOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_CallOptions
+ ? static_cast<const onert_tflite::CallOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ReshapeOptions *builtin_options_as_ReshapeOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ReshapeOptions
+ ? static_cast<const onert_tflite::ReshapeOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SkipGramOptions *builtin_options_as_SkipGramOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SkipGramOptions
+ ? static_cast<const onert_tflite::SkipGramOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SpaceToDepthOptions *builtin_options_as_SpaceToDepthOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SpaceToDepthOptions
+ ? static_cast<const onert_tflite::SpaceToDepthOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::EmbeddingLookupSparseOptions *
+ builtin_options_as_EmbeddingLookupSparseOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_EmbeddingLookupSparseOptions
+ ? static_cast<const onert_tflite::EmbeddingLookupSparseOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::MulOptions *builtin_options_as_MulOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_MulOptions
+ ? static_cast<const onert_tflite::MulOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::PadOptions *builtin_options_as_PadOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_PadOptions
+ ? static_cast<const onert_tflite::PadOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::GatherOptions *builtin_options_as_GatherOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_GatherOptions
+ ? static_cast<const onert_tflite::GatherOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::BatchToSpaceNDOptions *builtin_options_as_BatchToSpaceNDOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_BatchToSpaceNDOptions
+ ? static_cast<const onert_tflite::BatchToSpaceNDOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SpaceToBatchNDOptions *builtin_options_as_SpaceToBatchNDOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SpaceToBatchNDOptions
+ ? static_cast<const onert_tflite::SpaceToBatchNDOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::TransposeOptions *builtin_options_as_TransposeOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_TransposeOptions
+ ? static_cast<const onert_tflite::TransposeOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ReducerOptions *builtin_options_as_ReducerOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ReducerOptions
+ ? static_cast<const onert_tflite::ReducerOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SubOptions *builtin_options_as_SubOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SubOptions
+ ? static_cast<const onert_tflite::SubOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::DivOptions *builtin_options_as_DivOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_DivOptions
+ ? static_cast<const onert_tflite::DivOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SqueezeOptions *builtin_options_as_SqueezeOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SqueezeOptions
+ ? static_cast<const onert_tflite::SqueezeOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SequenceRNNOptions *builtin_options_as_SequenceRNNOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SequenceRNNOptions
+ ? static_cast<const onert_tflite::SequenceRNNOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::StridedSliceOptions *builtin_options_as_StridedSliceOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_StridedSliceOptions
+ ? static_cast<const onert_tflite::StridedSliceOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ExpOptions *builtin_options_as_ExpOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ExpOptions
+ ? static_cast<const onert_tflite::ExpOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::TopKV2Options *builtin_options_as_TopKV2Options() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_TopKV2Options
+ ? static_cast<const onert_tflite::TopKV2Options *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SplitOptions *builtin_options_as_SplitOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SplitOptions
+ ? static_cast<const onert_tflite::SplitOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::LogSoftmaxOptions *builtin_options_as_LogSoftmaxOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_LogSoftmaxOptions
+ ? static_cast<const onert_tflite::LogSoftmaxOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::CastOptions *builtin_options_as_CastOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_CastOptions
+ ? static_cast<const onert_tflite::CastOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::DequantizeOptions *builtin_options_as_DequantizeOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_DequantizeOptions
+ ? static_cast<const onert_tflite::DequantizeOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::MaximumMinimumOptions *builtin_options_as_MaximumMinimumOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_MaximumMinimumOptions
+ ? static_cast<const onert_tflite::MaximumMinimumOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ArgMaxOptions *builtin_options_as_ArgMaxOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ArgMaxOptions
+ ? static_cast<const onert_tflite::ArgMaxOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::LessOptions *builtin_options_as_LessOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_LessOptions
+ ? static_cast<const onert_tflite::LessOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::NegOptions *builtin_options_as_NegOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_NegOptions
+ ? static_cast<const onert_tflite::NegOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::PadV2Options *builtin_options_as_PadV2Options() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_PadV2Options
+ ? static_cast<const onert_tflite::PadV2Options *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::GreaterOptions *builtin_options_as_GreaterOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_GreaterOptions
+ ? static_cast<const onert_tflite::GreaterOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::GreaterEqualOptions *builtin_options_as_GreaterEqualOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_GreaterEqualOptions
+ ? static_cast<const onert_tflite::GreaterEqualOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::LessEqualOptions *builtin_options_as_LessEqualOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_LessEqualOptions
+ ? static_cast<const onert_tflite::LessEqualOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SelectOptions *builtin_options_as_SelectOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SelectOptions
+ ? static_cast<const onert_tflite::SelectOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SliceOptions *builtin_options_as_SliceOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SliceOptions
+ ? static_cast<const onert_tflite::SliceOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::TransposeConvOptions *builtin_options_as_TransposeConvOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_TransposeConvOptions
+ ? static_cast<const onert_tflite::TransposeConvOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SparseToDenseOptions *builtin_options_as_SparseToDenseOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SparseToDenseOptions
+ ? static_cast<const onert_tflite::SparseToDenseOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::TileOptions *builtin_options_as_TileOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_TileOptions
+ ? static_cast<const onert_tflite::TileOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ExpandDimsOptions *builtin_options_as_ExpandDimsOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ExpandDimsOptions
+ ? static_cast<const onert_tflite::ExpandDimsOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::EqualOptions *builtin_options_as_EqualOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_EqualOptions
+ ? static_cast<const onert_tflite::EqualOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::NotEqualOptions *builtin_options_as_NotEqualOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_NotEqualOptions
+ ? static_cast<const onert_tflite::NotEqualOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ShapeOptions *builtin_options_as_ShapeOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ShapeOptions
+ ? static_cast<const onert_tflite::ShapeOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::PowOptions *builtin_options_as_PowOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_PowOptions
+ ? static_cast<const onert_tflite::PowOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ArgMinOptions *builtin_options_as_ArgMinOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ArgMinOptions
+ ? static_cast<const onert_tflite::ArgMinOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::FakeQuantOptions *builtin_options_as_FakeQuantOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_FakeQuantOptions
+ ? static_cast<const onert_tflite::FakeQuantOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::PackOptions *builtin_options_as_PackOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_PackOptions
+ ? static_cast<const onert_tflite::PackOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::LogicalOrOptions *builtin_options_as_LogicalOrOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_LogicalOrOptions
+ ? static_cast<const onert_tflite::LogicalOrOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::OneHotOptions *builtin_options_as_OneHotOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_OneHotOptions
+ ? static_cast<const onert_tflite::OneHotOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::LogicalAndOptions *builtin_options_as_LogicalAndOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_LogicalAndOptions
+ ? static_cast<const onert_tflite::LogicalAndOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::LogicalNotOptions *builtin_options_as_LogicalNotOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_LogicalNotOptions
+ ? static_cast<const onert_tflite::LogicalNotOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::UnpackOptions *builtin_options_as_UnpackOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_UnpackOptions
+ ? static_cast<const onert_tflite::UnpackOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::FloorDivOptions *builtin_options_as_FloorDivOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_FloorDivOptions
+ ? static_cast<const onert_tflite::FloorDivOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SquareOptions *builtin_options_as_SquareOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SquareOptions
+ ? static_cast<const onert_tflite::SquareOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ZerosLikeOptions *builtin_options_as_ZerosLikeOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ZerosLikeOptions
+ ? static_cast<const onert_tflite::ZerosLikeOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::FillOptions *builtin_options_as_FillOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_FillOptions
+ ? static_cast<const onert_tflite::FillOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::BidirectionalSequenceLSTMOptions *
+ builtin_options_as_BidirectionalSequenceLSTMOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions
+ ? static_cast<const onert_tflite::BidirectionalSequenceLSTMOptions *>(
+ builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::BidirectionalSequenceRNNOptions *
+ builtin_options_as_BidirectionalSequenceRNNOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_BidirectionalSequenceRNNOptions
+ ? static_cast<const onert_tflite::BidirectionalSequenceRNNOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::UnidirectionalSequenceLSTMOptions *
+ builtin_options_as_UnidirectionalSequenceLSTMOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions
+ ? static_cast<const onert_tflite::UnidirectionalSequenceLSTMOptions *>(
+ builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::FloorModOptions *builtin_options_as_FloorModOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_FloorModOptions
+ ? static_cast<const onert_tflite::FloorModOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::RangeOptions *builtin_options_as_RangeOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_RangeOptions
+ ? static_cast<const onert_tflite::RangeOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ResizeNearestNeighborOptions *
+ builtin_options_as_ResizeNearestNeighborOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ResizeNearestNeighborOptions
+ ? static_cast<const onert_tflite::ResizeNearestNeighborOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::LeakyReluOptions *builtin_options_as_LeakyReluOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_LeakyReluOptions
+ ? static_cast<const onert_tflite::LeakyReluOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SquaredDifferenceOptions *builtin_options_as_SquaredDifferenceOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SquaredDifferenceOptions
+ ? static_cast<const onert_tflite::SquaredDifferenceOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::MirrorPadOptions *builtin_options_as_MirrorPadOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_MirrorPadOptions
+ ? static_cast<const onert_tflite::MirrorPadOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::AbsOptions *builtin_options_as_AbsOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_AbsOptions
+ ? static_cast<const onert_tflite::AbsOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SplitVOptions *builtin_options_as_SplitVOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SplitVOptions
+ ? static_cast<const onert_tflite::SplitVOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::UniqueOptions *builtin_options_as_UniqueOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_UniqueOptions
+ ? static_cast<const onert_tflite::UniqueOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ReverseV2Options *builtin_options_as_ReverseV2Options() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ReverseV2Options
+ ? static_cast<const onert_tflite::ReverseV2Options *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::AddNOptions *builtin_options_as_AddNOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_AddNOptions
+ ? static_cast<const onert_tflite::AddNOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::GatherNdOptions *builtin_options_as_GatherNdOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_GatherNdOptions
+ ? static_cast<const onert_tflite::GatherNdOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::CosOptions *builtin_options_as_CosOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_CosOptions
+ ? static_cast<const onert_tflite::CosOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::WhereOptions *builtin_options_as_WhereOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_WhereOptions
+ ? static_cast<const onert_tflite::WhereOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::RankOptions *builtin_options_as_RankOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_RankOptions
+ ? static_cast<const onert_tflite::RankOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ReverseSequenceOptions *builtin_options_as_ReverseSequenceOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ReverseSequenceOptions
+ ? static_cast<const onert_tflite::ReverseSequenceOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::MatrixDiagOptions *builtin_options_as_MatrixDiagOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_MatrixDiagOptions
+ ? static_cast<const onert_tflite::MatrixDiagOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::QuantizeOptions *builtin_options_as_QuantizeOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_QuantizeOptions
+ ? static_cast<const onert_tflite::QuantizeOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::MatrixSetDiagOptions *builtin_options_as_MatrixSetDiagOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_MatrixSetDiagOptions
+ ? static_cast<const onert_tflite::MatrixSetDiagOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::HardSwishOptions *builtin_options_as_HardSwishOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_HardSwishOptions
+ ? static_cast<const onert_tflite::HardSwishOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::IfOptions *builtin_options_as_IfOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_IfOptions
+ ? static_cast<const onert_tflite::IfOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::WhileOptions *builtin_options_as_WhileOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_WhileOptions
+ ? static_cast<const onert_tflite::WhileOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::DepthToSpaceOptions *builtin_options_as_DepthToSpaceOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_DepthToSpaceOptions
+ ? static_cast<const onert_tflite::DepthToSpaceOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::NonMaxSuppressionV4Options *
+ builtin_options_as_NonMaxSuppressionV4Options() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_NonMaxSuppressionV4Options
+ ? static_cast<const onert_tflite::NonMaxSuppressionV4Options *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::NonMaxSuppressionV5Options *
+ builtin_options_as_NonMaxSuppressionV5Options() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_NonMaxSuppressionV5Options
+ ? static_cast<const onert_tflite::NonMaxSuppressionV5Options *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ScatterNdOptions *builtin_options_as_ScatterNdOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ScatterNdOptions
+ ? static_cast<const onert_tflite::ScatterNdOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SelectV2Options *builtin_options_as_SelectV2Options() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SelectV2Options
+ ? static_cast<const onert_tflite::SelectV2Options *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::DensifyOptions *builtin_options_as_DensifyOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_DensifyOptions
+ ? static_cast<const onert_tflite::DensifyOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::SegmentSumOptions *builtin_options_as_SegmentSumOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_SegmentSumOptions
+ ? static_cast<const onert_tflite::SegmentSumOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::BatchMatMulOptions *builtin_options_as_BatchMatMulOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_BatchMatMulOptions
+ ? static_cast<const onert_tflite::BatchMatMulOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::CumsumOptions *builtin_options_as_CumsumOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_CumsumOptions
+ ? static_cast<const onert_tflite::CumsumOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::CallOnceOptions *builtin_options_as_CallOnceOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_CallOnceOptions
+ ? static_cast<const onert_tflite::CallOnceOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::BroadcastToOptions *builtin_options_as_BroadcastToOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_BroadcastToOptions
+ ? static_cast<const onert_tflite::BroadcastToOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::Rfft2dOptions *builtin_options_as_Rfft2dOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_Rfft2dOptions
+ ? static_cast<const onert_tflite::Rfft2dOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::Conv3DOptions *builtin_options_as_Conv3DOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_Conv3DOptions
+ ? static_cast<const onert_tflite::Conv3DOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::HashtableOptions *builtin_options_as_HashtableOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_HashtableOptions
+ ? static_cast<const onert_tflite::HashtableOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::HashtableFindOptions *builtin_options_as_HashtableFindOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_HashtableFindOptions
+ ? static_cast<const onert_tflite::HashtableFindOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::HashtableImportOptions *builtin_options_as_HashtableImportOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_HashtableImportOptions
+ ? static_cast<const onert_tflite::HashtableImportOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::HashtableSizeOptions *builtin_options_as_HashtableSizeOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_HashtableSizeOptions
+ ? static_cast<const onert_tflite::HashtableSizeOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::VarHandleOptions *builtin_options_as_VarHandleOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_VarHandleOptions
+ ? static_cast<const onert_tflite::VarHandleOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ReadVariableOptions *builtin_options_as_ReadVariableOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ReadVariableOptions
+ ? static_cast<const onert_tflite::ReadVariableOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::AssignVariableOptions *builtin_options_as_AssignVariableOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_AssignVariableOptions
+ ? static_cast<const onert_tflite::AssignVariableOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::RandomOptions *builtin_options_as_RandomOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_RandomOptions
+ ? static_cast<const onert_tflite::RandomOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::BucketizeOptions *builtin_options_as_BucketizeOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_BucketizeOptions
+ ? static_cast<const onert_tflite::BucketizeOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::GeluOptions *builtin_options_as_GeluOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_GeluOptions
+ ? static_cast<const onert_tflite::GeluOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::DynamicUpdateSliceOptions *
+ builtin_options_as_DynamicUpdateSliceOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_DynamicUpdateSliceOptions
+ ? static_cast<const onert_tflite::DynamicUpdateSliceOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::UnsortedSegmentProdOptions *
+ builtin_options_as_UnsortedSegmentProdOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_UnsortedSegmentProdOptions
+ ? static_cast<const onert_tflite::UnsortedSegmentProdOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::UnsortedSegmentMaxOptions *
+ builtin_options_as_UnsortedSegmentMaxOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_UnsortedSegmentMaxOptions
+ ? static_cast<const onert_tflite::UnsortedSegmentMaxOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::UnsortedSegmentSumOptions *
+ builtin_options_as_UnsortedSegmentSumOptions() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_UnsortedSegmentSumOptions
+ ? static_cast<const onert_tflite::UnsortedSegmentSumOptions *>(builtin_options())
+ : nullptr;
+ }
+ const onert_tflite::ATan2Options *builtin_options_as_ATan2Options() const
+ {
+ return builtin_options_type() == onert_tflite::BuiltinOptions_ATan2Options
+ ? static_cast<const onert_tflite::ATan2Options *>(builtin_options())
+ : nullptr;
+ }
+ const ::flatbuffers::Vector<uint8_t> *custom_options() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
+ }
+ onert_tflite::CustomOptionsFormat custom_options_format() const
+ {
+ return static_cast<onert_tflite::CustomOptionsFormat>(
+ GetField<int8_t>(VT_CUSTOM_OPTIONS_FORMAT, 0));
+ }
+ const ::flatbuffers::Vector<uint8_t> *mutating_variable_inputs() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_MUTATING_VARIABLE_INPUTS);
+ }
+ const ::flatbuffers::Vector<int32_t> *intermediates() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_INTERMEDIATES);
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<uint32_t>(verifier, VT_OPCODE_INDEX, 4) &&
+ VerifyOffset(verifier, VT_INPUTS) && verifier.VerifyVector(inputs()) &&
+ VerifyOffset(verifier, VT_OUTPUTS) && verifier.VerifyVector(outputs()) &&
+ VerifyField<uint8_t>(verifier, VT_BUILTIN_OPTIONS_TYPE, 1) &&
+ VerifyOffset(verifier, VT_BUILTIN_OPTIONS) &&
+ VerifyBuiltinOptions(verifier, builtin_options(), builtin_options_type()) &&
+ VerifyOffset(verifier, VT_CUSTOM_OPTIONS) && verifier.VerifyVector(custom_options()) &&
+ VerifyField<int8_t>(verifier, VT_CUSTOM_OPTIONS_FORMAT, 1) &&
+ VerifyOffset(verifier, VT_MUTATING_VARIABLE_INPUTS) &&
+ verifier.VerifyVector(mutating_variable_inputs()) &&
+ VerifyOffset(verifier, VT_INTERMEDIATES) && verifier.VerifyVector(intermediates()) &&
+ verifier.EndTable();
+ }
+};
+
+template <>
+inline const onert_tflite::Conv2DOptions *
+Operator::builtin_options_as<onert_tflite::Conv2DOptions>() const
+{
+ return builtin_options_as_Conv2DOptions();
+}
+
+template <>
+inline const onert_tflite::DepthwiseConv2DOptions *
+Operator::builtin_options_as<onert_tflite::DepthwiseConv2DOptions>() const
+{
+ return builtin_options_as_DepthwiseConv2DOptions();
+}
+
+template <>
+inline const onert_tflite::ConcatEmbeddingsOptions *
+Operator::builtin_options_as<onert_tflite::ConcatEmbeddingsOptions>() const
+{
+ return builtin_options_as_ConcatEmbeddingsOptions();
+}
+
+template <>
+inline const onert_tflite::LSHProjectionOptions *
+Operator::builtin_options_as<onert_tflite::LSHProjectionOptions>() const
+{
+ return builtin_options_as_LSHProjectionOptions();
+}
+
+template <>
+inline const onert_tflite::Pool2DOptions *
+Operator::builtin_options_as<onert_tflite::Pool2DOptions>() const
+{
+ return builtin_options_as_Pool2DOptions();
+}
+
+template <>
+inline const onert_tflite::SVDFOptions *
+Operator::builtin_options_as<onert_tflite::SVDFOptions>() const
+{
+ return builtin_options_as_SVDFOptions();
+}
+
+template <>
+inline const onert_tflite::RNNOptions *
+Operator::builtin_options_as<onert_tflite::RNNOptions>() const
+{
+ return builtin_options_as_RNNOptions();
+}
+
+template <>
+inline const onert_tflite::FullyConnectedOptions *
+Operator::builtin_options_as<onert_tflite::FullyConnectedOptions>() const
+{
+ return builtin_options_as_FullyConnectedOptions();
+}
+
+template <>
+inline const onert_tflite::SoftmaxOptions *
+Operator::builtin_options_as<onert_tflite::SoftmaxOptions>() const
+{
+ return builtin_options_as_SoftmaxOptions();
+}
+
+template <>
+inline const onert_tflite::ConcatenationOptions *
+Operator::builtin_options_as<onert_tflite::ConcatenationOptions>() const
+{
+ return builtin_options_as_ConcatenationOptions();
+}
+
+template <>
+inline const onert_tflite::AddOptions *
+Operator::builtin_options_as<onert_tflite::AddOptions>() const
+{
+ return builtin_options_as_AddOptions();
+}
+
+template <>
+inline const onert_tflite::L2NormOptions *
+Operator::builtin_options_as<onert_tflite::L2NormOptions>() const
+{
+ return builtin_options_as_L2NormOptions();
+}
+
+template <>
+inline const onert_tflite::LocalResponseNormalizationOptions *
+Operator::builtin_options_as<onert_tflite::LocalResponseNormalizationOptions>() const
+{
+ return builtin_options_as_LocalResponseNormalizationOptions();
+}
+
+template <>
+inline const onert_tflite::LSTMOptions *
+Operator::builtin_options_as<onert_tflite::LSTMOptions>() const
+{
+ return builtin_options_as_LSTMOptions();
+}
+
+template <>
+inline const onert_tflite::ResizeBilinearOptions *
+Operator::builtin_options_as<onert_tflite::ResizeBilinearOptions>() const
+{
+ return builtin_options_as_ResizeBilinearOptions();
+}
+
+template <>
+inline const onert_tflite::CallOptions *
+Operator::builtin_options_as<onert_tflite::CallOptions>() const
+{
+ return builtin_options_as_CallOptions();
+}
+
+template <>
+inline const onert_tflite::ReshapeOptions *
+Operator::builtin_options_as<onert_tflite::ReshapeOptions>() const
+{
+ return builtin_options_as_ReshapeOptions();
+}
+
+template <>
+inline const onert_tflite::SkipGramOptions *
+Operator::builtin_options_as<onert_tflite::SkipGramOptions>() const
+{
+ return builtin_options_as_SkipGramOptions();
+}
+
+template <>
+inline const onert_tflite::SpaceToDepthOptions *
+Operator::builtin_options_as<onert_tflite::SpaceToDepthOptions>() const
+{
+ return builtin_options_as_SpaceToDepthOptions();
+}
+
+template <>
+inline const onert_tflite::EmbeddingLookupSparseOptions *
+Operator::builtin_options_as<onert_tflite::EmbeddingLookupSparseOptions>() const
+{
+ return builtin_options_as_EmbeddingLookupSparseOptions();
+}
+
+template <>
+inline const onert_tflite::MulOptions *
+Operator::builtin_options_as<onert_tflite::MulOptions>() const
+{
+ return builtin_options_as_MulOptions();
+}
+
+template <>
+inline const onert_tflite::PadOptions *
+Operator::builtin_options_as<onert_tflite::PadOptions>() const
+{
+ return builtin_options_as_PadOptions();
+}
+
+template <>
+inline const onert_tflite::GatherOptions *
+Operator::builtin_options_as<onert_tflite::GatherOptions>() const
+{
+ return builtin_options_as_GatherOptions();
+}
+
+template <>
+inline const onert_tflite::BatchToSpaceNDOptions *
+Operator::builtin_options_as<onert_tflite::BatchToSpaceNDOptions>() const
+{
+ return builtin_options_as_BatchToSpaceNDOptions();
+}
+
+template <>
+inline const onert_tflite::SpaceToBatchNDOptions *
+Operator::builtin_options_as<onert_tflite::SpaceToBatchNDOptions>() const
+{
+ return builtin_options_as_SpaceToBatchNDOptions();
+}
+
+template <>
+inline const onert_tflite::TransposeOptions *
+Operator::builtin_options_as<onert_tflite::TransposeOptions>() const
+{
+ return builtin_options_as_TransposeOptions();
+}
+
+template <>
+inline const onert_tflite::ReducerOptions *
+Operator::builtin_options_as<onert_tflite::ReducerOptions>() const
+{
+ return builtin_options_as_ReducerOptions();
+}
+
+template <>
+inline const onert_tflite::SubOptions *
+Operator::builtin_options_as<onert_tflite::SubOptions>() const
+{
+ return builtin_options_as_SubOptions();
+}
+
+template <>
+inline const onert_tflite::DivOptions *
+Operator::builtin_options_as<onert_tflite::DivOptions>() const
+{
+ return builtin_options_as_DivOptions();
+}
+
+template <>
+inline const onert_tflite::SqueezeOptions *
+Operator::builtin_options_as<onert_tflite::SqueezeOptions>() const
+{
+ return builtin_options_as_SqueezeOptions();
+}
+
+template <>
+inline const onert_tflite::SequenceRNNOptions *
+Operator::builtin_options_as<onert_tflite::SequenceRNNOptions>() const
+{
+ return builtin_options_as_SequenceRNNOptions();
+}
+
+template <>
+inline const onert_tflite::StridedSliceOptions *
+Operator::builtin_options_as<onert_tflite::StridedSliceOptions>() const
+{
+ return builtin_options_as_StridedSliceOptions();
+}
+
+template <>
+inline const onert_tflite::ExpOptions *
+Operator::builtin_options_as<onert_tflite::ExpOptions>() const
+{
+ return builtin_options_as_ExpOptions();
+}
+
+template <>
+inline const onert_tflite::TopKV2Options *
+Operator::builtin_options_as<onert_tflite::TopKV2Options>() const
+{
+ return builtin_options_as_TopKV2Options();
+}
+
+template <>
+inline const onert_tflite::SplitOptions *
+Operator::builtin_options_as<onert_tflite::SplitOptions>() const
+{
+ return builtin_options_as_SplitOptions();
+}
+
+template <>
+inline const onert_tflite::LogSoftmaxOptions *
+Operator::builtin_options_as<onert_tflite::LogSoftmaxOptions>() const
+{
+ return builtin_options_as_LogSoftmaxOptions();
+}
+
+template <>
+inline const onert_tflite::CastOptions *
+Operator::builtin_options_as<onert_tflite::CastOptions>() const
+{
+ return builtin_options_as_CastOptions();
+}
+
+template <>
+inline const onert_tflite::DequantizeOptions *
+Operator::builtin_options_as<onert_tflite::DequantizeOptions>() const
+{
+ return builtin_options_as_DequantizeOptions();
+}
+
+template <>
+inline const onert_tflite::MaximumMinimumOptions *
+Operator::builtin_options_as<onert_tflite::MaximumMinimumOptions>() const
+{
+ return builtin_options_as_MaximumMinimumOptions();
+}
+
+template <>
+inline const onert_tflite::ArgMaxOptions *
+Operator::builtin_options_as<onert_tflite::ArgMaxOptions>() const
+{
+ return builtin_options_as_ArgMaxOptions();
+}
+
+template <>
+inline const onert_tflite::LessOptions *
+Operator::builtin_options_as<onert_tflite::LessOptions>() const
+{
+ return builtin_options_as_LessOptions();
+}
+
+template <>
+inline const onert_tflite::NegOptions *
+Operator::builtin_options_as<onert_tflite::NegOptions>() const
+{
+ return builtin_options_as_NegOptions();
+}
+
+template <>
+inline const onert_tflite::PadV2Options *
+Operator::builtin_options_as<onert_tflite::PadV2Options>() const
+{
+ return builtin_options_as_PadV2Options();
+}
+
+template <>
+inline const onert_tflite::GreaterOptions *
+Operator::builtin_options_as<onert_tflite::GreaterOptions>() const
+{
+ return builtin_options_as_GreaterOptions();
+}
+
+template <>
+inline const onert_tflite::GreaterEqualOptions *
+Operator::builtin_options_as<onert_tflite::GreaterEqualOptions>() const
+{
+ return builtin_options_as_GreaterEqualOptions();
+}
+
+template <>
+inline const onert_tflite::LessEqualOptions *
+Operator::builtin_options_as<onert_tflite::LessEqualOptions>() const
+{
+ return builtin_options_as_LessEqualOptions();
+}
+
+template <>
+inline const onert_tflite::SelectOptions *
+Operator::builtin_options_as<onert_tflite::SelectOptions>() const
+{
+ return builtin_options_as_SelectOptions();
+}
+
+template <>
+inline const onert_tflite::SliceOptions *
+Operator::builtin_options_as<onert_tflite::SliceOptions>() const
+{
+ return builtin_options_as_SliceOptions();
+}
+
+template <>
+inline const onert_tflite::TransposeConvOptions *
+Operator::builtin_options_as<onert_tflite::TransposeConvOptions>() const
+{
+ return builtin_options_as_TransposeConvOptions();
+}
+
+template <>
+inline const onert_tflite::SparseToDenseOptions *
+Operator::builtin_options_as<onert_tflite::SparseToDenseOptions>() const
+{
+ return builtin_options_as_SparseToDenseOptions();
+}
+
+template <>
+inline const onert_tflite::TileOptions *
+Operator::builtin_options_as<onert_tflite::TileOptions>() const
+{
+ return builtin_options_as_TileOptions();
+}
+
+template <>
+inline const onert_tflite::ExpandDimsOptions *
+Operator::builtin_options_as<onert_tflite::ExpandDimsOptions>() const
+{
+ return builtin_options_as_ExpandDimsOptions();
+}
+
+template <>
+inline const onert_tflite::EqualOptions *
+Operator::builtin_options_as<onert_tflite::EqualOptions>() const
+{
+ return builtin_options_as_EqualOptions();
+}
+
+template <>
+inline const onert_tflite::NotEqualOptions *
+Operator::builtin_options_as<onert_tflite::NotEqualOptions>() const
+{
+ return builtin_options_as_NotEqualOptions();
+}
+
+template <>
+inline const onert_tflite::ShapeOptions *
+Operator::builtin_options_as<onert_tflite::ShapeOptions>() const
+{
+ return builtin_options_as_ShapeOptions();
+}
+
+template <>
+inline const onert_tflite::PowOptions *
+Operator::builtin_options_as<onert_tflite::PowOptions>() const
+{
+ return builtin_options_as_PowOptions();
+}
+
+template <>
+inline const onert_tflite::ArgMinOptions *
+Operator::builtin_options_as<onert_tflite::ArgMinOptions>() const
+{
+ return builtin_options_as_ArgMinOptions();
+}
+
+template <>
+inline const onert_tflite::FakeQuantOptions *
+Operator::builtin_options_as<onert_tflite::FakeQuantOptions>() const
+{
+ return builtin_options_as_FakeQuantOptions();
+}
+
+template <>
+inline const onert_tflite::PackOptions *
+Operator::builtin_options_as<onert_tflite::PackOptions>() const
+{
+ return builtin_options_as_PackOptions();
+}
+
+template <>
+inline const onert_tflite::LogicalOrOptions *
+Operator::builtin_options_as<onert_tflite::LogicalOrOptions>() const
+{
+ return builtin_options_as_LogicalOrOptions();
+}
+
+template <>
+inline const onert_tflite::OneHotOptions *
+Operator::builtin_options_as<onert_tflite::OneHotOptions>() const
+{
+ return builtin_options_as_OneHotOptions();
+}
+
+template <>
+inline const onert_tflite::LogicalAndOptions *
+Operator::builtin_options_as<onert_tflite::LogicalAndOptions>() const
+{
+ return builtin_options_as_LogicalAndOptions();
+}
+
+template <>
+inline const onert_tflite::LogicalNotOptions *
+Operator::builtin_options_as<onert_tflite::LogicalNotOptions>() const
+{
+ return builtin_options_as_LogicalNotOptions();
+}
+
+template <>
+inline const onert_tflite::UnpackOptions *
+Operator::builtin_options_as<onert_tflite::UnpackOptions>() const
+{
+ return builtin_options_as_UnpackOptions();
+}
+
+template <>
+inline const onert_tflite::FloorDivOptions *
+Operator::builtin_options_as<onert_tflite::FloorDivOptions>() const
+{
+ return builtin_options_as_FloorDivOptions();
+}
+
+template <>
+inline const onert_tflite::SquareOptions *
+Operator::builtin_options_as<onert_tflite::SquareOptions>() const
+{
+ return builtin_options_as_SquareOptions();
+}
+
+template <>
+inline const onert_tflite::ZerosLikeOptions *
+Operator::builtin_options_as<onert_tflite::ZerosLikeOptions>() const
+{
+ return builtin_options_as_ZerosLikeOptions();
+}
+
+template <>
+inline const onert_tflite::FillOptions *
+Operator::builtin_options_as<onert_tflite::FillOptions>() const
+{
+ return builtin_options_as_FillOptions();
+}
+
+template <>
+inline const onert_tflite::BidirectionalSequenceLSTMOptions *
+Operator::builtin_options_as<onert_tflite::BidirectionalSequenceLSTMOptions>() const
+{
+ return builtin_options_as_BidirectionalSequenceLSTMOptions();
+}
+
+template <>
+inline const onert_tflite::BidirectionalSequenceRNNOptions *
+Operator::builtin_options_as<onert_tflite::BidirectionalSequenceRNNOptions>() const
+{
+ return builtin_options_as_BidirectionalSequenceRNNOptions();
+}
+
+template <>
+inline const onert_tflite::UnidirectionalSequenceLSTMOptions *
+Operator::builtin_options_as<onert_tflite::UnidirectionalSequenceLSTMOptions>() const
+{
+ return builtin_options_as_UnidirectionalSequenceLSTMOptions();
+}
+
+template <>
+inline const onert_tflite::FloorModOptions *
+Operator::builtin_options_as<onert_tflite::FloorModOptions>() const
+{
+ return builtin_options_as_FloorModOptions();
+}
+
+template <>
+inline const onert_tflite::RangeOptions *
+Operator::builtin_options_as<onert_tflite::RangeOptions>() const
+{
+ return builtin_options_as_RangeOptions();
+}
+
+template <>
+inline const onert_tflite::ResizeNearestNeighborOptions *
+Operator::builtin_options_as<onert_tflite::ResizeNearestNeighborOptions>() const
+{
+ return builtin_options_as_ResizeNearestNeighborOptions();
+}
+
+template <>
+inline const onert_tflite::LeakyReluOptions *
+Operator::builtin_options_as<onert_tflite::LeakyReluOptions>() const
+{
+ return builtin_options_as_LeakyReluOptions();
+}
+
+template <>
+inline const onert_tflite::SquaredDifferenceOptions *
+Operator::builtin_options_as<onert_tflite::SquaredDifferenceOptions>() const
+{
+ return builtin_options_as_SquaredDifferenceOptions();
+}
+
+template <>
+inline const onert_tflite::MirrorPadOptions *
+Operator::builtin_options_as<onert_tflite::MirrorPadOptions>() const
+{
+ return builtin_options_as_MirrorPadOptions();
+}
+
+template <>
+inline const onert_tflite::AbsOptions *
+Operator::builtin_options_as<onert_tflite::AbsOptions>() const
+{
+ return builtin_options_as_AbsOptions();
+}
+
+template <>
+inline const onert_tflite::SplitVOptions *
+Operator::builtin_options_as<onert_tflite::SplitVOptions>() const
+{
+ return builtin_options_as_SplitVOptions();
+}
+
+template <>
+inline const onert_tflite::UniqueOptions *
+Operator::builtin_options_as<onert_tflite::UniqueOptions>() const
+{
+ return builtin_options_as_UniqueOptions();
+}
+
+template <>
+inline const onert_tflite::ReverseV2Options *
+Operator::builtin_options_as<onert_tflite::ReverseV2Options>() const
+{
+ return builtin_options_as_ReverseV2Options();
+}
+
+template <>
+inline const onert_tflite::AddNOptions *
+Operator::builtin_options_as<onert_tflite::AddNOptions>() const
+{
+ return builtin_options_as_AddNOptions();
+}
+
+template <>
+inline const onert_tflite::GatherNdOptions *
+Operator::builtin_options_as<onert_tflite::GatherNdOptions>() const
+{
+ return builtin_options_as_GatherNdOptions();
+}
+
+template <>
+inline const onert_tflite::CosOptions *
+Operator::builtin_options_as<onert_tflite::CosOptions>() const
+{
+ return builtin_options_as_CosOptions();
+}
+
+template <>
+inline const onert_tflite::WhereOptions *
+Operator::builtin_options_as<onert_tflite::WhereOptions>() const
+{
+ return builtin_options_as_WhereOptions();
+}
+
+template <>
+inline const onert_tflite::RankOptions *
+Operator::builtin_options_as<onert_tflite::RankOptions>() const
+{
+ return builtin_options_as_RankOptions();
+}
+
+template <>
+inline const onert_tflite::ReverseSequenceOptions *
+Operator::builtin_options_as<onert_tflite::ReverseSequenceOptions>() const
+{
+ return builtin_options_as_ReverseSequenceOptions();
+}
+
+template <>
+inline const onert_tflite::MatrixDiagOptions *
+Operator::builtin_options_as<onert_tflite::MatrixDiagOptions>() const
+{
+ return builtin_options_as_MatrixDiagOptions();
+}
+
+template <>
+inline const onert_tflite::QuantizeOptions *
+Operator::builtin_options_as<onert_tflite::QuantizeOptions>() const
+{
+ return builtin_options_as_QuantizeOptions();
+}
+
+template <>
+inline const onert_tflite::MatrixSetDiagOptions *
+Operator::builtin_options_as<onert_tflite::MatrixSetDiagOptions>() const
+{
+ return builtin_options_as_MatrixSetDiagOptions();
+}
+
+template <>
+inline const onert_tflite::HardSwishOptions *
+Operator::builtin_options_as<onert_tflite::HardSwishOptions>() const
+{
+ return builtin_options_as_HardSwishOptions();
+}
+
+template <>
+inline const onert_tflite::IfOptions *Operator::builtin_options_as<onert_tflite::IfOptions>() const
+{
+ return builtin_options_as_IfOptions();
+}
+
+template <>
+inline const onert_tflite::WhileOptions *
+Operator::builtin_options_as<onert_tflite::WhileOptions>() const
+{
+ return builtin_options_as_WhileOptions();
+}
+
+template <>
+inline const onert_tflite::DepthToSpaceOptions *
+Operator::builtin_options_as<onert_tflite::DepthToSpaceOptions>() const
+{
+ return builtin_options_as_DepthToSpaceOptions();
+}
+
+template <>
+inline const onert_tflite::NonMaxSuppressionV4Options *
+Operator::builtin_options_as<onert_tflite::NonMaxSuppressionV4Options>() const
+{
+ return builtin_options_as_NonMaxSuppressionV4Options();
+}
+
+template <>
+inline const onert_tflite::NonMaxSuppressionV5Options *
+Operator::builtin_options_as<onert_tflite::NonMaxSuppressionV5Options>() const
+{
+ return builtin_options_as_NonMaxSuppressionV5Options();
+}
+
+template <>
+inline const onert_tflite::ScatterNdOptions *
+Operator::builtin_options_as<onert_tflite::ScatterNdOptions>() const
+{
+ return builtin_options_as_ScatterNdOptions();
+}
+
+template <>
+inline const onert_tflite::SelectV2Options *
+Operator::builtin_options_as<onert_tflite::SelectV2Options>() const
+{
+ return builtin_options_as_SelectV2Options();
+}
+
+template <>
+inline const onert_tflite::DensifyOptions *
+Operator::builtin_options_as<onert_tflite::DensifyOptions>() const
+{
+ return builtin_options_as_DensifyOptions();
+}
+
+template <>
+inline const onert_tflite::SegmentSumOptions *
+Operator::builtin_options_as<onert_tflite::SegmentSumOptions>() const
+{
+ return builtin_options_as_SegmentSumOptions();
+}
+
+template <>
+inline const onert_tflite::BatchMatMulOptions *
+Operator::builtin_options_as<onert_tflite::BatchMatMulOptions>() const
+{
+ return builtin_options_as_BatchMatMulOptions();
+}
+
+template <>
+inline const onert_tflite::CumsumOptions *
+Operator::builtin_options_as<onert_tflite::CumsumOptions>() const
+{
+ return builtin_options_as_CumsumOptions();
+}
+
+template <>
+inline const onert_tflite::CallOnceOptions *
+Operator::builtin_options_as<onert_tflite::CallOnceOptions>() const
+{
+ return builtin_options_as_CallOnceOptions();
+}
+
+template <>
+inline const onert_tflite::BroadcastToOptions *
+Operator::builtin_options_as<onert_tflite::BroadcastToOptions>() const
+{
+ return builtin_options_as_BroadcastToOptions();
+}
+
+template <>
+inline const onert_tflite::Rfft2dOptions *
+Operator::builtin_options_as<onert_tflite::Rfft2dOptions>() const
+{
+ return builtin_options_as_Rfft2dOptions();
+}
+
+template <>
+inline const onert_tflite::Conv3DOptions *
+Operator::builtin_options_as<onert_tflite::Conv3DOptions>() const
+{
+ return builtin_options_as_Conv3DOptions();
+}
+
+template <>
+inline const onert_tflite::HashtableOptions *
+Operator::builtin_options_as<onert_tflite::HashtableOptions>() const
+{
+ return builtin_options_as_HashtableOptions();
+}
+
+template <>
+inline const onert_tflite::HashtableFindOptions *
+Operator::builtin_options_as<onert_tflite::HashtableFindOptions>() const
+{
+ return builtin_options_as_HashtableFindOptions();
+}
+
+template <>
+inline const onert_tflite::HashtableImportOptions *
+Operator::builtin_options_as<onert_tflite::HashtableImportOptions>() const
+{
+ return builtin_options_as_HashtableImportOptions();
+}
+
+template <>
+inline const onert_tflite::HashtableSizeOptions *
+Operator::builtin_options_as<onert_tflite::HashtableSizeOptions>() const
+{
+ return builtin_options_as_HashtableSizeOptions();
+}
+
+template <>
+inline const onert_tflite::VarHandleOptions *
+Operator::builtin_options_as<onert_tflite::VarHandleOptions>() const
+{
+ return builtin_options_as_VarHandleOptions();
+}
+
+template <>
+inline const onert_tflite::ReadVariableOptions *
+Operator::builtin_options_as<onert_tflite::ReadVariableOptions>() const
+{
+ return builtin_options_as_ReadVariableOptions();
+}
+
+template <>
+inline const onert_tflite::AssignVariableOptions *
+Operator::builtin_options_as<onert_tflite::AssignVariableOptions>() const
+{
+ return builtin_options_as_AssignVariableOptions();
+}
+
+template <>
+inline const onert_tflite::RandomOptions *
+Operator::builtin_options_as<onert_tflite::RandomOptions>() const
+{
+ return builtin_options_as_RandomOptions();
+}
+
+template <>
+inline const onert_tflite::BucketizeOptions *
+Operator::builtin_options_as<onert_tflite::BucketizeOptions>() const
+{
+ return builtin_options_as_BucketizeOptions();
+}
+
+template <>
+inline const onert_tflite::GeluOptions *
+Operator::builtin_options_as<onert_tflite::GeluOptions>() const
+{
+ return builtin_options_as_GeluOptions();
+}
+
+template <>
+inline const onert_tflite::DynamicUpdateSliceOptions *
+Operator::builtin_options_as<onert_tflite::DynamicUpdateSliceOptions>() const
+{
+ return builtin_options_as_DynamicUpdateSliceOptions();
+}
+
+template <>
+inline const onert_tflite::UnsortedSegmentProdOptions *
+Operator::builtin_options_as<onert_tflite::UnsortedSegmentProdOptions>() const
+{
+ return builtin_options_as_UnsortedSegmentProdOptions();
+}
+
+template <>
+inline const onert_tflite::UnsortedSegmentMaxOptions *
+Operator::builtin_options_as<onert_tflite::UnsortedSegmentMaxOptions>() const
+{
+ return builtin_options_as_UnsortedSegmentMaxOptions();
+}
+
+template <>
+inline const onert_tflite::UnsortedSegmentSumOptions *
+Operator::builtin_options_as<onert_tflite::UnsortedSegmentSumOptions>() const
+{
+ return builtin_options_as_UnsortedSegmentSumOptions();
+}
+
+template <>
+inline const onert_tflite::ATan2Options *
+Operator::builtin_options_as<onert_tflite::ATan2Options>() const
+{
+ return builtin_options_as_ATan2Options();
+}
+
+struct OperatorBuilder
+{
+ typedef Operator Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_opcode_index(uint32_t opcode_index)
+ {
+ fbb_.AddElement<uint32_t>(Operator::VT_OPCODE_INDEX, opcode_index, 0);
+ }
+ void add_inputs(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> inputs)
+ {
+ fbb_.AddOffset(Operator::VT_INPUTS, inputs);
+ }
+ void add_outputs(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> outputs)
+ {
+ fbb_.AddOffset(Operator::VT_OUTPUTS, outputs);
+ }
+ void add_builtin_options_type(onert_tflite::BuiltinOptions builtin_options_type)
+ {
+ fbb_.AddElement<uint8_t>(Operator::VT_BUILTIN_OPTIONS_TYPE,
+ static_cast<uint8_t>(builtin_options_type), 0);
+ }
+ void add_builtin_options(::flatbuffers::Offset<void> builtin_options)
+ {
+ fbb_.AddOffset(Operator::VT_BUILTIN_OPTIONS, builtin_options);
+ }
+ void add_custom_options(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> custom_options)
+ {
+ fbb_.AddOffset(Operator::VT_CUSTOM_OPTIONS, custom_options);
+ }
+ void add_custom_options_format(onert_tflite::CustomOptionsFormat custom_options_format)
+ {
+ fbb_.AddElement<int8_t>(Operator::VT_CUSTOM_OPTIONS_FORMAT,
+ static_cast<int8_t>(custom_options_format), 0);
+ }
+ void add_mutating_variable_inputs(
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> mutating_variable_inputs)
+ {
+ fbb_.AddOffset(Operator::VT_MUTATING_VARIABLE_INPUTS, mutating_variable_inputs);
+ }
+ void add_intermediates(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> intermediates)
+ {
+ fbb_.AddOffset(Operator::VT_INTERMEDIATES, intermediates);
+ }
+ explicit OperatorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<Operator> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<Operator>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<Operator> CreateOperator(
+ ::flatbuffers::FlatBufferBuilder &_fbb, uint32_t opcode_index = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> inputs = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> outputs = 0,
+ onert_tflite::BuiltinOptions builtin_options_type = onert_tflite::BuiltinOptions_NONE,
+ ::flatbuffers::Offset<void> builtin_options = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> custom_options = 0,
+ onert_tflite::CustomOptionsFormat custom_options_format =
+ onert_tflite::CustomOptionsFormat_FLEXBUFFERS,
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> mutating_variable_inputs = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> intermediates = 0)
+{
+ OperatorBuilder builder_(_fbb);
+ builder_.add_intermediates(intermediates);
+ builder_.add_mutating_variable_inputs(mutating_variable_inputs);
+ builder_.add_custom_options(custom_options);
+ builder_.add_builtin_options(builtin_options);
+ builder_.add_outputs(outputs);
+ builder_.add_inputs(inputs);
+ builder_.add_opcode_index(opcode_index);
+ builder_.add_custom_options_format(custom_options_format);
+ builder_.add_builtin_options_type(builtin_options_type);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<Operator> CreateOperatorDirect(
+ ::flatbuffers::FlatBufferBuilder &_fbb, uint32_t opcode_index = 0,
+ const std::vector<int32_t> *inputs = nullptr, const std::vector<int32_t> *outputs = nullptr,
+ onert_tflite::BuiltinOptions builtin_options_type = onert_tflite::BuiltinOptions_NONE,
+ ::flatbuffers::Offset<void> builtin_options = 0,
+ const std::vector<uint8_t> *custom_options = nullptr,
+ onert_tflite::CustomOptionsFormat custom_options_format =
+ onert_tflite::CustomOptionsFormat_FLEXBUFFERS,
+ const std::vector<uint8_t> *mutating_variable_inputs = nullptr,
+ const std::vector<int32_t> *intermediates = nullptr)
+{
+ auto inputs__ = inputs ? _fbb.CreateVector<int32_t>(*inputs) : 0;
+ auto outputs__ = outputs ? _fbb.CreateVector<int32_t>(*outputs) : 0;
+ auto custom_options__ = custom_options ? _fbb.CreateVector<uint8_t>(*custom_options) : 0;
+ auto mutating_variable_inputs__ =
+ mutating_variable_inputs ? _fbb.CreateVector<uint8_t>(*mutating_variable_inputs) : 0;
+ auto intermediates__ = intermediates ? _fbb.CreateVector<int32_t>(*intermediates) : 0;
+ return onert_tflite::CreateOperator(_fbb, opcode_index, inputs__, outputs__, builtin_options_type,
+ builtin_options, custom_options__, custom_options_format,
+ mutating_variable_inputs__, intermediates__);
+}
+
+struct SubGraph FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SubGraphBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_TENSORS = 4,
+ VT_INPUTS = 6,
+ VT_OUTPUTS = 8,
+ VT_OPERATORS = 10,
+ VT_NAME = 12
+ };
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Tensor>> *tensors() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Tensor>> *>(
+ VT_TENSORS);
+ }
+ const ::flatbuffers::Vector<int32_t> *inputs() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_INPUTS);
+ }
+ const ::flatbuffers::Vector<int32_t> *outputs() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_OUTPUTS);
+ }
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Operator>> *operators() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Operator>> *>(
+ VT_OPERATORS);
+ }
+ const ::flatbuffers::String *name() const
+ {
+ return GetPointer<const ::flatbuffers::String *>(VT_NAME);
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_TENSORS) &&
+ verifier.VerifyVector(tensors()) && verifier.VerifyVectorOfTables(tensors()) &&
+ VerifyOffset(verifier, VT_INPUTS) && verifier.VerifyVector(inputs()) &&
+ VerifyOffset(verifier, VT_OUTPUTS) && verifier.VerifyVector(outputs()) &&
+ VerifyOffset(verifier, VT_OPERATORS) && verifier.VerifyVector(operators()) &&
+ verifier.VerifyVectorOfTables(operators()) && VerifyOffset(verifier, VT_NAME) &&
+ verifier.VerifyString(name()) && verifier.EndTable();
+ }
+};
+
+struct SubGraphBuilder
+{
+ typedef SubGraph Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_tensors(
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Tensor>>>
+ tensors)
+ {
+ fbb_.AddOffset(SubGraph::VT_TENSORS, tensors);
+ }
+ void add_inputs(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> inputs)
+ {
+ fbb_.AddOffset(SubGraph::VT_INPUTS, inputs);
+ }
+ void add_outputs(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> outputs)
+ {
+ fbb_.AddOffset(SubGraph::VT_OUTPUTS, outputs);
+ }
+ void add_operators(
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Operator>>>
+ operators)
+ {
+ fbb_.AddOffset(SubGraph::VT_OPERATORS, operators);
+ }
+ void add_name(::flatbuffers::Offset<::flatbuffers::String> name)
+ {
+ fbb_.AddOffset(SubGraph::VT_NAME, name);
+ }
+ explicit SubGraphBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SubGraph> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SubGraph>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SubGraph> CreateSubGraph(
+ ::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Tensor>>>
+ tensors = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> inputs = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> outputs = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Operator>>>
+ operators = 0,
+ ::flatbuffers::Offset<::flatbuffers::String> name = 0)
+{
+ SubGraphBuilder builder_(_fbb);
+ builder_.add_name(name);
+ builder_.add_operators(operators);
+ builder_.add_outputs(outputs);
+ builder_.add_inputs(inputs);
+ builder_.add_tensors(tensors);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<SubGraph> CreateSubGraphDirect(
+ ::flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<::flatbuffers::Offset<onert_tflite::Tensor>> *tensors = nullptr,
+ const std::vector<int32_t> *inputs = nullptr, const std::vector<int32_t> *outputs = nullptr,
+ const std::vector<::flatbuffers::Offset<onert_tflite::Operator>> *operators = nullptr,
+ const char *name = nullptr)
+{
+ auto tensors__ =
+ tensors ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::Tensor>>(*tensors) : 0;
+ auto inputs__ = inputs ? _fbb.CreateVector<int32_t>(*inputs) : 0;
+ auto outputs__ = outputs ? _fbb.CreateVector<int32_t>(*outputs) : 0;
+ auto operators__ =
+ operators ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::Operator>>(*operators) : 0;
+ auto name__ = name ? _fbb.CreateString(name) : 0;
+ return onert_tflite::CreateSubGraph(_fbb, tensors__, inputs__, outputs__, operators__, name__);
+}
+
+struct Buffer FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef BufferBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_DATA = 4
+ };
+ const ::flatbuffers::Vector<uint8_t> *data() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_DATA);
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_DATA) &&
+ verifier.VerifyVector(data()) && verifier.EndTable();
+ }
+};
+
+struct BufferBuilder
+{
+ typedef Buffer Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_data(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data)
+ {
+ fbb_.AddOffset(Buffer::VT_DATA, data);
+ }
+ explicit BufferBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<Buffer> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<Buffer>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<Buffer>
+CreateBuffer(::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data = 0)
+{
+ BufferBuilder builder_(_fbb);
+ builder_.add_data(data);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<Buffer> CreateBufferDirect(::flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<uint8_t> *data = nullptr)
+{
+ if (data)
+ {
+ _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 16);
+ }
+ auto data__ = data ? _fbb.CreateVector<uint8_t>(*data) : 0;
+ return onert_tflite::CreateBuffer(_fbb, data__);
+}
+
+struct Metadata FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef MetadataBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_NAME = 4,
+ VT_BUFFER = 6
+ };
+ const ::flatbuffers::String *name() const
+ {
+ return GetPointer<const ::flatbuffers::String *>(VT_NAME);
+ }
+ uint32_t buffer() const { return GetField<uint32_t>(VT_BUFFER, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) &&
+ verifier.VerifyString(name()) && VerifyField<uint32_t>(verifier, VT_BUFFER, 4) &&
+ verifier.EndTable();
+ }
+};
+
+struct MetadataBuilder
+{
+ typedef Metadata Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_name(::flatbuffers::Offset<::flatbuffers::String> name)
+ {
+ fbb_.AddOffset(Metadata::VT_NAME, name);
+ }
+ void add_buffer(uint32_t buffer) { fbb_.AddElement<uint32_t>(Metadata::VT_BUFFER, buffer, 0); }
+ explicit MetadataBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<Metadata> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<Metadata>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<Metadata>
+CreateMetadata(::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::String> name = 0, uint32_t buffer = 0)
+{
+ MetadataBuilder builder_(_fbb);
+ builder_.add_buffer(buffer);
+ builder_.add_name(name);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<Metadata> CreateMetadataDirect(::flatbuffers::FlatBufferBuilder &_fbb,
+ const char *name = nullptr,
+ uint32_t buffer = 0)
+{
+ auto name__ = name ? _fbb.CreateString(name) : 0;
+ return onert_tflite::CreateMetadata(_fbb, name__, buffer);
+}
+
+struct TensorMap FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef TensorMapBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_NAME = 4,
+ VT_TENSOR_INDEX = 6
+ };
+ const ::flatbuffers::String *name() const
+ {
+ return GetPointer<const ::flatbuffers::String *>(VT_NAME);
+ }
+ uint32_t tensor_index() const { return GetField<uint32_t>(VT_TENSOR_INDEX, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) &&
+ verifier.VerifyString(name()) && VerifyField<uint32_t>(verifier, VT_TENSOR_INDEX, 4) &&
+ verifier.EndTable();
+ }
+};
+
+struct TensorMapBuilder
+{
+ typedef TensorMap Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_name(::flatbuffers::Offset<::flatbuffers::String> name)
+ {
+ fbb_.AddOffset(TensorMap::VT_NAME, name);
+ }
+ void add_tensor_index(uint32_t tensor_index)
+ {
+ fbb_.AddElement<uint32_t>(TensorMap::VT_TENSOR_INDEX, tensor_index, 0);
+ }
+ explicit TensorMapBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<TensorMap> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<TensorMap>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<TensorMap>
+CreateTensorMap(::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::String> name = 0, uint32_t tensor_index = 0)
+{
+ TensorMapBuilder builder_(_fbb);
+ builder_.add_tensor_index(tensor_index);
+ builder_.add_name(name);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<TensorMap>
+CreateTensorMapDirect(::flatbuffers::FlatBufferBuilder &_fbb, const char *name = nullptr,
+ uint32_t tensor_index = 0)
+{
+ auto name__ = name ? _fbb.CreateString(name) : 0;
+ return onert_tflite::CreateTensorMap(_fbb, name__, tensor_index);
+}
+
+struct SignatureDef FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef SignatureDefBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_INPUTS = 4,
+ VT_OUTPUTS = 6,
+ VT_SIGNATURE_KEY = 8,
+ VT_SUBGRAPH_INDEX = 12
+ };
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>> *inputs() const
+ {
+ return GetPointer<
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>> *>(VT_INPUTS);
+ }
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>> *outputs() const
+ {
+ return GetPointer<
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>> *>(VT_OUTPUTS);
+ }
+ const ::flatbuffers::String *signature_key() const
+ {
+ return GetPointer<const ::flatbuffers::String *>(VT_SIGNATURE_KEY);
+ }
+ uint32_t subgraph_index() const { return GetField<uint32_t>(VT_SUBGRAPH_INDEX, 0); }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_INPUTS) &&
+ verifier.VerifyVector(inputs()) && verifier.VerifyVectorOfTables(inputs()) &&
+ VerifyOffset(verifier, VT_OUTPUTS) && verifier.VerifyVector(outputs()) &&
+ verifier.VerifyVectorOfTables(outputs()) && VerifyOffset(verifier, VT_SIGNATURE_KEY) &&
+ verifier.VerifyString(signature_key()) &&
+ VerifyField<uint32_t>(verifier, VT_SUBGRAPH_INDEX, 4) && verifier.EndTable();
+ }
+};
+
+struct SignatureDefBuilder
+{
+ typedef SignatureDef Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_inputs(
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>>>
+ inputs)
+ {
+ fbb_.AddOffset(SignatureDef::VT_INPUTS, inputs);
+ }
+ void add_outputs(
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>>>
+ outputs)
+ {
+ fbb_.AddOffset(SignatureDef::VT_OUTPUTS, outputs);
+ }
+ void add_signature_key(::flatbuffers::Offset<::flatbuffers::String> signature_key)
+ {
+ fbb_.AddOffset(SignatureDef::VT_SIGNATURE_KEY, signature_key);
+ }
+ void add_subgraph_index(uint32_t subgraph_index)
+ {
+ fbb_.AddElement<uint32_t>(SignatureDef::VT_SUBGRAPH_INDEX, subgraph_index, 0);
+ }
+ explicit SignatureDefBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<SignatureDef> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<SignatureDef>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<SignatureDef> CreateSignatureDef(
+ ::flatbuffers::FlatBufferBuilder &_fbb,
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>>>
+ inputs = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::TensorMap>>>
+ outputs = 0,
+ ::flatbuffers::Offset<::flatbuffers::String> signature_key = 0, uint32_t subgraph_index = 0)
+{
+ SignatureDefBuilder builder_(_fbb);
+ builder_.add_subgraph_index(subgraph_index);
+ builder_.add_signature_key(signature_key);
+ builder_.add_outputs(outputs);
+ builder_.add_inputs(inputs);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<SignatureDef> CreateSignatureDefDirect(
+ ::flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<::flatbuffers::Offset<onert_tflite::TensorMap>> *inputs = nullptr,
+ const std::vector<::flatbuffers::Offset<onert_tflite::TensorMap>> *outputs = nullptr,
+ const char *signature_key = nullptr, uint32_t subgraph_index = 0)
+{
+ auto inputs__ =
+ inputs ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::TensorMap>>(*inputs) : 0;
+ auto outputs__ =
+ outputs ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::TensorMap>>(*outputs) : 0;
+ auto signature_key__ = signature_key ? _fbb.CreateString(signature_key) : 0;
+ return onert_tflite::CreateSignatureDef(_fbb, inputs__, outputs__, signature_key__,
+ subgraph_index);
+}
+
+struct Model FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table
+{
+ typedef ModelBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE
+ {
+ VT_VERSION = 4,
+ VT_OPERATOR_CODES = 6,
+ VT_SUBGRAPHS = 8,
+ VT_DESCRIPTION = 10,
+ VT_BUFFERS = 12,
+ VT_METADATA_BUFFER = 14,
+ VT_METADATA = 16,
+ VT_SIGNATURE_DEFS = 18
+ };
+ uint32_t version() const { return GetField<uint32_t>(VT_VERSION, 0); }
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::OperatorCode>> *
+ operator_codes() const
+ {
+ return GetPointer<
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::OperatorCode>> *>(
+ VT_OPERATOR_CODES);
+ }
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SubGraph>> *subgraphs() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SubGraph>> *>(
+ VT_SUBGRAPHS);
+ }
+ const ::flatbuffers::String *description() const
+ {
+ return GetPointer<const ::flatbuffers::String *>(VT_DESCRIPTION);
+ }
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Buffer>> *buffers() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Buffer>> *>(
+ VT_BUFFERS);
+ }
+ const ::flatbuffers::Vector<int32_t> *metadata_buffer() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<int32_t> *>(VT_METADATA_BUFFER);
+ }
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Metadata>> *metadata() const
+ {
+ return GetPointer<const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Metadata>> *>(
+ VT_METADATA);
+ }
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SignatureDef>> *
+ signature_defs() const
+ {
+ return GetPointer<
+ const ::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SignatureDef>> *>(
+ VT_SIGNATURE_DEFS);
+ }
+ bool Verify(::flatbuffers::Verifier &verifier) const
+ {
+ return VerifyTableStart(verifier) && VerifyField<uint32_t>(verifier, VT_VERSION, 4) &&
+ VerifyOffset(verifier, VT_OPERATOR_CODES) && verifier.VerifyVector(operator_codes()) &&
+ verifier.VerifyVectorOfTables(operator_codes()) &&
+ VerifyOffset(verifier, VT_SUBGRAPHS) && verifier.VerifyVector(subgraphs()) &&
+ verifier.VerifyVectorOfTables(subgraphs()) && VerifyOffset(verifier, VT_DESCRIPTION) &&
+ verifier.VerifyString(description()) && VerifyOffset(verifier, VT_BUFFERS) &&
+ verifier.VerifyVector(buffers()) && verifier.VerifyVectorOfTables(buffers()) &&
+ VerifyOffset(verifier, VT_METADATA_BUFFER) && verifier.VerifyVector(metadata_buffer()) &&
+ VerifyOffset(verifier, VT_METADATA) && verifier.VerifyVector(metadata()) &&
+ verifier.VerifyVectorOfTables(metadata()) && VerifyOffset(verifier, VT_SIGNATURE_DEFS) &&
+ verifier.VerifyVector(signature_defs()) &&
+ verifier.VerifyVectorOfTables(signature_defs()) && verifier.EndTable();
+ }
+};
+
+struct ModelBuilder
+{
+ typedef Model Table;
+ ::flatbuffers::FlatBufferBuilder &fbb_;
+ ::flatbuffers::uoffset_t start_;
+ void add_version(uint32_t version) { fbb_.AddElement<uint32_t>(Model::VT_VERSION, version, 0); }
+ void add_operator_codes(
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::OperatorCode>>>
+ operator_codes)
+ {
+ fbb_.AddOffset(Model::VT_OPERATOR_CODES, operator_codes);
+ }
+ void add_subgraphs(
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SubGraph>>>
+ subgraphs)
+ {
+ fbb_.AddOffset(Model::VT_SUBGRAPHS, subgraphs);
+ }
+ void add_description(::flatbuffers::Offset<::flatbuffers::String> description)
+ {
+ fbb_.AddOffset(Model::VT_DESCRIPTION, description);
+ }
+ void add_buffers(
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Buffer>>>
+ buffers)
+ {
+ fbb_.AddOffset(Model::VT_BUFFERS, buffers);
+ }
+ void add_metadata_buffer(::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> metadata_buffer)
+ {
+ fbb_.AddOffset(Model::VT_METADATA_BUFFER, metadata_buffer);
+ }
+ void add_metadata(
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Metadata>>>
+ metadata)
+ {
+ fbb_.AddOffset(Model::VT_METADATA, metadata);
+ }
+ void add_signature_defs(
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SignatureDef>>>
+ signature_defs)
+ {
+ fbb_.AddOffset(Model::VT_SIGNATURE_DEFS, signature_defs);
+ }
+ explicit ModelBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb)
+ {
+ start_ = fbb_.StartTable();
+ }
+ ::flatbuffers::Offset<Model> Finish()
+ {
+ const auto end = fbb_.EndTable(start_);
+ auto o = ::flatbuffers::Offset<Model>(end);
+ return o;
+ }
+};
+
+inline ::flatbuffers::Offset<Model> CreateModel(
+ ::flatbuffers::FlatBufferBuilder &_fbb, uint32_t version = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::OperatorCode>>>
+ operator_codes = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SubGraph>>>
+ subgraphs = 0,
+ ::flatbuffers::Offset<::flatbuffers::String> description = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Buffer>>>
+ buffers = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> metadata_buffer = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::Metadata>>>
+ metadata = 0,
+ ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<onert_tflite::SignatureDef>>>
+ signature_defs = 0)
+{
+ ModelBuilder builder_(_fbb);
+ builder_.add_signature_defs(signature_defs);
+ builder_.add_metadata(metadata);
+ builder_.add_metadata_buffer(metadata_buffer);
+ builder_.add_buffers(buffers);
+ builder_.add_description(description);
+ builder_.add_subgraphs(subgraphs);
+ builder_.add_operator_codes(operator_codes);
+ builder_.add_version(version);
+ return builder_.Finish();
+}
+
+inline ::flatbuffers::Offset<Model> CreateModelDirect(
+ ::flatbuffers::FlatBufferBuilder &_fbb, uint32_t version = 0,
+ const std::vector<::flatbuffers::Offset<onert_tflite::OperatorCode>> *operator_codes = nullptr,
+ const std::vector<::flatbuffers::Offset<onert_tflite::SubGraph>> *subgraphs = nullptr,
+ const char *description = nullptr,
+ const std::vector<::flatbuffers::Offset<onert_tflite::Buffer>> *buffers = nullptr,
+ const std::vector<int32_t> *metadata_buffer = nullptr,
+ const std::vector<::flatbuffers::Offset<onert_tflite::Metadata>> *metadata = nullptr,
+ const std::vector<::flatbuffers::Offset<onert_tflite::SignatureDef>> *signature_defs = nullptr)
+{
+ auto operator_codes__ =
+ operator_codes
+ ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::OperatorCode>>(*operator_codes)
+ : 0;
+ auto subgraphs__ =
+ subgraphs ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::SubGraph>>(*subgraphs) : 0;
+ auto description__ = description ? _fbb.CreateString(description) : 0;
+ auto buffers__ =
+ buffers ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::Buffer>>(*buffers) : 0;
+ auto metadata_buffer__ = metadata_buffer ? _fbb.CreateVector<int32_t>(*metadata_buffer) : 0;
+ auto metadata__ =
+ metadata ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::Metadata>>(*metadata) : 0;
+ auto signature_defs__ =
+ signature_defs
+ ? _fbb.CreateVector<::flatbuffers::Offset<onert_tflite::SignatureDef>>(*signature_defs)
+ : 0;
+ return onert_tflite::CreateModel(_fbb, version, operator_codes__, subgraphs__, description__,
+ buffers__, metadata_buffer__, metadata__, signature_defs__);
+}
+
+inline bool VerifyQuantizationDetails(::flatbuffers::Verifier &verifier, const void *obj,
+ QuantizationDetails type)
+{
+ switch (type)
+ {
+ case QuantizationDetails_NONE:
+ {
+ return true;
+ }
+ case QuantizationDetails_CustomQuantization:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::CustomQuantization *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default:
+ return true;
+ }
+}
+
+inline bool
+VerifyQuantizationDetailsVector(::flatbuffers::Verifier &verifier,
+ const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values,
+ const ::flatbuffers::Vector<uint8_t> *types)
+{
+ if (!values || !types)
+ return !values && !types;
+ if (values->size() != types->size())
+ return false;
+ for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i)
+ {
+ if (!VerifyQuantizationDetails(verifier, values->Get(i),
+ types->GetEnum<QuantizationDetails>(i)))
+ {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline bool VerifySparseIndexVector(::flatbuffers::Verifier &verifier, const void *obj,
+ SparseIndexVector type)
+{
+ switch (type)
+ {
+ case SparseIndexVector_NONE:
+ {
+ return true;
+ }
+ case SparseIndexVector_Int32Vector:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::Int32Vector *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case SparseIndexVector_Uint16Vector:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::Uint16Vector *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case SparseIndexVector_Uint8Vector:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::Uint8Vector *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default:
+ return true;
+ }
+}
+
+inline bool
+VerifySparseIndexVectorVector(::flatbuffers::Verifier &verifier,
+ const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values,
+ const ::flatbuffers::Vector<uint8_t> *types)
+{
+ if (!values || !types)
+ return !values && !types;
+ if (values->size() != types->size())
+ return false;
+ for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i)
+ {
+ if (!VerifySparseIndexVector(verifier, values->Get(i), types->GetEnum<SparseIndexVector>(i)))
+ {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline bool VerifyBuiltinOptions(::flatbuffers::Verifier &verifier, const void *obj,
+ BuiltinOptions type)
+{
+ switch (type)
+ {
+ case BuiltinOptions_NONE:
+ {
+ return true;
+ }
+ case BuiltinOptions_Conv2DOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::Conv2DOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_DepthwiseConv2DOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::DepthwiseConv2DOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ConcatEmbeddingsOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ConcatEmbeddingsOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_LSHProjectionOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::LSHProjectionOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_Pool2DOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::Pool2DOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SVDFOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SVDFOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_RNNOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::RNNOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_FullyConnectedOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::FullyConnectedOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SoftmaxOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SoftmaxOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ConcatenationOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ConcatenationOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_AddOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::AddOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_L2NormOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::L2NormOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_LocalResponseNormalizationOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::LocalResponseNormalizationOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_LSTMOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::LSTMOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ResizeBilinearOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ResizeBilinearOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_CallOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::CallOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ReshapeOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ReshapeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SkipGramOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SkipGramOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SpaceToDepthOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SpaceToDepthOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_EmbeddingLookupSparseOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::EmbeddingLookupSparseOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_MulOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::MulOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_PadOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::PadOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_GatherOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::GatherOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_BatchToSpaceNDOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::BatchToSpaceNDOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SpaceToBatchNDOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SpaceToBatchNDOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_TransposeOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::TransposeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ReducerOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ReducerOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SubOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SubOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_DivOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::DivOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SqueezeOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SqueezeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SequenceRNNOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SequenceRNNOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_StridedSliceOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::StridedSliceOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ExpOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ExpOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_TopKV2Options:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::TopKV2Options *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SplitOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SplitOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_LogSoftmaxOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::LogSoftmaxOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_CastOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::CastOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_DequantizeOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::DequantizeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_MaximumMinimumOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::MaximumMinimumOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ArgMaxOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ArgMaxOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_LessOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::LessOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_NegOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::NegOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_PadV2Options:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::PadV2Options *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_GreaterOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::GreaterOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_GreaterEqualOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::GreaterEqualOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_LessEqualOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::LessEqualOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SelectOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SelectOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SliceOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SliceOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_TransposeConvOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::TransposeConvOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SparseToDenseOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SparseToDenseOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_TileOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::TileOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ExpandDimsOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ExpandDimsOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_EqualOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::EqualOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_NotEqualOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::NotEqualOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ShapeOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ShapeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_PowOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::PowOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ArgMinOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ArgMinOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_FakeQuantOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::FakeQuantOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_PackOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::PackOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_LogicalOrOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::LogicalOrOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_OneHotOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::OneHotOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_LogicalAndOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::LogicalAndOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_LogicalNotOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::LogicalNotOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_UnpackOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::UnpackOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_FloorDivOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::FloorDivOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SquareOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SquareOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ZerosLikeOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ZerosLikeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_FillOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::FillOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_BidirectionalSequenceLSTMOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::BidirectionalSequenceLSTMOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_BidirectionalSequenceRNNOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::BidirectionalSequenceRNNOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_UnidirectionalSequenceLSTMOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::UnidirectionalSequenceLSTMOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_FloorModOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::FloorModOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_RangeOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::RangeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ResizeNearestNeighborOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ResizeNearestNeighborOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_LeakyReluOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::LeakyReluOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SquaredDifferenceOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SquaredDifferenceOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_MirrorPadOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::MirrorPadOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_AbsOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::AbsOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SplitVOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SplitVOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_UniqueOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::UniqueOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ReverseV2Options:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ReverseV2Options *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_AddNOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::AddNOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_GatherNdOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::GatherNdOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_CosOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::CosOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_WhereOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::WhereOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_RankOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::RankOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ReverseSequenceOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ReverseSequenceOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_MatrixDiagOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::MatrixDiagOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_QuantizeOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::QuantizeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_MatrixSetDiagOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::MatrixSetDiagOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_HardSwishOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::HardSwishOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_IfOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::IfOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_WhileOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::WhileOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_DepthToSpaceOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::DepthToSpaceOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_NonMaxSuppressionV4Options:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::NonMaxSuppressionV4Options *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_NonMaxSuppressionV5Options:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::NonMaxSuppressionV5Options *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ScatterNdOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ScatterNdOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SelectV2Options:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SelectV2Options *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_DensifyOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::DensifyOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_SegmentSumOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::SegmentSumOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_BatchMatMulOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::BatchMatMulOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_CumsumOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::CumsumOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_CallOnceOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::CallOnceOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_BroadcastToOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::BroadcastToOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_Rfft2dOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::Rfft2dOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_Conv3DOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::Conv3DOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_HashtableOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::HashtableOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_HashtableFindOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::HashtableFindOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_HashtableImportOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::HashtableImportOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_HashtableSizeOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::HashtableSizeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_VarHandleOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::VarHandleOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ReadVariableOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ReadVariableOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_AssignVariableOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::AssignVariableOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_RandomOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::RandomOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_BucketizeOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::BucketizeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_GeluOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::GeluOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_DynamicUpdateSliceOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::DynamicUpdateSliceOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_UnsortedSegmentProdOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::UnsortedSegmentProdOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_UnsortedSegmentMaxOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::UnsortedSegmentMaxOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_UnsortedSegmentSumOptions:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::UnsortedSegmentSumOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ATan2Options:
+ {
+ auto ptr = reinterpret_cast<const onert_tflite::ATan2Options *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ default:
+ return true;
+ }
+}
+
+inline bool
+VerifyBuiltinOptionsVector(::flatbuffers::Verifier &verifier,
+ const ::flatbuffers::Vector<::flatbuffers::Offset<void>> *values,
+ const ::flatbuffers::Vector<uint8_t> *types)
+{
+ if (!values || !types)
+ return !values && !types;
+ if (values->size() != types->size())
+ return false;
+ for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i)
+ {
+ if (!VerifyBuiltinOptions(verifier, values->Get(i), types->GetEnum<BuiltinOptions>(i)))
+ {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline const onert_tflite::Model *GetModel(const void *buf)
+{
+ return ::flatbuffers::GetRoot<onert_tflite::Model>(buf);
+}
+
+inline const onert_tflite::Model *GetSizePrefixedModel(const void *buf)
+{
+ return ::flatbuffers::GetSizePrefixedRoot<onert_tflite::Model>(buf);
+}
+
+inline const char *ModelIdentifier() { return "TFL3"; }
+
+inline bool ModelBufferHasIdentifier(const void *buf)
+{
+ return ::flatbuffers::BufferHasIdentifier(buf, ModelIdentifier());
+}
+
+inline bool SizePrefixedModelBufferHasIdentifier(const void *buf)
+{
+ return ::flatbuffers::BufferHasIdentifier(buf, ModelIdentifier(), true);
+}
+
+inline bool VerifyModelBuffer(::flatbuffers::Verifier &verifier)
+{
+ return verifier.VerifyBuffer<onert_tflite::Model>(ModelIdentifier());
+}
+
+inline bool VerifySizePrefixedModelBuffer(::flatbuffers::Verifier &verifier)
+{
+ return verifier.VerifySizePrefixedBuffer<onert_tflite::Model>(ModelIdentifier());
+}
+
+inline const char *ModelExtension() { return "tflite"; }
+
+inline void FinishModelBuffer(::flatbuffers::FlatBufferBuilder &fbb,
+ ::flatbuffers::Offset<onert_tflite::Model> root)
+{
+ fbb.Finish(root, ModelIdentifier());
+}
+
+inline void FinishSizePrefixedModelBuffer(::flatbuffers::FlatBufferBuilder &fbb,
+ ::flatbuffers::Offset<onert_tflite::Model> root)
+{
+ fbb.FinishSizePrefixed(root, ModelIdentifier());
+}
+
+} // namespace onert_tflite
+
+#endif // FLATBUFFERS_GENERATED_TFLITESCHEMA_ONERT_TFLITE_H_
diff --git a/runtime/onert/core/src/odc/CodegenLoader.cc b/runtime/onert/core/src/odc/CodegenLoader.cc
new file mode 100644
index 000000000..764074fe3
--- /dev/null
+++ b/runtime/onert/core/src/odc/CodegenLoader.cc
@@ -0,0 +1,91 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CodegenLoader.h"
+
+#include <dlfcn.h>
+#include <iostream>
+#include <memory>
+
+static const char *SHARED_LIB_EXT =
+#if defined(__APPLE__) && defined(__MACH__)
+ ".dylib";
+#else
+ ".so";
+#endif
+
+namespace onert
+{
+namespace odc
+{
+
+CodegenLoader &CodegenLoader::instance()
+{
+ static CodegenLoader singleton;
+ return singleton;
+}
+
+void CodegenLoader::loadLibrary(const char *target)
+{
+ if (get() != nullptr)
+ return;
+
+ const std::string codegen_so = "lib" + std::string{target} + SHARED_LIB_EXT;
+#ifdef __ANDROID__
+ void *handle = dlopen(codegen_so.c_str(), RTLD_LAZY | RTLD_LOCAL);
+#else
+ void *handle = dlmopen(LM_ID_NEWLM, codegen_so.c_str(), RTLD_LAZY | RTLD_LOCAL);
+#endif
+ if (handle == nullptr)
+ {
+ throw std::runtime_error("CodegenLoader: " + std::string{dlerror()});
+ }
+
+ const auto factory = (factory_t)dlsym(handle, "create_codegen");
+ if (factory == nullptr)
+ {
+ const std::string dlerror_msg = dlerror();
+ dlclose(handle);
+ throw std::runtime_error("CodegenLoader: " + dlerror_msg);
+ }
+
+ const auto destroyer = (codegen_destory_t)dlsym(handle, "destroy_codegen");
+ _codegen = std::unique_ptr<ICodegen, codegen_destory_t>(factory(), destroyer);
+ if (_codegen == nullptr)
+ {
+ dlclose(handle);
+ throw std::runtime_error("CodegenLoader: unable to create codegen");
+ }
+
+ // Save backend handle (avoid warning by handle lost without dlclose())
+ _dlhandle = std::unique_ptr<void, dlhandle_destroy_t>{
+ handle, [filename = codegen_so](void *h) {
+ if (dlclose(h))
+ throw std::runtime_error("CodegenLoader: Failed to unload backend " + filename);
+ }};
+}
+
+void CodegenLoader::unloadLibrary()
+{
+ if (get() == nullptr)
+ return;
+
+ _codegen.reset(nullptr);
+ _dlhandle.reset(nullptr);
+}
+
+} // namespace odc
+} // namespace onert
diff --git a/runtime/onert/core/src/odc/CodegenLoader.h b/runtime/onert/core/src/odc/CodegenLoader.h
new file mode 100644
index 000000000..397256058
--- /dev/null
+++ b/runtime/onert/core/src/odc/CodegenLoader.h
@@ -0,0 +1,96 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_ODC_CODEGEN_LOADER_H__
+#define __ONERT_ODC_CODEGEN_LOADER_H__
+
+#include "odc/ICodegen.h"
+
+#include <functional>
+#include <memory>
+
+namespace onert
+{
+namespace odc
+{
+
+/**
+ * @brief Class to manage loading and unloading of dynamic library containing
+ * implementation of ICodegen interface.
+ */
+class CodegenLoader
+{
+public:
+ /**
+ * @brief Typedef for function pointer to destroy loaded library handle
+ */
+ using dlhandle_destroy_t = std::function<void(void *)>;
+ /**
+ * @brief Typedef for function pointer to create instance of ICodegen
+ */
+ using factory_t = ICodegen *(*)();
+ /**
+ * @brief Typedef for function pointer to destroy instance of ICodegen
+ */
+ using codegen_destory_t = void (*)(ICodegen *);
+
+ /**
+ * @brief Get singleton instance of CodegenLoader
+ * @return Reference to singleton instance of CodegenLoader
+ */
+ static CodegenLoader &instance();
+
+ // delete copy constructor and assignment operator
+ CodegenLoader(CodegenLoader const &) = delete;
+ CodegenLoader &operator=(CodegenLoader const &) = delete;
+
+private:
+ // cannot create instance of CodegenLoader outside of this class
+ CodegenLoader() = default;
+ ~CodegenLoader() = default;
+
+public:
+ /**
+ * @brief Load dynamic library containing implementation of ICodegen
+ * @param[in] target Target backend name
+ * This target string will be used to find a backend library.
+ * The name of target backend library should follow the following rules:
+ * 'lib' + {backend extension} + '-gen' + {lib extension}
+ * And the target string should be a name except 'lib' and {lib extension}.
+ * For example, if the backend extension is 'aaa', the backend library name
+ * should be 'libaaa-gen.so', and the target string should be 'aaa-gen'.
+ */
+ void loadLibrary(const char *target);
+ /**
+ * @brief Unload dynamic library containing implementation of ICodegen
+ */
+ void unloadLibrary();
+ /**
+ * @brief Get instance of ICodegen created through factory method
+ * @return Pointer to instance of ICodegen
+ */
+ const ICodegen *get() const { return _codegen.get(); }
+
+private:
+ // Note: Keep handle to avoid svace warning of "handle lost without dlclose()"
+ std::unique_ptr<void, dlhandle_destroy_t> _dlhandle;
+ std::unique_ptr<ICodegen, codegen_destory_t> _codegen{nullptr, nullptr};
+};
+
+} // namespace odc
+} // namespace onert
+
+#endif // __ONERT_ODC_CODEGEN_LOADER_H__
diff --git a/runtime/onert/core/src/odc/CodegenManager.cc b/runtime/onert/core/src/odc/CodegenManager.cc
new file mode 100644
index 000000000..45f10a69d
--- /dev/null
+++ b/runtime/onert/core/src/odc/CodegenManager.cc
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CodegenLoader.h"
+#include "odc/CodegenManager.h"
+#include "util/Utils.h"
+
+#include <mutex>
+
+namespace onert
+{
+namespace odc
+{
+
+bool CodegenManager::codegen(const std::string &model_path, const char *target,
+ CodegenPreference pref)
+{
+ if (target == nullptr)
+ throw std::runtime_error("Target string is not set");
+
+ if (_export_model_path.empty())
+ throw std::runtime_error("Export model path is not set");
+
+ if (model_path.empty())
+ throw std::runtime_error("Model path does not exist");
+
+ // codegen function is thread-unsafe
+ static std::mutex lock;
+ std::lock_guard<std::mutex> guard(lock);
+
+ auto &codegen_loader = CodegenLoader::instance();
+ codegen_loader.loadLibrary(target);
+ const auto code_generator = codegen_loader.get();
+ // TODO Use compile preference
+ UNUSED_RELEASE(pref);
+ const auto result = code_generator->codegen(model_path.c_str(), _export_model_path.c_str());
+ codegen_loader.unloadLibrary();
+
+ return (result == 0);
+}
+
+} // namespace odc
+} // namespace onert
diff --git a/runtime/onert/core/src/odc/QuantizeManager.cc b/runtime/onert/core/src/odc/QuantizeManager.cc
new file mode 100644
index 000000000..fc5725b91
--- /dev/null
+++ b/runtime/onert/core/src/odc/QuantizeManager.cc
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizerLoader.h"
+#include "odc/QuantizeManager.h"
+
+#include <iostream>
+#include <mutex>
+
+namespace onert
+{
+namespace odc
+{
+
+bool QuantizeManager::quantize(const std::string &model_path)
+{
+ if (model_path.empty() || _export_model_path.empty())
+ return false;
+
+ // Compile function is thread-unsafe
+ static std::mutex lock;
+ std::lock_guard<std::mutex> guard(lock);
+
+ auto &quantize_loader = QuantizerLoader::instance();
+ if (quantize_loader.loadLibrary() != 0)
+ return false;
+
+ auto quantizer = quantize_loader.get();
+ auto result = quantizer->quantize(model_path.c_str(), _export_model_path.c_str(), _qtype);
+
+ // TODO Unload quantize library to reduce memory usage
+
+ return (result == 0);
+}
+
+} // namespace odc
+} // namespace onert
diff --git a/runtime/onert/core/src/odc/QuantizeManager.test.cc b/runtime/onert/core/src/odc/QuantizeManager.test.cc
new file mode 100644
index 000000000..3c9f45c6e
--- /dev/null
+++ b/runtime/onert/core/src/odc/QuantizeManager.test.cc
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "odc/QuantizeManager.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::odc;
+
+// Test export model path is not set
+TEST(odc_QuantizeManager, neg_export_model_path_not_set)
+{
+ QuantizeManager manager;
+ manager.quantizeType(ODC_QTYPE_WO_I8_SYM);
+ ASSERT_EQ(manager.quantize("model_path"), false);
+}
+
+// Test invalid model path
+TEST(odc_QuantizeManager, neg_invalid_model_path)
+{
+ QuantizeManager manager;
+ manager.exportModelPath("export_model_path.circle");
+ manager.quantizeType(ODC_QTYPE_WO_I8_SYM);
+ ASSERT_EQ(manager.quantize("invalid_model_path.circle"), false);
+}
diff --git a/runtime/onert/core/src/odc/QuantizerLoader.cc b/runtime/onert/core/src/odc/QuantizerLoader.cc
new file mode 100644
index 000000000..8a972e97e
--- /dev/null
+++ b/runtime/onert/core/src/odc/QuantizerLoader.cc
@@ -0,0 +1,104 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizerLoader.h"
+
+#include <dlfcn.h>
+#include <iostream>
+#include <string>
+
+static const char *SHARED_LIB_EXT =
+#if defined(__APPLE__) && defined(__MACH__)
+ ".dylib";
+#else
+ ".so";
+#endif
+
+namespace onert
+{
+namespace odc
+{
+
+QuantizerLoader &QuantizerLoader::instance()
+{
+ static QuantizerLoader singleton;
+ return singleton;
+}
+
+int32_t QuantizerLoader::loadLibrary()
+{
+ if (get() != nullptr)
+ return 0;
+
+ const std::string quantize_so = std::string("libonert_odc") + SHARED_LIB_EXT;
+ void *handle = dlopen(quantize_so.c_str(), RTLD_LAZY | RTLD_LOCAL);
+ auto dlerror_msg = dlerror();
+
+ if (handle == nullptr)
+ {
+ std::cerr << "Failed to load " << quantize_so << std::endl;
+ std::cerr << dlerror_msg << std::endl;
+ return 1;
+ }
+
+ {
+ const char *factory_name = "create_quantizer";
+ auto factory = (factory_t)dlsym(handle, factory_name);
+ dlerror_msg = dlerror();
+
+ if (factory == nullptr)
+ {
+ std::cerr << "QuantizerLoader: unable to find function " << factory_name << dlerror_msg
+ << std::endl;
+ dlclose(handle);
+ return 1;
+ }
+
+ auto destroyer = (quantizer_destory_t)dlsym(handle, "destroy_quantizer");
+ _quantizer = std::unique_ptr<IQuantizer, quantizer_destory_t>(factory(), destroyer);
+
+ if (_quantizer == nullptr)
+ {
+ std::cerr << "QuantizerLoader: unable to create quantizer" << std::endl;
+ dlclose(handle);
+ return 1;
+ }
+ }
+
+ // Save quantize library handle (avoid warning by handle lost without dlclose())
+ // clang-format off
+ _dlhandle = std::unique_ptr<void, dlhandle_destroy_t>{handle, [filename = quantize_so](void *h) {
+ if (dlclose(h) != 0)
+ std::cerr << "Failed to unload backend " << filename << std::endl;
+ }};
+ // clang-format on
+
+ return 0;
+}
+
+int32_t QuantizerLoader::unloadLibrary()
+{
+ if (get() == nullptr)
+ return 0;
+
+ _quantizer.reset(nullptr);
+ _dlhandle.reset(nullptr);
+
+ return 0;
+}
+
+} // namespace odc
+} // namespace onert
diff --git a/runtime/onert/core/src/odc/QuantizerLoader.h b/runtime/onert/core/src/odc/QuantizerLoader.h
new file mode 100644
index 000000000..36a9f2996
--- /dev/null
+++ b/runtime/onert/core/src/odc/QuantizerLoader.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_ODC_QUANTIZER_LOADER_H__
+#define __ONERT_ODC_QUANTIZER_LOADER_H__
+
+#include "odc/IQuantizer.h"
+
+#include <functional>
+#include <memory>
+
+namespace onert
+{
+namespace odc
+{
+
+/**
+ * @brief Class to manage loading and unloading of dynamic library containing
+ * implementation of IQuantizer interface
+ */
+class QuantizerLoader
+{
+public:
+ /**
+ * @brief Typedef for function pointer to destroy loaded library handle
+ */
+ using dlhandle_destroy_t = std::function<void(void *)>;
+ /**
+ * @brief Typedef for function pointer to create instance of IQuantizer
+ */
+ using factory_t = IQuantizer *(*)();
+ /**
+ * @brief Typedef for function pointer to destroy instance of IQuantizer
+ */
+ using quantizer_destory_t = void (*)(IQuantizer *);
+
+ /**
+ * @brief Get singleton instance of QuantizerLoader
+ * @return Reference to singleton instance of QuantizerLoader
+ */
+ static QuantizerLoader &instance();
+
+private:
+ // Cannot create instance of QuantizerLoader outside of this class
+ QuantizerLoader() = default;
+ QuantizerLoader(QuantizerLoader const &) = delete;
+ QuantizerLoader &operator=(QuantizerLoader const &) = delete;
+ ~QuantizerLoader() = default;
+
+public:
+ /**
+ * @brief Load dynamic library containing implementation of IQuantizer
+ * @return 0 if success, otherwise errno value
+ */
+ int32_t loadLibrary();
+ /**
+ * @brief Unload dynamic library containing implementation of IQuantizer
+ * @return 0 if success, otherwise errno value
+ */
+ int32_t unloadLibrary();
+ /**
+ * @brief Get instance of IQuantizer created through factory method
+ * @return Pointer to instance of IQuantizer
+ */
+ IQuantizer *get() const { return _quantizer.get(); }
+
+private:
+ // Note: Keep handle to avoid svace warning of "handle lost without dlclose()"
+ std::unique_ptr<void, dlhandle_destroy_t> _dlhandle;
+ std::unique_ptr<IQuantizer, quantizer_destory_t> _quantizer{nullptr, nullptr};
+};
+
+} // namespace odc
+} // namespace onert
+
+#endif // __ONERT_ODC_QUANTIZER_LOADER_H__
diff --git a/runtime/onert/core/src/odc/QuantizerLoader.test.cc b/runtime/onert/core/src/odc/QuantizerLoader.test.cc
new file mode 100644
index 000000000..112e65b27
--- /dev/null
+++ b/runtime/onert/core/src/odc/QuantizerLoader.test.cc
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizerLoader.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::odc;
+
+// Test QuantizerLoader singleton
+TEST(odc_QuantizerLoader, singleton)
+{
+ QuantizerLoader &loader1 = QuantizerLoader::instance();
+ QuantizerLoader &loader2 = QuantizerLoader::instance();
+ ASSERT_EQ(&loader1, &loader2);
+}
+
+// Test load quantizer library
+TEST(odc_QuantizerLoader, load)
+{
+ QuantizerLoader &loader = QuantizerLoader::instance();
+ // Unload because it may be loaded on previous tests
+ ASSERT_EQ(loader.unloadLibrary(), 0);
+
+ if (loader.loadLibrary() == 0)
+ {
+ // Load twice to check if it is thread-safe
+ ASSERT_EQ(loader.loadLibrary(), 0);
+ }
+}
+
+// Get quantizer function without loading quantizer library
+TEST(odc_QuantizerLoader, neg_get)
+{
+ QuantizerLoader &loader = QuantizerLoader::instance();
+ // Unload because it may be loaded on previous tests
+ ASSERT_EQ(loader.unloadLibrary(), 0);
+ ASSERT_EQ(loader.get(), nullptr);
+}
+
+// Check quantizer function pointer when QuantizerLoader is unloaded
+TEST(odc_QuantizerLoader, neg_unload)
+{
+ QuantizerLoader &loader = QuantizerLoader::instance();
+ if (loader.loadLibrary() == 0)
+ ASSERT_NE(loader.get(), nullptr);
+
+ ASSERT_EQ(loader.unloadLibrary(), 0);
+ ASSERT_EQ(loader.get(), nullptr);
+}
diff --git a/runtime/onert/core/src/util/ChromeTracingEventWriter.cc b/runtime/onert/core/src/util/ChromeTracingEventWriter.cc
new file mode 100644
index 000000000..c3f5179df
--- /dev/null
+++ b/runtime/onert/core/src/util/ChromeTracingEventWriter.cc
@@ -0,0 +1,195 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EventWriter.h"
+
+#include <cassert>
+#include <sstream>
+#include <utility>
+#include <vector>
+
+// json type for ChromeTracingWriter
+namespace
+{
+
+std::string quote(const std::string &value)
+{
+ std::stringstream ss;
+ ss << '"' << value << '"';
+ return ss.str();
+}
+
+std::string field(const std::string &k, const std::string &v)
+{
+ std::stringstream ss;
+ ss << quote(k) << " : " << quote(v);
+ return ss.str();
+}
+
+struct Content // One Entry in Chrome Event Trace
+{
+ std::vector<std::pair<std::string, std::string>> flds;
+ std::vector<std::pair<std::string, std::string>> args;
+};
+
+std::string object(const Content &content)
+{
+ std::stringstream ss;
+
+ ss << "{ ";
+
+ ss << field(content.flds[0].first, content.flds[0].second);
+
+ for (uint32_t n = 1; n < content.flds.size(); ++n)
+ {
+ ss << ", " << field(content.flds.at(n).first, content.flds.at(n).second);
+ }
+
+ if (content.args.size() > 0)
+ {
+ ss << ", " << quote("args") << " : { ";
+ ss << field(content.args.at(0).first, content.args.at(0).second);
+
+ for (uint32_t n = 1; n < content.args.size(); ++n)
+ {
+ ss << ", " << field(content.args.at(n).first, content.args.at(n).second);
+ }
+
+ ss << "}";
+ }
+
+ ss << " }";
+
+ return ss.str();
+}
+
+void fill(Content &content, const DurationEvent &evt, const std::string &name,
+ const std::string &tid)
+{
+ content.flds.emplace_back("name", name);
+ content.flds.emplace_back("pid", "0");
+ content.flds.emplace_back("tid", tid);
+ content.flds.emplace_back("ph", evt.ph);
+ content.flds.emplace_back("ts", evt.ts);
+ content.args = evt.args;
+}
+
+void fill(Content &content, const CounterEvent &evt)
+{
+ assert(evt.name != "");
+
+ content.flds.emplace_back("name", evt.name);
+ content.flds.emplace_back("pid", "0");
+ content.flds.emplace_back("tid", evt.tid);
+ content.flds.emplace_back("ph", evt.ph);
+ content.flds.emplace_back("ts", evt.ts);
+ content.args = evt.args;
+}
+
+std::string object(const DurationEvent &evt, const std::string &name, const std::string &tid)
+{
+ Content content;
+
+ fill(content, evt, name, tid);
+
+ return ::object(content);
+}
+
+std::string object(const CounterEvent &evt)
+{
+ Content content;
+
+ fill(content, evt);
+
+ for (auto it = evt.values.begin(); it != evt.values.end(); ++it)
+ {
+ content.args.emplace_back(it->first, it->second);
+ }
+
+ return ::object(content);
+}
+
+std::string getSessionLabel(const DurationEvent &evt)
+{
+ return "$" + std::to_string(evt.session_index) + " sess";
+}
+
+std::string getSubgLabel(const DurationEvent &evt)
+{
+ return "$" + std::to_string(evt.subg_index) + " subg";
+}
+
+std::string getOpLabel(const OpSeqDurationEvent &evt)
+{
+ return "@" + std::to_string(evt.op_index) + " " + evt.op_name;
+}
+
+std::string getLabel(const DurationEvent &evt)
+{
+ if (auto evt_ptr = dynamic_cast<const OpSeqDurationEvent *>(&evt))
+ {
+ return getOpLabel(*evt_ptr);
+ }
+ else // SubgDurationEvent
+ {
+ return getSubgLabel(evt);
+ }
+}
+
+std::string getTid(const DurationEvent &evt)
+{
+ if (auto evt_ptr = dynamic_cast<const OpSeqDurationEvent *>(&evt))
+ {
+ return getSessionLabel(*evt_ptr) + ", " + getSubgLabel(*evt_ptr) + ", " + evt_ptr->backend;
+ }
+ else // SubgDurationEvent
+ {
+ return getSessionLabel(evt) + ", " + getSubgLabel(evt);
+ }
+}
+
+} // namespace
+
+void ChromeTracingWriter::flush(const std::vector<std::unique_ptr<EventRecorder>> &recorders)
+{
+ _os << "{\n";
+ _os << " " << quote("traceEvents") << ": [\n";
+
+ for (const auto &recorder : recorders)
+ {
+ flushOneRecord(*recorder);
+ }
+
+ _os << " { }\n";
+ _os << " ]\n";
+ _os << "}\n";
+}
+
+void ChromeTracingWriter::flushOneRecord(const EventRecorder &recorder)
+{
+ for (const auto &evt : recorder.duration_events())
+ {
+ const std::string name = getLabel(*evt);
+ const std::string tid = getTid(*evt);
+
+ _os << " " << object(*evt, name, tid) << ",\n";
+ }
+
+ for (const auto &evt : recorder.counter_events())
+ {
+ _os << " " << object(evt) << ",\n";
+ }
+}
diff --git a/runtime/onert/core/src/util/ConfigSource.cc b/runtime/onert/core/src/util/ConfigSource.cc
index 45cce662e..b7fcefc7a 100644
--- a/runtime/onert/core/src/util/ConfigSource.cc
+++ b/runtime/onert/core/src/util/ConfigSource.cc
@@ -15,13 +15,15 @@
*/
#include "util/ConfigSource.h"
-#include "util/GeneralConfigSource.h"
-#include "util/EnvConfigSource.h"
+#include "util/logging.h"
+
+#include <misc/EnvConfigSource.h>
+#include <misc/GeneralConfigSource.h>
+#include <misc/IConfigSource.h>
-#include <array>
#include <algorithm>
+#include <array>
#include <cassert>
-
#include <memory>
namespace onert
@@ -29,9 +31,26 @@ namespace onert
namespace util
{
+using namespace nnfw::misc;
+
static std::unique_ptr<IConfigSource> _source;
+static std::unique_ptr<IConfigSource> _source_ext;
void config_source(std::unique_ptr<IConfigSource> &&source) { _source = std::move(source); }
+void config_source_ext(std::unique_ptr<IConfigSource> &&source) { _source_ext = std::move(source); }
+
+void setConfigKeyValues(const CfgKeyValues &keyValues)
+{
+ auto configsrc = std::make_unique<GeneralConfigSource>();
+
+ for (auto it = keyValues.begin(); it != keyValues.end(); ++it)
+ {
+ VERBOSE(NNPKG_CONFIGS) << "(" << it->first << ") = (" << it->second << ")" << std::endl;
+ configsrc->set(it->first, it->second);
+ }
+
+ onert::util::config_source_ext(std::move(configsrc));
+}
static IConfigSource *config_source()
{
@@ -67,6 +86,15 @@ static std::string getConfigOrDefault(const std::string &key)
auto ret = config_source()->get(key);
if (ret.empty())
{
+ // if env is not set, search from external
+ if (_source_ext.get())
+ {
+ ret = _source_ext.get()->get(key);
+ }
+ }
+ // if not found search from defaults
+ if (ret.empty())
+ {
auto itr = defaults.find(key);
if (itr != defaults.end())
{
diff --git a/runtime/onert/core/src/util/EventCollector.cc b/runtime/onert/core/src/util/EventCollector.cc
index de37276bf..c1b9c4315 100644
--- a/runtime/onert/core/src/util/EventCollector.cc
+++ b/runtime/onert/core/src/util/EventCollector.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "util/EventCollector.h"
+#include "EventCollector.h"
// C++ standard libraries
#include <chrono>
@@ -30,24 +30,62 @@ std::string timestamp(void)
{
auto now = std::chrono::steady_clock::now();
return std::to_string(
- std::chrono::duration_cast<std::chrono::microseconds>(now.time_since_epoch()).count());
+ std::chrono::duration_cast<std::chrono::microseconds>(now.time_since_epoch()).count());
}
-class DurationEventBuilder
+class DurationEventBuilder : public EventCollector::EventVisitor
{
public:
DurationEventBuilder(const std::string &ts) : _ts{ts} {}
- DurationEvent build(const std::string &tid, const std::string &name, const std::string &ph) const
+ std::unique_ptr<SubgDurationEvent> build(const EventCollector::SubgEvent &evt_collected,
+ const std::string &ph) const
{
- DurationEvent evt;
+ auto dur_evt = std::make_unique<SubgDurationEvent>();
- evt.name = name;
- evt.tid = tid;
- evt.ph = ph;
- evt.ts = _ts;
+ // The following will be set by a child of EventsWriter:
+ // dur_evt.name, dur_evt.tid
+ dur_evt->ph = ph;
+ dur_evt->ts = _ts;
+ dur_evt->tracing_ctx = evt_collected.tracing_ctx;
- return evt;
+ dur_evt->session_index = evt_collected.session_index;
+ dur_evt->subg_index = evt_collected.subg_index;
+
+ dur_evt->args = evt_collected.userData;
+ {
+ dur_evt->args.emplace_back("session", std::to_string(evt_collected.session_index));
+ dur_evt->args.emplace_back("subgraph", std::to_string(evt_collected.subg_index));
+ }
+
+ return dur_evt;
+ }
+
+ std::unique_ptr<OpSeqDurationEvent> build(const EventCollector::OpSeqEvent &evt_collected,
+ const std::string &ph) const
+ {
+ auto dur_evt = std::make_unique<OpSeqDurationEvent>();
+
+ // The following will be set by a child of EventsWriter:
+ // dur_evt.name, dur_evt.tid
+ dur_evt->ph = ph;
+ dur_evt->ts = _ts;
+ dur_evt->tracing_ctx = evt_collected.tracing_ctx;
+
+ dur_evt->session_index = evt_collected.session_index;
+ dur_evt->subg_index = evt_collected.subg_index;
+
+ dur_evt->backend = evt_collected.backend;
+ dur_evt->op_index = evt_collected.op_index;
+ dur_evt->op_name = evt_collected.op_name;
+
+ dur_evt->args = evt_collected.userData;
+ {
+ dur_evt->args.emplace_back("session", std::to_string(evt_collected.session_index));
+ dur_evt->args.emplace_back("subgraph", std::to_string(evt_collected.subg_index));
+ }
+
+ return dur_evt;
}
private:
@@ -86,19 +124,26 @@ inline void emit_rusage(EventRecorder *rec, const std::string &ts)
} // namespace
-void EventCollector::onEvent(const Event &event)
+template <typename EventT> void EventCollector::onEvent(const EventT &event)
{
auto ts = timestamp();
+ DurationEventBuilder builder(ts);
+
switch (event.edge)
{
case Edge::BEGIN:
- _rec->emit(DurationEventBuilder(ts).build(event.backend, event.label, "B"));
+ {
+ auto duration_evt = builder.build(event, "B");
+ _rec->emit(std::move(duration_evt));
break;
-
+ }
case Edge::END:
- _rec->emit(DurationEventBuilder(ts).build(event.backend, event.label, "E"));
+ {
+ auto duration_evt = builder.build(event, "E");
+ _rec->emit(std::move(duration_evt));
break;
+ }
}
// TODO: Add resurece measurement(e.g. RSS)
@@ -107,3 +152,7 @@ void EventCollector::onEvent(const Event &event)
emit_rusage(_rec, ts);
#endif
}
+
+// template instantiation
+template void EventCollector::onEvent<EventCollector::SubgEvent>(const SubgEvent &event);
+template void EventCollector::onEvent<EventCollector::OpSeqEvent>(const OpSeqEvent &event);
diff --git a/runtime/onert/core/src/util/EventCollector.h b/runtime/onert/core/src/util/EventCollector.h
index 8154be592..effb72373 100644
--- a/runtime/onert/core/src/util/EventCollector.h
+++ b/runtime/onert/core/src/util/EventCollector.h
@@ -17,7 +17,13 @@
#ifndef __ONERT_UTIL_EVENT_COLLECTOR_H__
#define __ONERT_UTIL_EVENT_COLLECTOR_H__
-#include "util/EventRecorder.h"
+#include "EventRecorder.h"
+
+#include "util/TracingCtx.h"
+
+#include <string>
+#include <utility>
+#include <vector>
class EventCollector
{
@@ -28,11 +34,69 @@ public:
END
};
+ struct SubgEvent;
+ struct OpEvent;
+
+ class EventVisitor
+ {
+ public:
+ virtual ~EventVisitor() = default;
+
+ virtual std::unique_ptr<DurationEvent> visit(const SubgEvent &, const std::string &) const
+ {
+ throw std::runtime_error("Please implement");
+ }
+ virtual std::unique_ptr<DurationEvent> visit(const OpEvent &, const std::string &) const
+ {
+ throw std::runtime_error("Please implement");
+ }
+ };
+
struct Event
{
+ const onert::util::TracingCtx *tracing_ctx;
+
Edge edge;
+ uint32_t session_index;
+ uint32_t subg_index;
+
+ // user-defined data: pairs of (key, value)
+ std::vector<std::pair<std::string, std::string>> userData;
+
+ protected:
+ Event(const onert::util::TracingCtx *a_tracing_ctx, Edge a_edge, uint32_t a_subg_index)
+ : tracing_ctx(a_tracing_ctx), edge(a_edge), session_index(tracing_ctx->getSessionId()),
+ subg_index(a_subg_index)
+ { /* empty */
+ }
+
+ virtual ~Event() = default;
+ };
+
+ struct SubgEvent : public Event
+ {
+ // constructor for subgraph start and end event
+ SubgEvent(const onert::util::TracingCtx *a_tracing_ctx, Edge a_edge, uint32_t a_subg_index)
+ : Event(a_tracing_ctx, a_edge, a_subg_index)
+ { /* empty */
+ }
+ };
+
+ // TODO Rename this to OperationEvent
+ struct OpSeqEvent : public Event
+ {
std::string backend;
- std::string label;
+ uint32_t op_index;
+ std::string op_name;
+
+ OpSeqEvent(const onert::util::TracingCtx *a_tracing_ctx, Edge a_edge, uint32_t a_subg_index,
+ const std::string a_backend, uint32_t a_op_index, const std::string a_op_name)
+ : Event(a_tracing_ctx, a_edge, a_subg_index)
+ {
+ backend.assign(a_backend);
+ op_index = a_op_index;
+ op_name.assign(a_op_name);
+ }
};
public:
@@ -42,7 +106,7 @@ public:
}
public:
- void onEvent(const Event &event);
+ template <typename EventT> void onEvent(const EventT &event);
protected:
EventRecorder *_rec;
diff --git a/runtime/onert/core/src/util/EventCollectorGlobal.cc b/runtime/onert/core/src/util/EventCollectorGlobal.cc
deleted file mode 100644
index d09b95210..000000000
--- a/runtime/onert/core/src/util/EventCollectorGlobal.cc
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "util/EventCollectorGlobal.h"
-
-#include <cassert>
-#include <fstream>
-#include <iostream>
-
-#include "util/ConfigSource.h"
-
-namespace onert
-{
-namespace util
-{
-
-EventCollectorGlobal::EventCollectorGlobal() : _recorder{}, _collector{&_recorder}
-{
- // DO NOTHING
-}
-
-EventCollectorGlobal::~EventCollectorGlobal()
-{
- if (!_recorder.empty())
- {
- try
- {
- // TODO Need better way for saved file path than the hardcoded path
- std::ofstream ofs{"trace.global.json"};
- _recorder.writeToFile(ofs);
- }
- catch (const std::exception &e)
- {
- std::cerr << "E: Fail to record event in EventCollectorGlobal: " << e.what() << std::endl;
- }
- }
-}
-
-EventCollectorGlobal &EventCollectorGlobal::get()
-{
- static EventCollectorGlobal instance;
- return instance;
-}
-
-EventDurationBlock::EventDurationBlock(const std::string &tag) : _tag{tag}
-{
- auto &glob = EventCollectorGlobal::get();
- glob.collector().onEvent(EventCollector::Event{EventCollector::Edge::BEGIN, "0", _tag});
-}
-EventDurationBlock::~EventDurationBlock()
-{
- auto &glob = EventCollectorGlobal::get();
- glob.collector().onEvent(EventCollector::Event{EventCollector::Edge::END, "0", _tag});
-}
-
-EventDurationManual::EventDurationManual(const std::string &tag) : _tag{tag}, _pair{true} {}
-
-EventDurationManual::~EventDurationManual()
-{
- // Check if it has called begin-end pair
- assert(_pair);
-}
-
-void EventDurationManual::begin()
-{
- _pair = false;
- auto &glob = EventCollectorGlobal::get();
- glob.collector().onEvent(EventCollector::Event{EventCollector::Edge::BEGIN, "0", _tag});
-}
-
-void EventDurationManual::end()
-{
- assert(!_pair);
- _pair = true;
- auto &glob = EventCollectorGlobal::get();
- glob.collector().onEvent(EventCollector::Event{EventCollector::Edge::END, "0", _tag});
-}
-
-} // namespace util
-} // namespace onert
diff --git a/runtime/onert/core/src/util/EventCollectorGlobal.h b/runtime/onert/core/src/util/EventCollectorGlobal.h
deleted file mode 100644
index 1027ec84d..000000000
--- a/runtime/onert/core/src/util/EventCollectorGlobal.h
+++ /dev/null
@@ -1,155 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __ONERT_UTIL_EVENT_COLLECTOR_GLOBAL_H__
-#define __ONERT_UTIL_EVENT_COLLECTOR_GLOBAL_H__
-
-#include "util/EventRecorder.h"
-#include "util/EventCollector.h"
-
-namespace onert
-{
-namespace util
-{
-
-/**
- * @brief Singleton class for event collection from anywhere in code
- *
- */
-class EventCollectorGlobal
-{
-public:
- /**
- * @brief Get the singleton object of this class
- *
- * @return EventCollectorGlobal& Singleton object
- */
- static EventCollectorGlobal &get();
-
-public:
- /**
- * @brief Getter for event collector object
- *
- * @return EventCollector& Collector object
- */
- EventCollector &collector() { return _collector; }
-
-private:
- EventCollectorGlobal();
- ~EventCollectorGlobal();
-
-private:
- EventRecorder _recorder;
- EventCollector _collector;
-};
-
-/**
- * @brief Helper class for emitting duration event which is handled automatically with ctor/dtor
- *
- */
-class EventDurationBlock
-{
-public:
- /**
- * @brief Raise a duration event with type of BEGIN
- *
- * @param tag A label for the duration event
- */
- EventDurationBlock(const std::string &tag);
- /**
- * @brief Raise a duration event with type of END
- *
- */
- ~EventDurationBlock();
-
-private:
- std::string _tag;
-};
-
-/**
- * @brief Helper class for emitting duration event which is handled manually
- *
- * Usage:
- * {
- * ...
- * EventDurationManual duration("some tag");
- * duration.begin();
- * ...
- * ... // Code for duration
- * ...
- * duration.end();
- * }
- *
- */
-class EventDurationManual
-{
-public:
- /**
- * @brief Construct a new Event Duration Manual object
- *
- * @param tag A label for the duration object
- */
- EventDurationManual(const std::string &tag);
- /**
- * @brief Destroy the Event Duration Manual object
- *
- */
- ~EventDurationManual();
-
- /**
- * @brief Raise a duration event with type of BEGIN
- *
- */
- void begin();
- /**
- * @brief Raise a duration event with type of END
- *
- */
- void end();
-
-private:
- std::string _tag;
- bool _pair;
-};
-
-} // namespace util
-} // namespace onert
-
-/**
- * Helper Macro Definitions
- *
- * HOW TO USE
- *
- * void f(args)
- * {
- * EVENT_DURATION_FUNCTION();
- * ...
- * if(cond)
- * {
- * EVENT_DURATION_REGION("if branch");
- * ...
- * }
- * ...
- * }
- */
-
-#define EVENT_DURATION_FUNCTION() \
- ::onert::util::EventDurationBlock __event_duration__##__LINE__ { __FUNCTION__ }
-
-#define EVENT_DURATION_REGION(tag) \
- ::onert::util::EventDurationBlock __event_duration__##__LINE__ { tag }
-
-#endif // __ONERT_UTIL_EVENT_COLLECTOR_GLOBAL_H__
diff --git a/runtime/onert/core/src/util/EventRecorder.cc b/runtime/onert/core/src/util/EventRecorder.cc
index 13a599bed..85a588d38 100644
--- a/runtime/onert/core/src/util/EventRecorder.cc
+++ b/runtime/onert/core/src/util/EventRecorder.cc
@@ -14,396 +14,13 @@
* limitations under the License.
*/
-#include "util/EventRecorder.h"
+#include "EventRecorder.h"
-#include <sstream>
-#include <vector>
-#include <unordered_map>
-#include <json/json.h>
-#include <assert.h>
-#include <utility>
-#include <map>
-#include <set>
-#include <stdint.h>
-
-// json type for Chrome Event Trace
-namespace
-{
-
-std::string quote(const std::string &value)
-{
- std::stringstream ss;
- ss << '"' << value << '"';
- return ss.str();
-}
-
-std::string field(const std::string &k, const std::string &v)
-{
- std::stringstream ss;
- ss << quote(k) << " : " << quote(v);
- return ss.str();
-}
-
-struct Content // One Entry in Chrome Event Trace
-{
- std::vector<std::pair<std::string, std::string>> flds;
- std::vector<std::pair<std::string, std::string>> args;
-};
-
-std::string object(const Content &content)
-{
- std::stringstream ss;
-
- ss << "{ ";
-
- ss << field(content.flds[0].first, content.flds[0].second);
-
- for (uint32_t n = 1; n < content.flds.size(); ++n)
- {
- ss << ", " << field(content.flds.at(n).first, content.flds.at(n).second);
- }
-
- if (content.args.size() > 0)
- {
- ss << ", " << quote("args") << " : { ";
- ss << field(content.args.at(0).first, content.args.at(0).second);
-
- for (uint32_t n = 1; n < content.args.size(); ++n)
- {
- ss << ", " << field(content.args.at(n).first, content.args.at(n).second);
- }
-
- ss << "}";
- }
-
- ss << " }";
-
- return ss.str();
-}
-
-void fill(Content &content, const Event &evt)
-{
- content.flds.emplace_back("name", evt.name);
- content.flds.emplace_back("pid", "0");
- content.flds.emplace_back("tid", evt.tid);
- content.flds.emplace_back("ph", evt.ph);
- content.flds.emplace_back("ts", evt.ts);
-}
-
-std::string object(const DurationEvent &evt)
-{
- Content content;
-
- fill(content, evt);
-
- return ::object(content);
-}
-
-std::string object(const CounterEvent &evt)
-{
- Content content;
-
- fill(content, evt);
-
- for (auto it = evt.values.begin(); it != evt.values.end(); ++it)
- {
- content.args.emplace_back(it->first, it->second);
- }
-
- return ::object(content);
-}
-
-} // namespace
-
-// md table type
-namespace
-{
-
-void writeMDTableRow(std::ostream &os, const std::vector<std::string> &list)
-{
- os << "| ";
- for (auto &key : list)
- {
- os << key << " | ";
- }
- os << "\n";
-}
-
-struct MDContent
-{
- std::string name;
- uint64_t begin_ts;
- uint64_t end_ts;
- uint32_t min_rss;
- uint32_t max_rss;
- uint32_t min_page_reclaims;
- uint32_t max_page_reclaims;
-
- MDContent()
- : begin_ts(0), end_ts(0), min_rss(UINT32_MAX), max_rss(0), min_page_reclaims(UINT32_MAX),
- max_page_reclaims(0)
- {
- // DO NOTHING
- }
-
- virtual ~MDContent() = default;
-
- void updateRss(uint32_t rss)
- {
- if (min_rss == UINT32_MAX)
- min_rss = rss;
- if (max_rss == 0)
- max_rss = rss;
-
- if (min_rss > rss)
- min_rss = rss;
- else if (max_rss < rss)
- max_rss = rss;
- }
-
- void updateMinflt(uint32_t minflt)
- {
- if (min_page_reclaims == UINT32_MAX)
- min_page_reclaims = minflt;
- if (max_page_reclaims == 0)
- max_page_reclaims = minflt;
-
- if (min_page_reclaims > minflt)
- min_page_reclaims = minflt;
- else if (max_page_reclaims < minflt)
- max_page_reclaims = minflt;
- }
-
- virtual void write(std::ostream &os) const = 0;
-};
-
-struct OpSeq : public MDContent
-{
- std::string backend;
- uint64_t graph_latency;
-
- struct OpSeqCmp
- {
- bool operator()(const OpSeq &lhs, const OpSeq &rhs) const
- {
- return lhs.begin_ts < rhs.begin_ts;
- }
- bool operator()(const OpSeq &lhs, const OpSeq &rhs) { return lhs.begin_ts < rhs.begin_ts; }
- bool operator()(OpSeq &lhs, OpSeq &rhs) { return lhs.begin_ts < rhs.begin_ts; }
- };
-
- void write(std::ostream &os) const override
- {
- uint64_t opseq_latency = end_ts - begin_ts;
- double opseq_per = static_cast<double>(opseq_latency) / graph_latency * 100.0;
- writeMDTableRow(os, {name, backend, std::to_string(opseq_latency), std::to_string(opseq_per),
- std::to_string(min_rss), std::to_string(max_rss),
- std::to_string(min_page_reclaims), std::to_string(max_page_reclaims)});
- }
-};
-
-struct Graph : public MDContent
-{
- std::set<OpSeq, OpSeq::OpSeqCmp> opseqs;
-
- void setOpSeqs(const std::map<std::string, OpSeq> &name_to_opseq)
- {
- uint64_t graph_latency = end_ts - begin_ts;
- for (auto it : name_to_opseq)
- {
- auto opseq = it.second;
- opseq.graph_latency = graph_latency;
-
- opseqs.insert(opseq);
-
- updateRss(opseq.min_rss);
- updateRss(opseq.max_rss);
- updateMinflt(opseq.min_page_reclaims);
- updateMinflt(opseq.max_page_reclaims);
- }
- }
-
- void write(std::ostream &os) const override
- {
- static std::vector<std::string> graph_headers{"latency(us)", "rss_min(kb)", "rss_max(kb)",
- "page_reclaims_min", "page_reclaims_max"};
-
- static std::vector<std::string> graph_headers_line{"-----------", "-------", "-------",
- "-----------------", "-----------------"};
-
- // Graph's Header
- writeMDTableRow(os, graph_headers);
- writeMDTableRow(os, graph_headers_line);
-
- // Graph's contents
- writeMDTableRow(os, {std::to_string(end_ts - begin_ts), std::to_string(min_rss),
- std::to_string(max_rss), std::to_string(min_page_reclaims),
- std::to_string(max_page_reclaims)});
-
- os << "\n";
-
- static std::vector<std::string> opseq_headers{
- "OpSeq name", "backend", "latency(us)", "latency(%)",
- "rss_min(kb)", "rss_max(kb)", "page_reclaims_min", "page_reclaims_max"};
-
- static std::vector<std::string> opseq_headers_line{
- "----------", "-------", "-----------", "-----------",
- "-------", "-------", "-----------------", "-----------------"};
-
- os << "## OpSequences \n";
-
- // OpSeq's Header
- writeMDTableRow(os, opseq_headers);
- writeMDTableRow(os, opseq_headers_line);
-
- // OpSeq's contents
- for (auto opseq : opseqs)
- {
- opseq.write(os);
- }
-
- os << "\n";
- }
-};
-
-struct MDTableBuilder
-{
- MDTableBuilder(const std::vector<DurationEvent> &duration_events,
- const std::vector<CounterEvent> &counter_events)
- : _duration_events(duration_events), _counter_events(counter_events)
- {
- for (const auto &evt : _counter_events)
- {
- uint64_t ts = std::stoull(evt.ts);
- auto &name = evt.name;
- assert(name.compare("maxrss") == 0 || name.compare("minflt") == 0);
- assert(evt.values.size() == 1);
- auto &val = evt.values.begin()->second;
- if (_ts_to_values.find(ts) == _ts_to_values.end())
- {
- std::pair<uint32_t, uint32_t> values;
- if (name.compare("maxrss") == 0)
- values.first = std::stoul(val);
- else
- values.second = std::stoul(val);
- _ts_to_values.insert({ts, values});
- }
- else
- {
- auto &values = _ts_to_values.at(ts);
- if (name.compare("maxrss") == 0)
- values.first = std::stoul(val);
- else
- values.second = std::stoul(val);
- }
- }
- }
-
- MDTableBuilder &build()
- {
- for (auto &it : divideGraph())
- {
- size_t begin_idx = it.first;
- size_t end_idx = it.second;
- std::map<std::string, OpSeq> name_to_opseq;
- for (size_t i = begin_idx + 1; i < end_idx; ++i)
- {
- const auto &evt = _duration_events[i];
- assert(evt.name.compare("Graph") != 0);
- assert(evt.ph.compare("B") == 0 || evt.ph.compare("E") == 0);
- if (evt.ph.compare("B") == 0)
- {
- assert(name_to_opseq.find(evt.name) == name_to_opseq.end());
- name_to_opseq.insert({evt.name, makeOpSeq(evt)});
- }
- else
- {
- assert(name_to_opseq.find(evt.name) != name_to_opseq.end());
- auto &opseq = name_to_opseq.at(evt.name);
- updateOpSeq(opseq, evt);
- }
- }
-
- _graphs.emplace_back(makeGraph(begin_idx, end_idx, name_to_opseq));
- }
-
- return *this;
- }
-
- std::vector<std::pair<size_t, size_t>> divideGraph()
- {
- std::vector<std::pair<size_t, size_t>> graph_idx_list; // pair<begin_idx, end_idx>
- for (size_t i = 0, begin_idx = 0; i < _duration_events.size(); ++i)
- {
- const auto &evt = _duration_events.at(i);
- if (evt.name.compare("Graph") == 0)
- {
- if (evt.ph.compare("B") == 0)
- begin_idx = i;
- else
- graph_idx_list.emplace_back(begin_idx, i);
- }
- }
- return graph_idx_list;
- }
-
- OpSeq makeOpSeq(const DurationEvent &evt)
- {
- OpSeq opseq;
- opseq.name = evt.name;
- opseq.begin_ts = std::stoull(evt.ts);
- opseq.updateRss(_ts_to_values.at(opseq.begin_ts).first);
- opseq.updateMinflt(_ts_to_values.at(opseq.begin_ts).second);
- opseq.backend = evt.tid;
- return opseq;
- }
-
- void updateOpSeq(OpSeq &opseq, const DurationEvent &evt)
- {
- opseq.end_ts = std::stoull(evt.ts);
- opseq.updateRss(_ts_to_values.at(opseq.end_ts).first);
- opseq.updateMinflt(_ts_to_values.at(opseq.end_ts).second);
- }
-
- Graph makeGraph(size_t begin_idx, size_t end_idx,
- const std::map<std::string, OpSeq> &name_to_opseq)
- {
- Graph graph;
- graph.name = "Graph";
- graph.begin_ts = std::stoull(_duration_events[begin_idx].ts);
- graph.updateRss(_ts_to_values.at(graph.begin_ts).first);
- graph.updateMinflt(_ts_to_values.at(graph.begin_ts).second);
- graph.end_ts = std::stoull(_duration_events[end_idx].ts);
- graph.updateRss(_ts_to_values.at(graph.end_ts).first);
- graph.updateMinflt(_ts_to_values.at(graph.end_ts).second);
- graph.setOpSeqs(name_to_opseq);
- return graph;
- }
-
- void write(std::ostream &os)
- {
- // Write contents
- for (size_t i = 0; i < _graphs.size(); ++i)
- {
- os << "# Graph " << i << "\n";
- _graphs.at(i).write(os);
- }
- }
-
- const std::vector<DurationEvent> &_duration_events;
- const std::vector<CounterEvent> &_counter_events;
- // timestamp to std::pair<maxrss, minflt>
- std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> _ts_to_values;
- std::vector<Graph> _graphs;
-};
-
-} // namespace
-
-void EventRecorder::emit(const DurationEvent &evt)
+void EventRecorder::emit(std::unique_ptr<DurationEvent> &&evt)
{
std::lock_guard<std::mutex> lock{_mu};
- _duration_events.push_back(evt);
+ _duration_events.push_back(std::move(evt));
}
void EventRecorder::emit(const CounterEvent &evt)
@@ -412,146 +29,3 @@ void EventRecorder::emit(const CounterEvent &evt)
_counter_events.push_back(evt);
}
-
-void EventRecorder::writeToFile(std::ostream &os)
-{
- std::lock_guard<std::mutex> lock{_mu};
-
- switch (_write_format)
- {
- case WriteFormat::CHROME_TRACING:
- writeChromeTrace(os);
- break;
- case WriteFormat::SNPE_BENCHMARK:
- writeSNPEBenchmark(os);
- break;
- case WriteFormat::MD_TABLE:
- writeMDTable(os);
- break;
- default:
- assert(!"Invalid value");
- break;
- }
-}
-
-void EventRecorder::writeSNPEBenchmark(std::ostream &os)
-{
- Json::Value root;
- auto &exec_data = root["Execution_Data"] = Json::Value{Json::objectValue};
-
- struct Stat
- {
- uint64_t sum = 0;
- uint64_t count = 0;
- uint64_t max = 0;
- uint64_t min = std::numeric_limits<uint64_t>::max();
-
- void accumulate(uint64_t val)
- {
- sum += val;
- count++;
- max = std::max(max, val);
- min = std::min(min, val);
- }
- };
-
- // Memory
- {
- std::unordered_map<std::string, Stat> mem_stats;
- for (auto &evt : _counter_events)
- {
- auto &mem_stat = mem_stats[evt.name];
- uint64_t val = std::stoull(evt.values["value"]);
- mem_stat.accumulate(val);
- }
-
- auto &mem = exec_data["memory"] = Json::Value{Json::objectValue};
- for (auto &kv : mem_stats)
- {
- auto &key = kv.first;
- auto &val = kv.second;
- mem[key]["Avg_Size"] = val.sum / val.count;
- mem[key]["Max_Size"] = val.max;
- mem[key]["Min_Size"] = val.min;
- mem[key]["Runtime"] = "NA";
- }
- }
-
- // Operation Execution Time
- {
- // NOTE This assumes _duration_events is sorted by "ts" ascending
-
- // 2D keys : stats[tid][name]
- std::unordered_map<std::string, std::unordered_map<std::string, Stat>> stats;
- std::unordered_map<std::string, std::unordered_map<std::string, uint64_t>> begin_timestamps;
- for (auto &evt : _duration_events)
- {
- auto &stat = stats[evt.tid][evt.name];
- auto &begin_ts = begin_timestamps[evt.tid][evt.name];
- uint64_t timestamp = std::stoull(evt.ts);
- if (evt.ph == "B")
- {
- if (begin_ts != 0)
- throw std::runtime_error{"Invalid Data"};
- begin_ts = timestamp;
- }
- else if (evt.ph == "E")
- {
- if (begin_ts == 0 || timestamp < begin_ts)
- throw std::runtime_error{"Invalid Data"};
- stat.accumulate(timestamp - begin_ts);
- begin_ts = 0;
- }
- else
- throw std::runtime_error{"Invalid Data - invalid value for \"ph\" : \"" + evt.ph + "\""};
- }
-
- for (auto &kv : begin_timestamps)
- for (auto &kv2 : kv.second)
- if (kv2.second != 0)
- throw std::runtime_error{"Invalid Data - B and E pair does not match."};
-
- for (auto &kv : stats)
- {
- auto &tid = kv.first;
- auto &map = kv.second;
- auto &json_tid = exec_data[tid] = Json::Value{Json::objectValue};
- for (auto &kv : map)
- {
- auto &name = kv.first;
- auto &val = kv.second;
- json_tid[name]["Avg_Time"] = val.sum / val.count;
- json_tid[name]["Max_Time"] = val.max;
- json_tid[name]["Min_Time"] = val.min;
- json_tid[name]["Runtime"] = tid;
- }
- }
- }
-
- os << root;
-}
-
-void EventRecorder::writeChromeTrace(std::ostream &os)
-{
- os << "{\n";
- os << " " << quote("traceEvents") << ": [\n";
-
- for (auto &evt : _duration_events)
- {
- os << " " << object(evt) << ",\n";
- }
-
- for (auto &evt : _counter_events)
- {
- os << " " << object(evt) << ",\n";
- }
-
- os << " { }\n";
- os << " ]\n";
- os << "}\n";
-}
-
-void EventRecorder::writeMDTable(std::ostream &os)
-{
- MDTableBuilder(_duration_events, _counter_events).build().write(os);
-}
diff --git a/runtime/onert/core/src/util/EventRecorder.h b/runtime/onert/core/src/util/EventRecorder.h
index 37ec1a0f1..5cf03d8ac 100644
--- a/runtime/onert/core/src/util/EventRecorder.h
+++ b/runtime/onert/core/src/util/EventRecorder.h
@@ -17,28 +17,52 @@
#ifndef __ONERT_UTIL_EVENT_RECORDER_H__
#define __ONERT_UTIL_EVENT_RECORDER_H__
+#include "util/TracingCtx.h"
+
#include <map>
#include <memory>
#include <mutex>
-#include <ostream>
#include <vector>
+// refer to https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit#
struct Event
{
- std::string name;
- std::string tid;
- std::string ph; /* REQUIRED */
- std::string ts; /* REQUIRED */
+ const onert::util::TracingCtx *tracing_ctx;
+
+ std::string ph; // Event type.
+ std::string ts; // tracing clock of timestamp of this event
+ std::vector<std::pair<std::string, std::string>> args; // user-defined data: pairs of (key, value)
+
+ virtual ~Event() = default;
};
struct DurationEvent : public Event
{
- // TO BE FILLED
+ uint32_t session_index = 0;
+ uint32_t subg_index = 0;
+
+protected:
+ DurationEvent() = default;
+};
+
+struct SubgDurationEvent : public DurationEvent
+{ /* same with DurationEvent */
+};
+
+// TODO Rename it to OperationDurationEvent
+struct OpSeqDurationEvent : public DurationEvent
+{
+ // Note: DurationEvent's name and tid will be set by EventWriter
+ std::string backend;
+ uint32_t op_index;
+ std::string op_name;
};
struct CounterEvent : public Event
{
+ std::string name; // name of event
+ std::string tid; // thread ID
std::map<std::string, std::string> values;
};
@@ -50,35 +74,22 @@ struct CounterEvent : public Event
class EventRecorder
{
public:
- enum class WriteFormat
- {
- CHROME_TRACING,
- SNPE_BENCHMARK,
- MD_TABLE,
- };
-
-public:
EventRecorder() = default;
public:
- void emit(const DurationEvent &evt);
+ void emit(std::unique_ptr<DurationEvent> &&evt);
void emit(const CounterEvent &evt);
public:
- bool empty() { return _duration_events.empty() && _counter_events.empty(); }
- void writeToFile(std::ostream &os);
- void setWriteFormat(WriteFormat write_format) { _write_format = write_format; }
-
-private:
- void writeSNPEBenchmark(std::ostream &os);
- void writeChromeTrace(std::ostream &os);
- void writeMDTable(std::ostream &os);
+ const std::vector<std::unique_ptr<DurationEvent>> &duration_events() const
+ {
+ return _duration_events;
+ }
+ const std::vector<CounterEvent> &counter_events() const { return _counter_events; }
private:
std::mutex _mu;
- // TODO: Allow user to control write_format
- WriteFormat _write_format{WriteFormat::SNPE_BENCHMARK};
- std::vector<DurationEvent> _duration_events;
+ std::vector<std::unique_ptr<DurationEvent>> _duration_events;
std::vector<CounterEvent> _counter_events;
};
diff --git a/runtime/onert/core/src/util/EventWriter.cc b/runtime/onert/core/src/util/EventWriter.cc
new file mode 100644
index 000000000..ca4bd302e
--- /dev/null
+++ b/runtime/onert/core/src/util/EventWriter.cc
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EventWriter.h"
+
+#include <cassert>
+
+// initialization
+std::mutex EventWriter::_mutex;
+
+void EventWriter::readyToFlush(std::unique_ptr<EventRecorder> &&recorder)
+{
+ {
+ std::unique_lock<std::mutex> lock{_mutex};
+
+ _recorders.emplace_back(std::move(recorder));
+
+ if (--_ref_count > 0)
+ return;
+ }
+ // The caller of this method is the last instance that uses EventWriter.
+ // Let's write log files.
+
+ // Note. According to an internal issue, let snpe json as just file name not '.snpe.json'
+ flush(WriteFormat::SNPE_BENCHMARK);
+ flush(WriteFormat::CHROME_TRACING);
+ flush(WriteFormat::MD_TABLE);
+}
+
+void EventWriter::flush(WriteFormat write_format)
+{
+ auto *writer = _actual_writers[write_format].get();
+ assert(writer);
+
+ writer->flush(_recorders);
+}
diff --git a/runtime/onert/core/src/util/EventWriter.h b/runtime/onert/core/src/util/EventWriter.h
new file mode 100644
index 000000000..672820aa9
--- /dev/null
+++ b/runtime/onert/core/src/util/EventWriter.h
@@ -0,0 +1,144 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __ONERT_UTIL_EVENT_WRITER_H__
+#define __ONERT_UTIL_EVENT_WRITER_H__
+
+#include "EventRecorder.h"
+
+#include <string>
+#include <vector>
+#include <unordered_map>
+#include <mutex>
+#include <fstream>
+
+class EventFormatWriter
+{
+public:
+ EventFormatWriter(const std::string &filepath) : _os{filepath, std::ofstream::out} {}
+ virtual ~EventFormatWriter()
+ { /* empty */
+ }
+
+ virtual void flush(const std::vector<std::unique_ptr<EventRecorder>> &) = 0;
+
+protected:
+ std::ofstream _os;
+};
+
+class SNPEWriter : public EventFormatWriter
+{
+public:
+ SNPEWriter(const std::string &filepath) : EventFormatWriter(filepath)
+ { /* empty */
+ }
+ ~SNPEWriter() {}
+
+ void flush(const std::vector<std::unique_ptr<EventRecorder>> &) override;
+};
+
+class ChromeTracingWriter : public EventFormatWriter
+{
+public:
+ ChromeTracingWriter(const std::string &filepath) : EventFormatWriter(filepath)
+ { /* empty */
+ }
+ ~ChromeTracingWriter() {}
+
+ void flush(const std::vector<std::unique_ptr<EventRecorder>> &) override;
+
+private:
+ void flushOneRecord(const EventRecorder &);
+};
+
+class MDTableWriter : public EventFormatWriter
+{
+public:
+ MDTableWriter(const std::string &filepath) : EventFormatWriter(filepath)
+ { /* empty */
+ }
+ ~MDTableWriter() {}
+
+ void flush(const std::vector<std::unique_ptr<EventRecorder>> &) override;
+};
+
+#include <mutex>
+
+class EventWriter
+{
+public:
+ enum class WriteFormat
+ {
+ CHROME_TRACING,
+ SNPE_BENCHMARK,
+ MD_TABLE,
+ };
+
+ /**
+ * @brief Retuens a singleton object
+ */
+ static EventWriter *get(const std::string &workspace_dir)
+ {
+ std::unique_lock<std::mutex> lock{_mutex};
+
+ static EventWriter singleton(workspace_dir);
+ return &singleton;
+ }
+
+ /**
+ * @brief Call this when observer which use EventWriter starts
+ */
+ void startToUse()
+ {
+ std::unique_lock<std::mutex> lock{_mutex};
+ _ref_count++;
+ }
+
+ /**
+ * @brief Call this when observer which use EventWriter finishes.
+ * After multiple observers calls this method, the reference count will eventually be 0.
+ * Then, EventWriter will write profiling result file.
+ */
+ void readyToFlush(std::unique_ptr<EventRecorder> &&recorder);
+
+private:
+ EventWriter(const std::string &workspace_dir) : _ref_count(0)
+ {
+ std::string snpe_log_name(workspace_dir + "/trace.json");
+ std::string chrome_tracing_log_name(workspace_dir + "/trace.chrome.json");
+ std::string md_table_log_name(workspace_dir + "/trace.table.md");
+
+ _actual_writers[WriteFormat::SNPE_BENCHMARK] = std::make_unique<SNPEWriter>(snpe_log_name);
+ _actual_writers[WriteFormat::CHROME_TRACING] =
+ std::make_unique<ChromeTracingWriter>(chrome_tracing_log_name);
+ _actual_writers[WriteFormat::MD_TABLE] = std::make_unique<MDTableWriter>(md_table_log_name);
+ };
+
+ void flush(WriteFormat write_format);
+
+private:
+ static std::mutex _mutex;
+
+ // number of observer of an executor that want to write profiling data
+ int32_t _ref_count;
+
+ // one recorder object per executor
+ std::vector<std::unique_ptr<EventRecorder>> _recorders;
+
+ std::unordered_map<WriteFormat, std::unique_ptr<EventFormatWriter>> _actual_writers;
+};
+
+#endif // __ONERT_UTIL_EVENT_WRITER_H__
diff --git a/runtime/onert/core/src/util/GeneralConfigSource.cc b/runtime/onert/core/src/util/GeneralConfigSource.cc
deleted file mode 100644
index 7d2757e58..000000000
--- a/runtime/onert/core/src/util/GeneralConfigSource.cc
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "util/GeneralConfigSource.h"
-#include "util/logging.h"
-
-namespace onert
-{
-namespace util
-{
-
-std::string GeneralConfigSource::get(const std::string &key) const
-{
- auto itr = _map.find(key);
- if (itr == _map.end())
- {
- return "";
- }
- else
- {
- return itr->second;
- }
-}
-
-void GeneralConfigSource::set(const std::string &key, const std::string &val)
-{
- VERBOSE(GeneralConfigSource) << key << " : " << val << std::endl;
- _map[key] = val;
-}
-
-} // namespace util
-} // namespace onert
diff --git a/runtime/onert/core/src/util/EnvConfigSource.cc b/runtime/onert/core/src/util/Index.test.cc
index 0d25b7353..ff73e5e59 100644
--- a/runtime/onert/core/src/util/EnvConfigSource.cc
+++ b/runtime/onert/core/src/util/Index.test.cc
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,27 +14,21 @@
* limitations under the License.
*/
-#include "util/EnvConfigSource.h"
+#include "util/Index.h"
-#include <cstdlib>
+#include <gtest/gtest.h>
-namespace onert
-{
-namespace util
-{
+using Index = ::onert::util::Index<uint32_t, struct TestTag>;
-std::string EnvConfigSource::get(const std::string &key) const
+TEST(Index, neg_index_test)
{
- const char *value = std::getenv(key.c_str());
- if (value != nullptr)
- {
- return value;
- }
- else
- {
- return GeneralConfigSource::get(key);
- }
-}
+ Index idx1{1u};
+ Index idx2{2u};
+ Index idx3{idx1};
-} // namespace util
-} // namespace onert
+ ASSERT_EQ(idx1, 1);
+ ASSERT_EQ(idx1, 1u);
+ ASSERT_EQ(idx1.value(), 1u);
+ ASSERT_NE(idx1, idx2);
+ ASSERT_EQ(idx1, idx3);
+}
diff --git a/runtime/onert/core/src/util/MDTableEventWriter.cc b/runtime/onert/core/src/util/MDTableEventWriter.cc
new file mode 100644
index 000000000..e7d90eec4
--- /dev/null
+++ b/runtime/onert/core/src/util/MDTableEventWriter.cc
@@ -0,0 +1,365 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EventWriter.h"
+
+#include <cassert>
+#include <map>
+#include <set>
+#include <sstream>
+#include <stdint.h>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+// md table type
+namespace
+{
+
+void writeMDTableRow(std::ostream &os, const std::vector<std::string> &list)
+{
+ os << "| ";
+ for (const auto &key : list)
+ {
+ os << key << " | ";
+ }
+ os << "\n";
+}
+
+struct MDContent
+{
+ std::string name;
+ uint64_t begin_ts;
+ uint64_t end_ts;
+ uint32_t min_rss;
+ uint32_t max_rss;
+ uint32_t min_page_reclaims;
+ uint32_t max_page_reclaims;
+
+ MDContent()
+ : begin_ts(0), end_ts(0), min_rss(UINT32_MAX), max_rss(0), min_page_reclaims(UINT32_MAX),
+ max_page_reclaims(0)
+ {
+ // DO NOTHING
+ }
+
+ virtual ~MDContent() = default;
+
+ void updateRss(uint32_t rss)
+ {
+ if (min_rss == UINT32_MAX)
+ min_rss = rss;
+ if (max_rss == 0)
+ max_rss = rss;
+
+ if (min_rss > rss)
+ min_rss = rss;
+ else if (max_rss < rss)
+ max_rss = rss;
+ }
+
+ void updateMinflt(uint32_t minflt)
+ {
+ if (min_page_reclaims == UINT32_MAX)
+ min_page_reclaims = minflt;
+ if (max_page_reclaims == 0)
+ max_page_reclaims = minflt;
+
+ if (min_page_reclaims > minflt)
+ min_page_reclaims = minflt;
+ else if (max_page_reclaims < minflt)
+ max_page_reclaims = minflt;
+ }
+
+ virtual void write(std::ostream &os) const = 0;
+};
+
+struct Operation : public MDContent
+{
+ std::string backend;
+ uint64_t graph_latency;
+
+ struct OperationCmp
+ {
+ bool operator()(const Operation &lhs, const Operation &rhs) const
+ {
+ return lhs.begin_ts < rhs.begin_ts;
+ }
+ bool operator()(const Operation &lhs, const Operation &rhs)
+ {
+ return lhs.begin_ts < rhs.begin_ts;
+ }
+ bool operator()(Operation &lhs, Operation &rhs) { return lhs.begin_ts < rhs.begin_ts; }
+ };
+
+ void write(std::ostream &os) const override
+ {
+ uint64_t op_latency = end_ts - begin_ts;
+ double op_per = static_cast<double>(op_latency) / graph_latency * 100.0;
+ writeMDTableRow(os, {name, backend, std::to_string(op_latency), std::to_string(op_per),
+ std::to_string(min_rss), std::to_string(max_rss),
+ std::to_string(min_page_reclaims), std::to_string(max_page_reclaims)});
+ }
+};
+
+struct Graph : public MDContent
+{
+ std::set<Operation, Operation::OperationCmp> ops;
+ std::string session_index;
+ std::string subgraph_index;
+
+ void setOperations(const std::map<std::string, Operation> &name_to_op)
+ {
+ uint64_t graph_latency = end_ts - begin_ts;
+ for (auto &&it : name_to_op)
+ {
+ auto op = it.second;
+ op.graph_latency = graph_latency;
+
+ ops.insert(op);
+
+ updateRss(op.min_rss);
+ updateRss(op.max_rss);
+ updateMinflt(op.min_page_reclaims);
+ updateMinflt(op.max_page_reclaims);
+ }
+ }
+
+ void write(std::ostream &os) const override
+ {
+ static std::vector<std::string> graph_headers{"latency(us)", "rss_min(kb)", "rss_max(kb)",
+ "page_reclaims_min", "page_reclaims_max"};
+
+ static std::vector<std::string> graph_headers_line{"-----------", "-------", "-------",
+ "-----------------", "-----------------"};
+
+ // Graph's Header
+ writeMDTableRow(os, graph_headers);
+ writeMDTableRow(os, graph_headers_line);
+
+ // Graph's contents
+ writeMDTableRow(os, {std::to_string(end_ts - begin_ts), std::to_string(min_rss),
+ std::to_string(max_rss), std::to_string(min_page_reclaims),
+ std::to_string(max_page_reclaims)});
+
+ os << "\n";
+
+ static std::vector<std::string> op_headers{
+ "Op name", "backend", "latency(us)", "latency(%)",
+ "rss_min(kb)", "rss_max(kb)", "page_reclaims_min", "page_reclaims_max"};
+
+ static std::vector<std::string> op_headers_line{
+ "-------", "-------", "-----------", "-----------",
+ "-------", "-------", "-----------------", "-----------------"};
+
+ os << "## Op \n";
+
+ // Operation's Header
+ writeMDTableRow(os, op_headers);
+ writeMDTableRow(os, op_headers_line);
+
+ // Operation's contents
+ for (auto &&op : ops)
+ {
+ op.write(os);
+ }
+
+ os << "\n";
+ }
+};
+
+std::string getLabel(const OpSeqDurationEvent &evt)
+{
+ std::string subg_label("$" + std::to_string(evt.subg_index) + " subgraph");
+ std::string op_label("@" + std::to_string(evt.op_index) + " " + evt.op_name);
+
+ return subg_label + " " + op_label;
+}
+
+struct MDTableBuilder
+{
+ MDTableBuilder(const std::vector<std::unique_ptr<DurationEvent>> &duration_events,
+ const std::vector<CounterEvent> &counter_events)
+ : _duration_events(duration_events), _counter_events(counter_events)
+ {
+// when ready with low overhead in release build
+#ifdef DEBUG
+ for (const auto &evt : _counter_events)
+ {
+ uint64_t ts = std::stoull(evt.ts);
+ auto &name = evt.name;
+ assert(name.compare("maxrss") == 0 || name.compare("minflt") == 0);
+ assert(evt.values.size() == 1);
+ auto &val = evt.values.begin()->second;
+ if (_ts_to_values.find(ts) == _ts_to_values.end())
+ {
+ std::pair<uint32_t, uint32_t> values;
+ if (name.compare("maxrss") == 0)
+ values.first = std::stoul(val);
+ else
+ values.second = std::stoul(val);
+ _ts_to_values.insert({ts, values});
+ }
+ else
+ {
+ auto &values = _ts_to_values.at(ts);
+ if (name.compare("maxrss") == 0)
+ values.first = std::stoul(val);
+ else
+ values.second = std::stoul(val);
+ }
+ }
+#endif
+ }
+
+ MDTableBuilder &build()
+ {
+ for (const auto &it : divideGraph())
+ {
+ size_t begin_idx = it.first;
+ size_t end_idx = it.second;
+ std::map<std::string, Operation> name_to_op;
+ for (size_t i = begin_idx + 1; i < end_idx; ++i)
+ {
+ const auto *evt = dynamic_cast<const OpSeqDurationEvent *>(_duration_events[i].get());
+ if (evt == nullptr)
+ continue;
+
+ const std::string evt_name = getLabel(*evt);
+ assert(evt->ph.compare("B") == 0 || evt->ph.compare("E") == 0);
+ if (evt->ph.compare("B") == 0)
+ {
+ assert(name_to_op.find(evt_name) == name_to_op.end());
+ name_to_op.insert({evt_name, makeOperation(*evt)});
+ }
+ else
+ {
+ assert(name_to_op.find(evt_name) != name_to_op.end());
+ auto &op = name_to_op.at(evt_name);
+ updateOperation(op, *evt);
+ }
+ }
+
+ _graphs.emplace_back(makeGraph(begin_idx, end_idx, name_to_op));
+ }
+
+ return *this;
+ }
+
+ std::vector<std::pair<size_t, size_t>> divideGraph()
+ {
+ std::vector<std::pair<size_t, size_t>> graph_idx_list; // pair<begin_idx, end_idx>
+ for (size_t i = 0, begin_idx = 0; i < _duration_events.size(); ++i)
+ {
+ const auto subg_evt = dynamic_cast<const SubgDurationEvent *>(_duration_events.at(i).get());
+ if (subg_evt == nullptr)
+ continue;
+
+ if (subg_evt->ph.compare("B") == 0)
+ begin_idx = i;
+ else
+ graph_idx_list.emplace_back(begin_idx, i);
+ }
+ return graph_idx_list;
+ }
+
+ Operation makeOperation(const OpSeqDurationEvent &evt)
+ {
+ Operation op;
+ const std::string &evt_name = getLabel(evt);
+ op.name = evt_name;
+ op.begin_ts = std::stoull(evt.ts);
+ op.backend = evt.backend;
+#ifdef DEBUG
+ op.updateRss(_ts_to_values.at(op.begin_ts).first);
+ op.updateMinflt(_ts_to_values.at(op.begin_ts).second);
+#else
+ op.updateRss(0);
+ op.updateMinflt(0);
+#endif
+ return op;
+ }
+
+ void updateOperation(Operation &op, const DurationEvent &evt)
+ {
+ op.end_ts = std::stoull(evt.ts);
+#ifdef DEBUG
+ op.updateRss(_ts_to_values.at(op.end_ts).first);
+ op.updateMinflt(_ts_to_values.at(op.end_ts).second);
+#else
+ op.updateRss(0);
+ op.updateMinflt(0);
+#endif
+ }
+
+ Graph makeGraph(size_t begin_idx, size_t end_idx,
+ const std::map<std::string, Operation> &name_to_op)
+ {
+ Graph graph;
+ graph.name = "Subgraph";
+ graph.begin_ts = std::stoull(_duration_events[begin_idx]->ts);
+ graph.end_ts = std::stoull(_duration_events[end_idx]->ts);
+ graph.setOperations(name_to_op);
+
+ for (const auto &arg : _duration_events[end_idx]->args)
+ {
+ if (arg.first == "session")
+ graph.session_index = arg.second;
+ if (arg.first == "subgraph")
+ graph.subgraph_index = arg.second;
+ }
+
+#ifdef DEBUG
+ graph.updateRss(_ts_to_values.at(graph.begin_ts).first);
+ graph.updateMinflt(_ts_to_values.at(graph.begin_ts).second);
+ graph.updateRss(_ts_to_values.at(graph.end_ts).first);
+ graph.updateMinflt(_ts_to_values.at(graph.end_ts).second);
+#else
+ graph.updateRss(0);
+ graph.updateMinflt(0);
+#endif
+ return graph;
+ }
+
+ void write(std::ostream &os)
+ {
+ // Write contents
+ for (size_t i = 0; i < _graphs.size(); ++i)
+ {
+ auto &graph = _graphs.at(i);
+ os << "# Session: " << graph.session_index << ", Subgraph: " << graph.subgraph_index
+ << ", Running count: " << i << "\n";
+ _graphs.at(i).write(os);
+ }
+ }
+
+ const std::vector<std::unique_ptr<DurationEvent>> &_duration_events;
+ const std::vector<CounterEvent> &_counter_events;
+
+ // timestamp to std::pair<maxrss, minflt>
+ std::unordered_map<uint64_t, std::pair<uint32_t, uint32_t>> _ts_to_values;
+ std::vector<Graph> _graphs;
+};
+
+} // namespace
+
+void MDTableWriter::flush(const std::vector<std::unique_ptr<EventRecorder>> &records)
+{
+ for (const auto &recorder : records)
+ {
+ MDTableBuilder(recorder->duration_events(), recorder->counter_events()).build().write(_os);
+ }
+}
diff --git a/runtime/onert/core/src/util/ObjectManager.test.cc b/runtime/onert/core/src/util/ObjectManager.test.cc
new file mode 100644
index 000000000..3fe735732
--- /dev/null
+++ b/runtime/onert/core/src/util/ObjectManager.test.cc
@@ -0,0 +1,211 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/Index.h"
+#include "util/ObjectManager.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert;
+
+struct TestTag;
+using Index = typename util::Index<uint32_t, TestTag>;
+
+TEST(ObjectManager, emplace)
+{
+ util::ObjectManager<Index, int> man;
+
+ auto index = man.emplace(100);
+ ASSERT_EQ(man.at(index), 100);
+}
+
+TEST(ObjectManager, neg_remove_1)
+{
+ util::ObjectManager<Index, int> man;
+
+ Index index = man.emplace(100);
+ ASSERT_TRUE(man.exist(index));
+ ASSERT_EQ(man.at(index), 100);
+
+ man.remove(index);
+ ASSERT_FALSE(man.exist(index));
+}
+
+TEST(ObjectManager, neg_remove_2)
+{
+ util::ObjectManager<Index, int> man;
+
+ auto index0 = man.emplace(100);
+ auto index1 = man.emplace(200);
+ ASSERT_TRUE(man.exist(index0));
+ ASSERT_EQ(man.at(index0), 100);
+ ASSERT_TRUE(man.exist(index1));
+ ASSERT_EQ(man.at(index1), 200);
+
+ man.remove(index0);
+ ASSERT_FALSE(man.exist(index0));
+ ASSERT_TRUE(man.exist(index1));
+ ASSERT_EQ(man.at(index1), 200);
+}
+
+TEST(ObjectManager, push)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Not specify index
+ auto index = man.push(std::make_unique<int>(100));
+ ASSERT_EQ(man.at(index), 100);
+
+ // Specify index
+ auto index2 = man.push(std::make_unique<int>(200), Index{33});
+ ASSERT_EQ(index2.value(), 33);
+ ASSERT_EQ(man.at(index2), 200);
+
+ auto index3 = man.push(std::make_unique<int>(300));
+ // NOTE auto-generated index number is always (biggest index in the ObjectManager + 1)
+ ASSERT_EQ(index3.value(), 34);
+ ASSERT_EQ(man.at(index3), 300);
+
+ auto index4 = man.push(std::make_unique<int>(400), Index{22});
+ ASSERT_EQ(index4.value(), 22);
+ ASSERT_EQ(man.at(index4), 400);
+
+ auto index5 = man.push(std::make_unique<int>(500));
+ // NOTE auto-generated index number is always (biggest index in the ObjectManager + 1)
+ ASSERT_EQ(index5.value(), 35);
+ ASSERT_EQ(man.at(index5), 500);
+}
+
+TEST(ObjectManager, neg_push)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Specify index
+ auto index = man.push(std::make_unique<int>(100), Index{55});
+ ASSERT_EQ(index.value(), 55);
+ ASSERT_EQ(man.at(index), 100);
+
+ // Specify the same index
+ auto index2 = man.push(std::make_unique<int>(200), Index{55});
+ ASSERT_FALSE(index2.valid());
+}
+
+static const uint32_t kMaxUInt32 = std::numeric_limits<uint32_t>::max();
+
+TEST(ObjectManager, neg_push_undefined_index)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Try inserting invalid(undefined) index
+ auto index = man.push(std::make_unique<int>(100), Index{kMaxUInt32});
+ ASSERT_FALSE(index.valid());
+ ASSERT_EQ(man.size(), 0);
+}
+
+TEST(ObjectManager, neg_push_max_index)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Insert an object with maximum valid index
+ auto index = man.push(std::make_unique<int>(100), Index{kMaxUInt32 - 1});
+ ASSERT_EQ(index.value(), kMaxUInt32 - 1);
+ ASSERT_EQ(man.at(index), 100);
+ ASSERT_EQ(man.size(), 1);
+
+ // Reached to the final index so next push/emplace must fail
+ auto index2 = man.push(std::make_unique<int>(200));
+ ASSERT_EQ(man.size(), 1);
+ ASSERT_FALSE(index2.valid());
+}
+
+TEST(ObjectManager, neg_emplace_max_index)
+{
+ util::ObjectManager<Index, int> man;
+
+ // Insert an object with maximum valid index
+ auto index = man.push(std::make_unique<int>(100), Index{kMaxUInt32 - 1});
+ ASSERT_EQ(index.value(), kMaxUInt32 - 1);
+ ASSERT_EQ(man.at(index), 100);
+ ASSERT_EQ(man.size(), 1);
+
+ // Reached to the final index so next push/emplace must fail
+ auto index3 = man.emplace(200);
+ ASSERT_EQ(man.size(), 1);
+ ASSERT_FALSE(index3.valid());
+}
+
+TEST(ObjectManager, const_iterate)
+{
+ util::ObjectManager<Index, int> man;
+
+ auto index0 = man.emplace(100);
+ auto index1 = man.emplace(200);
+ auto index2 = man.emplace(300);
+
+ int sum = 0;
+ man.iterate([&](const Index &index, const int &val) { sum += val; });
+ ASSERT_EQ(sum, 600);
+}
+
+TEST(ObjectManager, non_const_iterate)
+{
+ util::ObjectManager<Index, int> man;
+
+ auto index0 = man.emplace(100);
+ auto index1 = man.emplace(200);
+ auto index2 = man.emplace(300);
+
+ man.iterate([&](const Index &index, int &val) { val += 1; });
+ ASSERT_EQ(man.at(index0), 101);
+ ASSERT_EQ(man.at(index1), 201);
+ ASSERT_EQ(man.at(index2), 301);
+}
+
+TEST(ObjectManager, set)
+{
+ util::ObjectManager<Index, int> man;
+ auto index = man.set(Index{1}, std::make_unique<int>(100)); // Insert
+ ASSERT_EQ(index, Index{1});
+ auto index2 = man.set(index, std::make_unique<int>(200)); // Overwrite
+ ASSERT_EQ(index2, index);
+ ASSERT_EQ(man.at(index2), 200);
+}
+
+TEST(ObjectManager, neg_set)
+{
+ auto v = std::make_unique<int>(100);
+ util::ObjectManager<Index, int> man;
+ auto index = man.set(Index{}, std::move(v)); // Try set with an invalid index
+ ASSERT_EQ(index, Index{});
+ ASSERT_FALSE(index.valid());
+ ASSERT_NE(v, nullptr); // v must be kept when failure
+}
+
+TEST(ObjectManager, getRawPtr)
+{
+ auto v = std::make_unique<int>(100);
+ auto v_ptr = v.get();
+ util::ObjectManager<Index, int> man;
+ auto index = man.push(std::move(v));
+ ASSERT_EQ(v_ptr, man.getRawPtr(index));
+}
+
+TEST(ObjectManager, neg_getRawPtr)
+{
+ util::ObjectManager<Index, int> man;
+ auto ptr = man.getRawPtr(Index{1});
+ ASSERT_EQ(ptr, nullptr);
+}
diff --git a/runtime/onert/core/src/util/SNPEEventWriter.cc b/runtime/onert/core/src/util/SNPEEventWriter.cc
new file mode 100644
index 000000000..87bbfc662
--- /dev/null
+++ b/runtime/onert/core/src/util/SNPEEventWriter.cc
@@ -0,0 +1,186 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "EventWriter.h"
+
+#include <json/json.h>
+
+#include <cassert>
+#include <unordered_map>
+#include <utility>
+
+/**
+ * @brief Version of SNPE format
+ * In version 1
+ * - There is no "version" field in Json
+ * - Only one subgraph is supported
+ * - Operation name is a form of "$3 ADD"
+ *
+ * In version 2,
+ * - "version" : "2" was added in Json
+ * - Multiple session and multiple subgraphs are supported
+ * - When there is only one session, operation name is a form of "$2 subgraph $3 ADD",
+ * meaning ADD op whose operation index 3 in a subgraph whose index is 2
+ * - When there are two or more sessions, operation name is a form of
+ * "$1 session $2 subgraph $3 ADD", meaning ADD op whose operation index 3
+ * in a subgraph whose index is 2, which was run in 1st session.
+ */
+#define SNPE_JSON_SCHEMA_VERSION "2"
+
+namespace
+{
+
+std::string getLabel(const DurationEvent &evt)
+{
+ if (auto evt_ptr = dynamic_cast<const OpSeqDurationEvent *>(&evt))
+ {
+ std::string subg_label("$" + std::to_string(evt_ptr->subg_index) + " subgraph");
+ std::string op_label("$" + std::to_string(evt_ptr->op_index) + " " + evt_ptr->op_name);
+
+ // Note : At this moment, there is only one thread running for EventWriter
+ if (evt_ptr->tracing_ctx->hasMultipleSessions())
+ {
+ std::string session_label("$" + std::to_string(evt_ptr->session_index) + " session");
+ return session_label + " " + subg_label + " " + op_label;
+ }
+ else
+ {
+ // When there is only one session, do not include session info
+ // Refer to https://github.sec.samsung.net/STAR/nnfw/issues/11436#issuecomment-930332
+ return subg_label + " " + op_label;
+ }
+ }
+ else // SubgEvent
+ return "Graph";
+}
+
+std::string getBackend(const DurationEvent &evt)
+{
+ if (auto evt_ptr = dynamic_cast<const OpSeqDurationEvent *>(&evt))
+ return evt_ptr->backend;
+ else // SubbEvent
+ return "runtime";
+}
+
+} // namespace
+
+void SNPEWriter::flush(const std::vector<std::unique_ptr<EventRecorder>> &recorders)
+{
+ struct Stat
+ {
+ uint64_t sum = 0;
+ uint64_t count = 0;
+ uint64_t max = 0;
+ uint64_t min = std::numeric_limits<uint64_t>::max();
+
+ void accumulate(uint64_t val)
+ {
+ sum += val;
+ count++;
+ max = std::max(max, val);
+ min = std::min(min, val);
+ }
+ };
+
+ Json::Value root;
+ root["version"] = SNPE_JSON_SCHEMA_VERSION;
+
+ auto &exec_data = root["Execution_Data"] = Json::Value{Json::objectValue};
+
+ // Memory
+ {
+ std::unordered_map<std::string, Stat> mem_stats;
+ for (const auto &recorder : recorders)
+ {
+ for (const auto &evt : recorder->counter_events())
+ {
+ auto &mem_stat = mem_stats[evt.name];
+ uint64_t val = std::stoull(evt.values.at("value"));
+ mem_stat.accumulate(val);
+ }
+ }
+
+ auto &mem = exec_data["memory"] = Json::Value{Json::objectValue};
+ for (const auto &kv : mem_stats)
+ {
+ auto &key = kv.first;
+ auto &val = kv.second;
+ mem[key]["Avg_Size"] = val.sum / val.count;
+ mem[key]["Max_Size"] = val.max;
+ mem[key]["Min_Size"] = val.min;
+ mem[key]["Runtime"] = "NA";
+ }
+ }
+
+ // Operation Execution Time
+ {
+ // NOTE This assumes _duration_events is sorted by "ts" ascending
+
+ // 2D keys : stats[tid][name]
+ std::unordered_map<std::string, std::unordered_map<std::string, Stat>> stats;
+ std::unordered_map<std::string, std::unordered_map<std::string, uint64_t>> begin_timestamps;
+ for (const auto &recorder : recorders)
+ {
+ for (const auto &evt : recorder->duration_events())
+ {
+ std::string evt_name = getLabel(*evt);
+ std::string evt_tid = getBackend(*evt);
+
+ auto &stat = stats[evt_tid][evt_name];
+ auto &begin_ts = begin_timestamps[evt_tid][evt_name];
+ uint64_t timestamp = std::stoull(evt->ts);
+ if (evt->ph == "B")
+ {
+ if (begin_ts != 0)
+ throw std::runtime_error{"Invalid Data"};
+ begin_ts = timestamp;
+ }
+ else if (evt->ph == "E")
+ {
+ if (begin_ts == 0 || timestamp < begin_ts)
+ throw std::runtime_error{"Invalid Data"};
+ stat.accumulate(timestamp - begin_ts);
+ begin_ts = 0;
+ }
+ else
+ throw std::runtime_error{"Invalid Data - invalid value for \"ph\" : \"" + evt->ph + "\""};
+ }
+ }
+
+ for (const auto &kv : begin_timestamps)
+ for (const auto &kv2 : kv.second)
+ if (kv2.second != 0)
+ throw std::runtime_error{"Invalid Data - B and E pair does not match."};
+
+ for (const auto &kv : stats)
+ {
+ const auto &tid = kv.first;
+ const auto &map = kv.second;
+ auto &json_tid = exec_data[tid] = Json::Value{Json::objectValue};
+ for (const auto &kv : map)
+ {
+ auto &name = kv.first;
+ auto &val = kv.second;
+ json_tid[name]["Avg_Time"] = val.sum / val.count;
+ json_tid[name]["Max_Time"] = val.max;
+ json_tid[name]["Min_Time"] = val.min;
+ json_tid[name]["Runtime"] = tid;
+ }
+ }
+ }
+
+ _os << root;
+}
diff --git a/runtime/onert/core/src/util/ShapeInference.cc b/runtime/onert/core/src/util/ShapeInference.cc
index 95c15049d..2a6fde45b 100644
--- a/runtime/onert/core/src/util/ShapeInference.cc
+++ b/runtime/onert/core/src/util/ShapeInference.cc
@@ -22,6 +22,7 @@
#include "util/logging.h"
#include <cassert>
+#include <numeric>
#include <sstream>
#include <cmath>
@@ -72,6 +73,19 @@ ir::Shape broadcastShapes(const ir::Shape &lhs_shape, const ir::Shape &rhs_shape
} // namespace
+namespace bcq
+{
+inline int getOutputSize(const ir::Shape &cluster_shape, const int32_t *cluster_buf)
+{
+ int size = 0;
+ for (int idx = 0; idx < cluster_shape.dim(0); idx++)
+ {
+ size += cluster_buf[idx * 2 + 1];
+ }
+ return size;
+}
+} // namespace bcq
+
//
// Shape inference
//
@@ -97,10 +111,9 @@ std::pair<int, int> calcConvLikeHeightAndWidth(const int in_h, const int in_w, c
break;
case ir::PaddingType::EXPLICIT:
out_h =
- (in_h + pad.param.top + pad.param.bottom - effective_filter_h_size) / stride.vertical + 1;
+ (in_h + pad.param.top + pad.param.bottom - effective_filter_h_size) / stride.vertical + 1;
out_w =
- (in_w + pad.param.left + pad.param.right - effective_filter_w_size) / stride.horizontal +
- 1;
+ (in_w + pad.param.left + pad.param.right - effective_filter_w_size) / stride.horizontal + 1;
break;
default:
assert(false);
@@ -114,8 +127,13 @@ ir::Shape inferEltwiseShape(const ir::Shape &lhs_shape, const ir::Shape &rhs_sha
return broadcastShapes(lhs_shape, rhs_shape);
}
-ir::Shape inferArgMaxShape(const ir::Shape &input_shape, int axis, int rank)
+ir::Shape inferArgMinMaxShape(const ir::Shape &input_shape, int axis, int rank)
{
+ if (axis < 0 || axis >= rank)
+ {
+ throw std::runtime_error("ArgMinMax shape inference: Wrong axis value " + std::to_string(axis));
+ }
+
ir::Shape out_shape;
for (int idx = 0; idx < rank; ++idx)
{
@@ -167,15 +185,15 @@ ir::Shape inferReduceShape(const ir::Shape &input_shape, const std::vector<int>
else
{
// Calculates size of reducing axis.
- int num_reduce_axis = num_axis;
for (int i = 0; i < num_axis; ++i)
{
int current = axes[i];
+ if (!(-input_num_dims <= current && current < input_num_dims))
+ throw std::runtime_error{"Invalid dim value " + std::to_string(current)};
if (current < 0)
{
current += input_num_dims;
}
- assert(0 <= current && current < input_num_dims);
for (int j = 0; j < i; ++j)
{
int previous = axes[j];
@@ -185,14 +203,12 @@ ir::Shape inferReduceShape(const ir::Shape &input_shape, const std::vector<int>
}
if (current == previous)
{
- --num_reduce_axis;
break;
}
}
}
// Determines output dimensions.
ir::Shape out_shape;
- int num_skip_axis = 0;
for (int idx = 0; idx < input_num_dims; ++idx)
{
bool is_axis = false;
@@ -200,7 +216,6 @@ ir::Shape inferReduceShape(const ir::Shape &input_shape, const std::vector<int>
{
if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
{
- ++num_skip_axis;
is_axis = true;
break;
}
@@ -259,19 +274,24 @@ ir::Shape inferBatchMatMulShape(const ir::Shape &lhs_shape, const ir::Shape &rhs
return output_shape;
}
-ir::Shape inferBroadcastToShape(const ir::Shape wshape, const int32_t *shape_buffer)
+/*
+ * shp_shape : SHAPE input tensor's shape
+ * shp_buf : SHAPE input tensor's buffer
+ */
+ir::Shape inferBroadcastToShape(const ir::Shape shp_shape, const int32_t *shp_buf)
{
- const int num_elements = wshape.num_elements();
+
+ const int num_elements = shp_shape.num_elements();
assert(num_elements != 0);
- assert(shape_buffer);
+ assert(shp_buf);
ir::Shape new_shape(num_elements);
for (int i = 0; i < num_elements; ++i)
{
- assert(shape_buffer[i] != 0); // It shouldn't be 0.
- new_shape.dim(i) = shape_buffer[i];
+ assert(shp_buf[i] != 0); // It shouldn't be 0.
+ new_shape.dim(i) = shp_buf[i];
}
return new_shape;
@@ -305,6 +325,9 @@ ir::Shape inferConcatShape(const Shapes &in_shapes, const ir::operation::Concat:
ir::Shape inferConv2DShape(const ir::Shape &in_shape, const ir::Shape &ker_shape,
const ir::operation::Conv2D::Param &param, ir::Layout layout)
{
+ if (param.stride.horizontal == 0 || param.stride.vertical == 0)
+ throw std::runtime_error{"Conv2D: stride values must be positive"};
+
auto ifm_shape = in_shape.asFeature(layout);
// Kernel format is [depth_out, kernel_height, kernel_width, depth_in]
@@ -321,6 +344,9 @@ ir::Shape inferDepthwiseConv2DShape(const ir::Shape &in_shape, const ir::Shape &
const ir::operation::DepthwiseConv2D::Param &param,
ir::Layout layout)
{
+ if (param.stride.horizontal == 0 || param.stride.vertical == 0)
+ throw std::runtime_error{"DepthwiseConv2D: stride values must be positive"};
+
assert(layout == ir::Layout::NHWC);
auto ifm_shape = in_shape.asFeature(layout);
@@ -330,7 +356,7 @@ ir::Shape inferDepthwiseConv2DShape(const ir::Shape &in_shape, const ir::Shape &
assert(kf_shape.N == 1);
const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W,
- param.padding, param.stride);
+ param.padding, param.stride, param.dilation);
return ir::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, kf_shape.C};
}
@@ -354,18 +380,22 @@ ir::Shape inferExpandDimsShape(const ir::Shape &in_shape, int32_t axis)
return out_shape;
}
-ir::Shape inferFillShape(const ir::Shape &in_shape, const int32_t *buffer)
+template <typename T> ir::Shape inferFillShape(const ir::Shape &fill_shape, const T *shape_buf)
{
- ir::Shape out_shape(in_shape.dim(0));
+ ir::Shape out_shape(fill_shape.dim(0));
for (int out_x = 0; out_x < out_shape.rank(); ++out_x)
{
- out_shape.dim(out_x) = buffer[out_x];
+ out_shape.dim(out_x) = static_cast<int32_t>(shape_buf[out_x]);
}
return out_shape;
}
+// template instantiation
+template ir::Shape inferFillShape(const ir::Shape &fill_shape, const int32_t *shape_buf);
+template ir::Shape inferFillShape(const ir::Shape &fill_shape, const int64_t *shape_buf);
+
ir::Shape inferFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &ker_shape)
{
assert(in_shape.rank() >= 2);
@@ -380,11 +410,60 @@ ir::Shape inferFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &k
return {ir::Shape({static_cast<int32_t>(batch_size), num_units})};
}
+ir::Shape inferBCQFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &cluster_shape,
+ const int32_t *cluster_buf)
+{
+ assert(cluster_shape.rank() == 2);
+ assert(cluster_shape.dim(1) == 2);
+
+ const auto input_size = in_shape.dim(1);
+ const auto output_size = bcq::getOutputSize(cluster_shape, cluster_buf);
+
+ return {ir::Shape({output_size, input_size})};
+}
+
+ir::Shape inferBCQGatherShape(const ir::Shape &indices_shape, const ir::Shape &cluster_shape,
+ const int32_t *cluster_buf, int rank,
+ const ir::operation::BCQGather::Param &param)
+{
+ ir::Shape out_shape;
+ ir::Shape in_original_shape;
+
+ assert(cluster_shape.rank() == 2);
+ assert(cluster_shape.dim(1) == 2);
+
+ auto hidden_size = param.input_hidden_size;
+ auto axis = param.axis;
+
+ in_original_shape.append(bcq::getOutputSize(cluster_shape, cluster_buf));
+ in_original_shape.append(hidden_size);
+
+ const int indices_rank = indices_shape.rank();
+ for (int idx = 0; idx < rank; ++idx)
+ {
+ if (idx == (int)axis)
+ {
+ for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++)
+ {
+ out_shape.append(indices_shape.dim(indices_idx));
+ }
+ }
+ else
+ {
+ out_shape.append(in_original_shape.dim(idx));
+ }
+ }
+
+ return out_shape;
+}
+
ir::Shape inferGatherShape(const ir::Shape &input_shape, const ir::Shape &indices_shape, int axis,
int rank)
{
ir::Shape out_shape;
+
const int indices_rank = indices_shape.rank();
+
for (int idx = 0; idx < rank; ++idx)
{
if (idx == axis)
@@ -470,6 +549,9 @@ ir::Shape inferPadShape(const ir::Shape &in_shape, const int32_t *pad_buf, const
ir::Shape inferPoolShape(const ir::Shape &in_shape, const ir::operation::Pool2D::Param &param,
const ir::Layout layout)
{
+ if (param.stride.horizontal == 0 || param.stride.vertical == 0)
+ throw std::runtime_error{"Pool2D: stride values must be positive"};
+
assert(layout == ir::Layout::NHWC);
auto ifm_shape = in_shape.asFeature(layout);
const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, param.kh, param.kw,
@@ -482,6 +564,17 @@ ir::Shape inferResizeBilinearShape(const ir::Shape &in_shape, const int32_t outp
const int32_t output_width)
{
assert(in_shape.rank() == 4);
+ if (output_height < 0)
+ {
+ throw std::runtime_error{"ResizeBilinear: size value must be positive value, output_height = " +
+ std::to_string(output_height)};
+ }
+ if (output_width < 0)
+ {
+ throw std::runtime_error{"ResizeBilinear: size value must be positive value, output_width = " +
+ std::to_string(output_width)};
+ }
+
ir::Shape ret(in_shape.rank());
ret.dim(0) = in_shape.dim(0);
@@ -497,9 +590,9 @@ template <typename T> ir::Shape inferRangeShape(T start_val, T limit_val, T delt
ir::Shape out_shape(static_cast<int>(1));
out_shape.dim(0) =
- (std::is_integral<T>::value
- ? ((std::abs(start_val - limit_val) + std::abs(delta_val) - 1) / std::abs(delta_val))
- : std::ceil(std::abs((start_val - limit_val) / delta_val)));
+ (std::is_integral<T>::value
+ ? ((std::abs(start_val - limit_val) + std::abs(delta_val) - 1) / std::abs(delta_val))
+ : std::ceil(std::abs((start_val - limit_val) / delta_val)));
return out_shape;
}
@@ -507,16 +600,17 @@ template <typename T> ir::Shape inferRangeShape(T start_val, T limit_val, T delt
template ir::Shape inferRangeShape(int start_val, int limit_val, int delta_val);
template ir::Shape inferRangeShape(float start_val, float limit_val, float delta_val);
-ir::Shape inferReshapeShape(const int32_t *shape_buf, const int32_t shape_num_elements,
- const size_t total_num_elements)
+ir::Shape inferReshapeShape(const ir::Shape &input_shape, const int32_t *shape_buf,
+ const int32_t shape_num_elements)
{
ir::Shape ret(shape_num_elements);
- int32_t flatten_dim = ir::Shape::UNSPECIFIED_DIM;
+ int32_t flatten_dim = ir::Shape::kUnspecifiedDim;
+ auto total_num_elements = input_shape.num_elements();
for (int32_t i = 0; i < shape_num_elements; ++i)
{
if (shape_buf[i] < 0)
{
- if (flatten_dim != ir::Shape::UNSPECIFIED_DIM)
+ if (flatten_dim != ir::Shape::kUnspecifiedDim)
throw std::runtime_error("Reshape: 2nd param has special dim(for flatten) more than twice");
flatten_dim = i;
ret.dim(i) = 1;
@@ -526,12 +620,20 @@ ir::Shape inferReshapeShape(const int32_t *shape_buf, const int32_t shape_num_el
ret.dim(i) = shape_buf[i];
}
}
- if (flatten_dim != ir::Shape::UNSPECIFIED_DIM)
+ if (flatten_dim != ir::Shape::kUnspecifiedDim)
ret.dim(flatten_dim) = total_num_elements / ret.num_elements();
// Check reshapable
if (total_num_elements != static_cast<size_t>(ret.num_elements()))
- throw std::runtime_error("Reshape: 2nd param is not compatible with the shape of input");
+ {
+ // Multi batch case
+ // TODO Handle multi batch case more precisely on runtime level
+ if ((ret.dim(0) == 1) &&
+ (total_num_elements == static_cast<size_t>(ret.num_elements() * input_shape.dim(0))))
+ ret.dim(0) = input_shape.dim(0);
+ else
+ throw std::runtime_error("Reshape: 2nd param is not compatible with the shape of input");
+ }
return ret;
}
@@ -566,9 +668,9 @@ ir::Shape inferSelectShape(const ir::Shape &input_cond_shape, const ir::Shape &i
ir::Shape true_shape = input_true_shape;
ir::Shape false_shape = input_false_shape;
int most_rank =
- (cond_shape.rank() >= true_shape.rank()) && (cond_shape.rank() >= false_shape.rank())
- ? cond_shape.rank()
- : (false_shape.rank() >= true_shape.rank() ? false_shape.rank() : true_shape.rank());
+ (cond_shape.rank() >= true_shape.rank()) && (cond_shape.rank() >= false_shape.rank())
+ ? cond_shape.rank()
+ : (false_shape.rank() >= true_shape.rank() ? false_shape.rank() : true_shape.rank());
ir::Shape calculate_shape(most_rank);
@@ -579,9 +681,9 @@ ir::Shape inferSelectShape(const ir::Shape &input_cond_shape, const ir::Shape &i
for (int i = 0; i < most_rank; ++i)
{
calculate_shape.dim(i) =
- (cond_shape.dim(i) >= true_shape.dim(i)) && (cond_shape.dim(i) >= false_shape.dim(i))
- ? cond_shape.dim(i)
- : (false_shape.dim(i) >= true_shape.dim(i) ? false_shape.dim(i) : true_shape.dim(i));
+ (cond_shape.dim(i) >= true_shape.dim(i)) && (cond_shape.dim(i) >= false_shape.dim(i))
+ ? cond_shape.dim(i)
+ : (false_shape.dim(i) >= true_shape.dim(i) ? false_shape.dim(i) : true_shape.dim(i));
if ((cond_shape.dim(i) != calculate_shape.dim(i) && cond_shape.dim(i) != 1) ||
(true_shape.dim(i) != calculate_shape.dim(i) && true_shape.dim(i) != 1) ||
@@ -613,7 +715,8 @@ ir::Shape inferSelectShape(const ir::Shape &input_cond_shape, const ir::Shape &i
return new_shape;
}
-ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins, const int32_t *sizes)
+template <typename T>
+ir::Shape inferSliceShape(const ir::Shape &input_shape, const T *begins_buf, const T *sizes_buf)
{
const uint32_t rank = input_shape.rank();
ir::Shape out_shape(rank);
@@ -623,12 +726,12 @@ ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins, c
const auto input_dim = input_shape.dim(idx);
// begin is zero-based
- auto begin = begins[idx];
+ auto begin = begins_buf[idx];
if (begin < 0)
throw std::runtime_error("shape inference Slice: Invalid begin.");
// size is one-based
- auto size = sizes[idx];
+ auto size = sizes_buf[idx];
if (size < -1)
throw std::runtime_error("shape inference Slice: Invalid size.");
@@ -638,18 +741,23 @@ ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins, c
}
else
{
- if (input_dim < begin + size)
+ if (input_dim < static_cast<int32_t>(begin + size))
throw std::runtime_error("shape inference Slice: Invalid begin and size.");
}
- out_shape.dim(idx) = size;
+ out_shape.dim(idx) = static_cast<int32_t>(size);
}
return out_shape;
}
+// template instantiation
+template ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins_buf,
+ const int32_t *sizes_buf);
+template ir::Shape inferSliceShape(const ir::Shape &input_shape, const int64_t *begins_buf,
+ const int64_t *sizes_buf);
ir::Shape inferSpaceToBatchNDShape(const ir::Shape &input_shape, const ir::Shape &block_shape_shape,
- const ir::Shape &padding_shape, const int32_t *block_shape_data,
- const int32_t *padding_data)
+ const ir::Shape &padding_shape, const int32_t *block_shape_buf,
+ const int32_t *padding_buf)
{
const uint32_t rank = input_shape.rank();
ir::Shape out_shape(rank);
@@ -677,14 +785,14 @@ ir::Shape inferSpaceToBatchNDShape(const ir::Shape &input_shape, const ir::Shape
for (int dim = 0; dim < kSpatialDimensionNum; ++dim)
{
int final_dim_size =
- (input_shape.dim(dim + 1) + padding_data[dim * 2] + padding_data[dim * 2 + 1]);
+ (input_shape.dim(dim + 1) + padding_buf[dim * 2] + padding_buf[dim * 2 + 1]);
- assert(final_dim_size % block_shape_data[dim] == 0);
+ assert(final_dim_size % block_shape_buf[dim] == 0);
- out_shape.dim(dim + 1) = final_dim_size / block_shape_data[dim];
+ out_shape.dim(dim + 1) = final_dim_size / block_shape_buf[dim];
}
- const int output_batch_size = input_shape.dim(0) * block_shape_data[0] * block_shape_data[1];
+ const int output_batch_size = input_shape.dim(0) * block_shape_buf[0] * block_shape_buf[1];
const int output_channel_size = input_shape.dim(3);
out_shape.dim(0) = output_batch_size;
@@ -740,7 +848,7 @@ ir::Shape inferSqueezeShape(const ir::Shape &in_shape, const ir::operation::Sque
if (!(current >= 0 && current < shape_rank && in_shape.dim(current) == 1))
{
throw std::runtime_error(
- "The following conditions must be met: 0 <= dim < Shape rank, dim == 1");
+ "The following conditions must be met: 0 <= dim < Shape rank, dim == 1");
}
if (!should_squeeze[current])
@@ -948,35 +1056,71 @@ ir::Shape inferStridedSliceShape(const ir::Shape &input_shape, const StridedSlic
return out_shape;
}
-ir::Shape inferTileShape(const ir::Shape &in_shape, const int32_t *multiplier)
+ir::Shape inferTileShape(const ir::Shape &in_shape, const int32_t *multiplier_buf,
+ const int32_t multiplier_size)
{
- // assert(in_shape.rank() == multiplier.rank());
+ if (multiplier_size != in_shape.rank())
+ {
+ throw std::runtime_error(
+ "inferTileShape failed, input rank: " + std::to_string(in_shape.rank()) +
+ ", bad multipliers size: " + std::to_string(multiplier_size) + "");
+ }
ir::Shape new_Shape(in_shape.rank());
for (int i = 0; i < in_shape.rank(); ++i)
{
- assert(multiplier[i]); // multiplier[i] shuld not be 0.
- new_Shape.dim(i) = in_shape.dim(i) * multiplier[i];
+ assert(multiplier_buf[i]); // multiplier_buf[i] shuld not be 0.
+ new_Shape.dim(i) = in_shape.dim(i) * multiplier_buf[i];
}
return new_Shape;
}
-ir::Shape inferTransposeShape(const ir::Shape &in_shape, const std::vector<int> &perm)
+ir::Shape inferTransposeShape(const ir::Shape &in_shape, const int32_t *perm_buf,
+ const int32_t perm_size)
{
- if (static_cast<int>(perm.size()) > in_shape.rank())
+ const auto rank = in_shape.rank();
+ if (perm_size > rank)
+ {
+ throw std::runtime_error("inferTransposeShape failed, bad permutation size: " +
+ std::to_string(perm_size));
+ }
+
+ const int32_t *perm_data = perm_buf;
+ std::vector<int32_t> regular_perm_vec;
+ if (perm_size == 0)
{
- throw std::runtime_error("inferTransposeShape failed, bad rank size: " +
- std::to_string(static_cast<int>(perm.size())));
+ // perm_data will be set to (n-1...0)
+ regular_perm_vec.resize(rank);
+ std::iota(regular_perm_vec.begin(), regular_perm_vec.end(), 0);
+ std::reverse(regular_perm_vec.begin(), regular_perm_vec.end());
+ perm_data = regular_perm_vec.data();
}
- ir::Shape out_shape(static_cast<int>(perm.size()));
- for (int idx = 0; idx < static_cast<int>(perm.size()); idx++)
+ else
+ {
+ assert(rank == perm_size);
+ }
+
+ ir::Shape out_shape(rank);
+ std::vector<bool> visit_perms(rank, false);
+ for (int idx = 0; idx < rank; idx++)
{
- if (perm[idx] < 0 || perm[idx] >= static_cast<int>(perm.size()))
+ const auto perm_val = perm_data[idx];
+ // Check invalid permutation value
+ if (perm_val < 0 || perm_val >= rank)
{
- throw std::runtime_error("inferTransposeShape failed, bad perm value: " +
- std::to_string(perm[idx]));
+ throw std::runtime_error("inferTransposeShape failed, bad permutation value: " +
+ std::to_string(perm_val));
}
- out_shape.dim(idx) = in_shape.dim(perm[idx]);
+
+ // Check duplicated permutation value
+ if (visit_perms.at(perm_val))
+ {
+ throw std::runtime_error("inferTransposeShape failed, duplicated permutation value: " +
+ std::to_string(perm_val));
+ }
+ visit_perms.at(perm_val) = true;
+
+ out_shape.dim(idx) = in_shape.dim(perm_val);
}
return out_shape;
}
diff --git a/runtime/onert/core/src/util/ShapeInference.test.cc b/runtime/onert/core/src/util/ShapeInference.test.cc
new file mode 100644
index 000000000..96579bfa2
--- /dev/null
+++ b/runtime/onert/core/src/util/ShapeInference.test.cc
@@ -0,0 +1,544 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/ShapeInference.h"
+
+#include <gtest/gtest.h>
+
+using namespace onert::ir;
+
+TEST(ShapeInference, Elementwise)
+{
+ Shape lhs_shape{1, 299, 299, 3};
+ Shape rhs_shape{3};
+ auto infered_out_shape = onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.dim(0), 1);
+ ASSERT_EQ(infered_out_shape.dim(1), 299);
+ ASSERT_EQ(infered_out_shape.dim(2), 299);
+ ASSERT_EQ(infered_out_shape.dim(3), 3);
+}
+
+TEST(ShapeInference, neg_Elementwise)
+{
+ Shape lhs_shape{1, 299, 299, 3};
+ Shape rhs_shape{5, 3};
+ ASSERT_THROW(onert::shape_inference::inferEltwiseShape(lhs_shape, rhs_shape), std::runtime_error);
+}
+
+TEST(ShapeInference, Pool2DNodeSame)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Stride stride{3, 7};
+ Padding padding{PaddingType::SAME};
+
+ operation::Pool2D::Param avg_pool_param{
+ operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
+ auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+
+ operation::Pool2D::Param max_pool_param{
+ operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
+ infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+}
+
+TEST(ShapeInference, Pool2DNodeValid)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Stride stride{3, 7};
+ Padding padding{PaddingType::VALID};
+
+ operation::Pool2D::Param avg_pool_param{
+ operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
+ auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+
+ operation::Pool2D::Param max_pool_param{
+ operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
+ infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+}
+
+TEST(ShapeInference, Pool2DNodeExplicit)
+{
+ Shape in_shape{10, 3, 5, 20};
+
+ Stride stride{3, 7};
+ Padding padding{4, 3, 2, 1};
+
+ operation::Pool2D::Param avg_pool_param{
+ operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
+ auto infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, avg_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+
+ operation::Pool2D::Param max_pool_param{
+ operation::Pool2D::PoolType::MAX, 3, 6, stride, padding, Activation::NONE};
+ infered_out_shape = onert::shape_inference::inferPoolShape(in_shape, max_pool_param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 20);
+}
+
+TEST(ShapeInference, neg_Pool2DNode_InvalidStride)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Stride stride{0, 7};
+ Padding padding{PaddingType::SAME};
+
+ operation::Pool2D::Param avg_pool_param{
+ operation::Pool2D::PoolType::AVG, 3, 6, stride, padding, Activation::NONE};
+ ASSERT_THROW(onert::shape_inference::inferPoolShape(in_shape, avg_pool_param),
+ std::runtime_error);
+}
+
+TEST(ShapeInference, Conv2D)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Shape ker_shape{30, 3, 6, 20};
+
+ operation::Conv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, Activation::NONE,
+ Dilation{1, 1}};
+ auto infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
+
+ param = operation::Conv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, Activation::NONE,
+ Dilation{1, 1}};
+ infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
+
+ param =
+ operation::Conv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, Activation::NONE, Dilation{1, 1}};
+ infered_out_shape = onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 30);
+}
+
+TEST(ShapeInference, neg_Conv2D_InvalidStride)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Shape ker_shape{30, 3, 6, 20};
+
+ operation::Conv2D::Param param{Stride{0, 0}, Padding{PaddingType::VALID}, Activation::NONE,
+ Dilation{1, 1}};
+ ASSERT_THROW(onert::shape_inference::inferConv2DShape(in_shape, ker_shape, param),
+ std::runtime_error);
+}
+
+TEST(ShapeInference, DepthwiseConv2D)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Shape ker_shape{1, 3, 6, 60};
+
+ operation::DepthwiseConv2D::Param param{Stride{3, 7}, Padding{PaddingType::VALID}, 3,
+ Activation::NONE, Dilation{1, 1}};
+ auto infered_out_shape =
+ onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 1);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
+
+ param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{PaddingType::SAME}, 3,
+ Activation::NONE, Dilation{1, 1}};
+ infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
+
+ param = operation::DepthwiseConv2D::Param{Stride{3, 7}, Padding{4, 3, 2, 1}, 3, Activation::NONE,
+ Dilation{1, 1}};
+ infered_out_shape = onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 4);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).N, 10);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).H, 3);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).W, 2);
+ ASSERT_EQ(infered_out_shape.asFeature(Layout::NHWC).C, 60);
+}
+
+TEST(ShapeInference, neg_DepthwiseConv2D_InvalidSride)
+{
+ Shape in_shape{10, 6, 12, 20};
+ Shape ker_shape{1, 3, 6, 60};
+
+ operation::DepthwiseConv2D::Param param{Stride{3, 0}, Padding{PaddingType::VALID}, 3,
+ Activation::NONE, Dilation{1, 1}};
+ ASSERT_THROW(onert::shape_inference::inferDepthwiseConv2DShape(in_shape, ker_shape, param),
+ std::runtime_error);
+}
+
+TEST(ShapeInference, Concat)
+{
+ {
+ Shape in1{10, 20, 30, 3, 50};
+ Shape in2{10, 20, 30, 2, 50};
+ Shape in3{10, 20, 30, 2, 50};
+
+ operation::Concat::Param param{3};
+ auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2, in3}, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 5);
+ ASSERT_EQ(infered_out_shape.dim(0), 10);
+ ASSERT_EQ(infered_out_shape.dim(1), 20);
+ ASSERT_EQ(infered_out_shape.dim(2), 30);
+ ASSERT_EQ(infered_out_shape.dim(3), 7);
+ ASSERT_EQ(infered_out_shape.dim(4), 50);
+ }
+ {
+ // case 1. when axis < 0
+ Shape in1{10, 20, 2};
+ Shape in2{10, 20, 3};
+
+ operation::Concat::Param param{-1};
+ auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 3);
+ ASSERT_EQ(infered_out_shape.dim(0), 10);
+ ASSERT_EQ(infered_out_shape.dim(1), 20);
+ ASSERT_EQ(infered_out_shape.dim(2), 5);
+ }
+ {
+ // case 2. when axis < 0
+ Shape in1{2, 20, 2};
+ Shape in2{3, 20, 2};
+
+ operation::Concat::Param param{-3};
+ auto infered_out_shape = onert::shape_inference::inferConcatShape({in1, in2}, param);
+
+ ASSERT_EQ(infered_out_shape.rank(), 3);
+ ASSERT_EQ(infered_out_shape.dim(0), 5);
+ ASSERT_EQ(infered_out_shape.dim(1), 20);
+ ASSERT_EQ(infered_out_shape.dim(2), 2);
+ }
+}
+
+TEST(ShapeInference, neg_Concat)
+{
+ {
+ operation::Concat::Param param{2};
+ Shape in1{10, 1, 3};
+ Shape in2{10, 2, 4}; // dim[1] should be 1 but 2
+
+ EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param));
+ }
+ { // wrong rank
+ operation::Concat::Param param{2};
+ Shape in1{10, 2, 3, 4};
+ Shape in2{10, 2, 4}; // rank should be 4
+
+ EXPECT_ANY_THROW(onert::shape_inference::inferConcatShape({in1, in2}, param));
+ }
+}
+
+TEST(ShapeInference, ExpandDims)
+{
+ Shape in_shape{30, 40};
+
+ auto check = [&](int32_t axis, Shape &expected) {
+ auto actual = onert::shape_inference::inferExpandDimsShape(in_shape, axis);
+
+ ASSERT_EQ(actual.rank(), 3);
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ ASSERT_EQ(actual.dim(dim), expected.dim(dim));
+ };
+
+ { // boundary
+ int32_t axis = 0;
+ Shape expected{1, 30, 40};
+ check(axis, expected);
+ }
+ { // boundary
+ int32_t axis = 2;
+ Shape expected{30, 40, 1};
+ check(axis, expected);
+ }
+ { // inside
+ int32_t axis = 1;
+ Shape expected{30, 1, 40};
+ check(axis, expected);
+ }
+ { // negative boundary
+ int32_t axis = -1;
+ Shape expected{30, 40, 1};
+ check(axis, expected);
+ }
+ { // negative boundary
+ int32_t axis = -3;
+ Shape expected{1, 30, 40};
+ check(axis, expected);
+ }
+}
+
+TEST(ShapeInference, neg_ExpandDims)
+{
+ Shape in_shape{30, 40};
+
+ { // over boundary
+ int32_t axis = 3;
+ ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error);
+ }
+ { // over boundary
+ int32_t axis = -4;
+ ASSERT_THROW(onert::shape_inference::inferExpandDimsShape(in_shape, axis), std::runtime_error);
+ }
+}
+
+TEST(ShapeInference, FullyConnected)
+{
+ Shape in_shape{3, 4, 5, 6};
+ Shape ker_shape{3, 10};
+ auto infered_out_shape = onert::shape_inference::inferFullyConnectedShape(in_shape, ker_shape);
+
+ ASSERT_EQ(infered_out_shape.rank(), 2);
+ ASSERT_EQ(infered_out_shape.dim(0), 36);
+ ASSERT_EQ(infered_out_shape.dim(1), 3);
+}
+
+TEST(ShapeInference, Transpose)
+{
+ auto check = [&](Shape &in_shape, std::vector<int> perm, Shape &expected) {
+ // pre-conditions
+ ASSERT_EQ(in_shape.rank(), perm.size());
+ ASSERT_EQ(expected.rank(), perm.size());
+ auto inferred_out_shape =
+ onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size());
+ // post-conditions
+ ASSERT_EQ(inferred_out_shape.rank(), perm.size());
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ {
+ ASSERT_EQ(inferred_out_shape.dim(dim), expected.dim(dim));
+ }
+ };
+ // check for 2-D
+ {
+ Shape in_shape{2, 3};
+ std::vector<int> perm = {1, 0};
+ Shape expected{3, 2};
+ // int32_t rank = 2;
+ check(in_shape, perm, expected);
+ }
+ // check for 3-D
+ {
+ Shape in_shape{1, 2, 3};
+ std::vector<int> perm = {2, 0, 1};
+ Shape expected{3, 1, 2};
+ // int32_t rank = 3;
+ check(in_shape, perm, expected);
+ }
+ // check for 4-D
+ {
+ Shape in_shape{1, 2, 3, 4};
+ std::vector<int> perm = {1, 3, 0, 2};
+ Shape expected{2, 4, 1, 3};
+ // int32_t rank = 4;
+ check(in_shape, perm, expected);
+ }
+}
+
+TEST(ShapeInference, neg_Transpose)
+{
+ Shape in_shape{1, 2, 3};
+ // Invalid parameter size
+ {
+ std::vector<int> perm = {2, 0, 1, 0};
+ // int32_t rank = 3;
+ ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size()),
+ std::runtime_error);
+ }
+ // Invalid parameter value
+ {
+ std::vector<int> perm = {2, 0, 3};
+ // int32_t rank = 3;
+ ASSERT_THROW(onert::shape_inference::inferTransposeShape(in_shape, perm.data(), perm.size()),
+ std::runtime_error);
+ }
+}
+
+TEST(ShapeInference, Gather)
+{
+ auto check = [&](Shape &input, Shape &indices, Shape &expected, int32_t axis) {
+ int rank = input.rank();
+ auto actual = onert::shape_inference::inferGatherShape(input, indices, axis, rank);
+
+ ASSERT_EQ(actual.rank(), expected.rank());
+
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ ASSERT_EQ(actual.dim(dim), expected.dim(dim));
+ };
+
+ // check for 2-D, 3-D, axis 0
+ {
+ Shape input{3, 4};
+ Shape indices{1, 1, 2};
+ int32_t axis = 0;
+ Shape expected{1, 1, 2, 4};
+ check(input, indices, expected, axis);
+ }
+
+ // check for 2-D, 3-D, axis 1
+ {
+ Shape input{3, 4};
+ Shape indices{1, 2, 1};
+ int32_t axis = 1;
+ Shape expected{3, 1, 2, 1};
+ check(input, indices, expected, axis);
+ }
+
+ // check for 3-D, 2-D, axis 0
+ {
+ Shape input{2, 3, 4};
+ Shape indices{1, 2};
+ int32_t axis = 0;
+ Shape expected{1, 2, 3, 4};
+ check(input, indices, expected, axis);
+ }
+
+ // check for 3-D, 2-D, axis 2
+ {
+ Shape input{2, 3, 4};
+ Shape indices{2, 1};
+ int32_t axis = 2;
+ Shape expected{2, 3, 2, 1};
+ check(input, indices, expected, axis);
+ }
+
+ // check for 4D, axis 0
+ {
+ Shape input{1, 2, 3, 4};
+ Shape indices{2};
+ int32_t axis = 0;
+ Shape expected{2, 2, 3, 4};
+ check(input, indices, expected, axis);
+ }
+}
+
+TEST(ShapeInference, BCQFullyConnected)
+{
+ auto check = [&](Shape &in_shape, Shape &cluster_shape, std::vector<int> cluster,
+ Shape &expected) {
+ auto actual =
+ onert::shape_inference::inferBCQFullyConnectedShape(in_shape, cluster_shape, cluster.data());
+ ASSERT_EQ(actual.rank(), expected.rank());
+
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ ASSERT_EQ(actual.dim(dim), expected.dim(dim));
+ };
+
+ {
+ Shape in_shape{10, 1};
+ Shape cluster_shape{3, 2};
+ std::vector<int> cluster = {1, 10, 2, 10, 3, 10};
+
+ Shape expected{30, 1};
+ check(in_shape, cluster_shape, cluster, expected);
+ }
+
+ {
+ Shape in_shape{1, 1};
+ Shape cluster_shape{1, 2};
+ std::vector<int> cluster = {3, 50};
+
+ Shape expected{50, 1};
+ check(in_shape, cluster_shape, cluster, expected);
+ }
+}
+
+TEST(ShapeInference, BCQGather)
+{
+ auto check = [&](Shape &indices_shape, Shape &cluster_shape, std::vector<int> cluster,
+ uint32_t hidden_size, uint32_t axis, int rank, Shape &expected) {
+ operation::BCQGather::Param param{hidden_size, axis};
+ auto actual = onert::shape_inference::inferBCQGatherShape(indices_shape, cluster_shape,
+ cluster.data(), rank, param);
+ ASSERT_EQ(actual.rank(), expected.rank());
+
+ for (int32_t dim = 0; dim < expected.rank(); dim++)
+ ASSERT_EQ(actual.dim(dim), expected.dim(dim));
+ };
+
+ {
+ Shape indices_shape{5, 1};
+ Shape cluster_shape{3, 2};
+ std::vector<int> cluster = {1, 10, 2, 10, 3, 10};
+ uint32_t hidden_size = 10;
+ uint32_t axis = 0;
+ int rank = 2;
+
+ Shape expected{5, 1, 10};
+ check(indices_shape, cluster_shape, cluster, hidden_size, axis, rank, expected);
+ }
+
+ {
+ Shape indices_shape{5, 1};
+ Shape cluster_shape{3, 2};
+ std::vector<int> cluster = {1, 10, 2, 10, 3, 10};
+ uint32_t hidden_size = 10;
+ uint32_t axis = 1;
+ int rank = 2;
+
+ Shape expected{30, 5, 1};
+ check(indices_shape, cluster_shape, cluster, hidden_size, axis, rank, expected);
+ }
+}
diff --git a/runtime/onert/core/src/util/TracingCtx.cc b/runtime/onert/core/src/util/TracingCtx.cc
new file mode 100644
index 000000000..c05baee60
--- /dev/null
+++ b/runtime/onert/core/src/util/TracingCtx.cc
@@ -0,0 +1,30 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "util/TracingCtx.h"
+
+namespace onert
+{
+namespace util
+{
+
+// initializing static member var
+std::mutex TracingCtx::_session_id_mutex;
+uint32_t TracingCtx::_next_session_id = 0;
+
+} // namespace util
+} // namespace onert