summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/bcq-tools/generate_bcq_output_arrays126
-rw-r--r--compiler/bcq-tools/generate_bcq_output_arrays.py130
-rw-r--r--compiler/circle2circle/src/Circle2Circle.cpp40
-rw-r--r--compiler/circlechef/tests/CMakeLists.txt51
-rw-r--r--compiler/circlechef/tests/shape_signature/test.recipe45
-rw-r--r--compiler/circlechef/tests/shape_signature/test.reverse0
-rw-r--r--compiler/common-artifacts/exclude.lst54
-rw-r--r--compiler/exo/src/Circle/CircleExporterUtils.h2
-rw-r--r--compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp4
-rw-r--r--compiler/exo/src/TFLite/TFLExporterUtils.h2
-rw-r--r--compiler/hermes/include/hermes/core/Message.h2
-rw-r--r--compiler/luci-interpreter/src/kernels/Conv2D.cpp98
-rw-r--r--compiler/luci-interpreter/src/kernels/Conv2D.h1
-rw-r--r--compiler/luci-interpreter/src/kernels/Conv2D.test.cpp72
-rw-r--r--compiler/luci-interpreter/src/kernels/DepthwiseConv2D.cpp103
-rw-r--r--compiler/luci-interpreter/src/kernels/DepthwiseConv2D.h1
-rw-r--r--compiler/luci-interpreter/src/kernels/DepthwiseConv2D.test.cpp73
-rw-r--r--compiler/luci-interpreter/src/kernels/TransposeConv.cpp104
-rw-r--r--compiler/luci-interpreter/src/kernels/TransposeConv.h1
-rw-r--r--compiler/luci-interpreter/src/kernels/TransposeConv.test.cpp59
-rw-r--r--compiler/luci-interpreter/src/loader/GraphLoader.cpp4
-rw-r--r--compiler/luci/export/src/CircleExporterImpl.cpp7
-rw-r--r--compiler/luci/export/src/CircleExporterUtils.cpp16
-rw-r--r--compiler/luci/export/src/CircleExporterUtils.h2
-rw-r--r--compiler/luci/export/src/CircleOperationExporter.cpp4
-rw-r--r--compiler/luci/export/src/CircleTensorExporter.cpp7
-rw-r--r--compiler/luci/export/src/Optimize.cpp2
-rw-r--r--compiler/luci/export/src/SerializedData.h2
-rw-r--r--compiler/luci/import/include/luci/Import/CircleReader.h2
-rw-r--r--compiler/luci/import/src/CircleReader.cpp16
-rw-r--r--compiler/luci/import/src/Nodes/CircleFullyConnected.cpp7
-rw-r--r--compiler/luci/lang/include/luci/IR/AttrDilation.h14
-rw-r--r--compiler/luci/lang/include/luci/IR/AttrFilter.h14
-rw-r--r--compiler/luci/lang/include/luci/IR/AttrStride.h14
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleShapeSignature.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h17
-rw-r--r--compiler/luci/lang/src/AttrDilation.cpp36
-rw-r--r--compiler/luci/lang/src/AttrDilation.test.cpp36
-rw-r--r--compiler/luci/lang/src/AttrFilter.cpp36
-rw-r--r--compiler/luci/lang/src/AttrFilter.test.cpp36
-rw-r--r--compiler/luci/lang/src/AttrStride.cpp36
-rw-r--r--compiler/luci/lang/src/AttrStride.test.cpp36
-rw-r--r--compiler/luci/lang/src/CircleShapeSignature.cpp34
-rw-r--r--compiler/luci/pass/include/luci/CircleOptimizer.h8
-rw-r--r--compiler/luci/pass/include/luci/ModulePass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/CircleTypeInferencePass.h42
-rw-r--r--compiler/luci/pass/include/luci/Pass/FuseBCQPass.h5
-rw-r--r--compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h44
-rw-r--r--compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/RemoveRedundantTransposePass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h5
-rw-r--r--compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h42
-rw-r--r--compiler/luci/pass/include/luci/Pass/ShuffleWeightTo16x1Float32Pass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/SubstitutePackToReshapePass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/TypeInferencePass.h5
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.cpp85
-rw-r--r--compiler/luci/pass/src/CircleTypeInferencePass.cpp59
-rw-r--r--compiler/luci/pass/src/FuseBCQPass.cpp291
-rw-r--r--compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp112
-rw-r--r--compiler/luci/pass/src/ModulePhase.cpp71
-rw-r--r--compiler/luci/pass/src/ModulePhase.h67
-rw-r--r--compiler/luci/pass/src/ProgressReporter.cpp42
-rw-r--r--compiler/luci/pass/src/ProgressReporter.h26
-rw-r--r--compiler/luci/pass/src/PropagateQuantParamPass.cpp102
-rw-r--r--compiler/luci/pass/src/PropagateQuantParamPass.test.cpp118
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp149
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTranspose.cpp127
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp156
-rw-r--r--compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp223
-rw-r--r--compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp142
-rw-r--r--compiler/luci/pass/src/ShapeInferencePass.cpp13
-rw-r--r--compiler/luci/pass/src/ShapeSignatureInferencePass.cpp63
-rw-r--r--compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp139
-rw-r--r--compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp118
-rw-r--r--compiler/luci/pass/src/SubstitutePackToReshapePass.cpp107
-rw-r--r--compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp124
-rw-r--r--compiler/luci/pass/src/TypeInferencePass.cpp13
-rw-r--r--compiler/luci/service/include/luci/Service/CircleShapeInference.h153
-rw-r--r--compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h36
-rw-r--r--compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h (renamed from compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceRule.h)42
-rw-r--r--compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h45
-rw-r--r--compiler/luci/service/include/luci/Service/CircleTypeInference.h153
-rw-r--r--compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h34
-rw-r--r--compiler/luci/service/include/luci/Service/ShapeDescription.h3
-rw-r--r--compiler/luci/service/src/CircleShapeInference.cpp60
-rw-r--r--compiler/luci/service/src/CircleShapeInferenceHelper.cpp34
-rw-r--r--compiler/luci/service/src/CircleShapeInferenceRule.cpp4
-rw-r--r--compiler/luci/service/src/CircleShapeSignatureInference.cpp (renamed from compiler/luci/service/src/CircleShapeSignatureInferenceRule.cpp)12
-rw-r--r--compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp160
-rw-r--r--compiler/luci/service/src/CircleTypeInference.cpp46
-rw-r--r--compiler/luci/service/src/CircleTypeInferenceHelper.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleInput.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleMean.cpp28
-rw-r--r--compiler/luci/service/src/Nodes/CircleOutput.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleOutputDummy.cpp24
-rw-r--r--compiler/luci/service/src/Nodes/CircleOutputExclude.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceAny.cpp28
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceMax.cpp28
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceMin.cpp28
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceProd.cpp28
-rw-r--r--compiler/luci/service/src/Nodes/CircleRelu.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleRelu6.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleReluN1To1.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleSum.cpp28
-rw-r--r--compiler/luci/service/src/ShapeDescription.cpp13
-rw-r--r--compiler/luci/service/src/Validate.cpp82
-rw-r--r--compiler/luci/tester/src/ReadTester.cpp9
-rw-r--r--compiler/luci/tester/src/WriteTester.cpp9
-rw-r--r--compiler/moco/support/src/TFShapeInferenceHelper.cpp4
-rw-r--r--compiler/nnc/include/Definitions.h.in4
-rw-r--r--compiler/one-cmds/how-to-use-one-commands.txt1
-rw-r--r--compiler/one-cmds/one-codegen29
-rw-r--r--compiler/one-cmds/one-import-bcq4
-rw-r--r--compiler/one-cmds/one-import-tf2
-rw-r--r--compiler/one-cmds/one-optimize4
-rw-r--r--compiler/one-cmds/tests/one-build_001.cfg2
-rw-r--r--compiler/one-cmds/tests/one-build_002.cfg2
-rw-r--r--compiler/one-cmds/tests/one-build_neg_002.cfg2
-rw-r--r--compiler/one-cmds/tests/one-build_neg_003.cfg2
-rw-r--r--compiler/one-cmds/tests/one-build_neg_004.cfg2
-rw-r--r--compiler/one-cmds/tests/one-import_002.cfg2
-rw-r--r--compiler/one-cmds/tests/one-import_003.cfg13
-rw-r--r--compiler/one-cmds/tests/one-import_003.test42
-rw-r--r--compiler/one-cmds/tests/one-import_004.cfg13
-rw-r--r--compiler/one-cmds/tests/one-import_004.test42
-rw-r--r--compiler/one-cmds/tests/prepare_test_materials.sh14
-rw-r--r--compiler/oops/include/oops/InternalExn.h6
-rw-r--r--compiler/pota-quantization-value-test/CMakeLists.txt28
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/beta.json20
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/gamma.json20
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/ifm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/record_minmax/ifm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/record_minmax/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/beta.json10
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/gamma.json10
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/ifm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/record_minmax/ifm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/record_minmax/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/alpha.json18
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/ifm.json2
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/ofm.json2
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/record_minmax/ifm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/record_minmax/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/quantization/alpha.json21
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/quantization/ifm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/quantization/ofm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/record_minmax/ifm.json4
-rw-r--r--compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/record_minmax/ofm.json4
-rwxr-xr-xcompiler/pota-quantization-value-test/gen_h5_explicit_inputs.py35
-rw-r--r--compiler/pota-quantization-value-test/test.lst3
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/4.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/0.txt2
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/1.txt2
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/2.txt2
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/3.txt2
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/4.txt2
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/0.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/1.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/2.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/3.txt1
-rw-r--r--compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/4.txt1
-rwxr-xr-xcompiler/pota-quantization-value-test/test_record_minmax.sh4
-rw-r--r--compiler/tflchef/core/src/CustomOp/MaxPoolWithArgMax.cpp4
-rw-r--r--compiler/tfldump/src/Dump.cpp12
-rw-r--r--compiler/tfldump/src/OpPrinter.cpp1
-rw-r--r--compiler/tfldump/src/Read.cpp1
-rw-r--r--compiler/tfldump/src/Read.h3
-rw-r--r--compiler/vconone/CMakeLists.txt2
180 files changed, 5314 insertions, 497 deletions
diff --git a/compiler/bcq-tools/generate_bcq_output_arrays b/compiler/bcq-tools/generate_bcq_output_arrays
index b71a37410..8544bbd2a 100644
--- a/compiler/bcq-tools/generate_bcq_output_arrays
+++ b/compiler/bcq-tools/generate_bcq_output_arrays
@@ -112,128 +112,22 @@ def print_bcqinfo_output_arrays_v1(flags):
if infoname == "bcqinfo_dequant_weight":
has_dequant_weight = True
- # Ideal situation is that the user nodes of BCQ applicable constant nodes
- # are BCQ applicable operations such as MatMul, GatherV2, etc.
- # However, operations which do not change original values such as
- # Ideneity or Transpose can exist between them. In view of TensorFlow Lite,
- # real user nodes of BCQ applicable constant nodes must be found first.
- # This work is done by BFS search with queue.
-
- prefix_node_dict = {} # key : prefix / value : list of candidates
- matmul_node_prefix_dict = {} # key : Name of MatMul node / value : prefix
-
- queue_prefix = list(prefix_set)
- queue_nodename = [queue_prefix[idx] + ":0" for idx in range(len(queue_prefix))]
-
- while len(queue_prefix) > 0:
- prefix = queue_prefix.pop(0)
- nodename = queue_nodename.pop(0)
- if prefix not in prefix_node_dict.keys():
- prefix_node_dict[prefix] = []
-
- # Usually, output name of op is like "outputname:0"
- # -2 is for removing ":0"
- for op in ops:
- if op.type == "MatMul" and (op.inputs[0].name == nodename
- or op.inputs[1].name == nodename):
- prefix_node_dict[prefix].append(op.outputs[0].name[:-2])
- matmul_node_prefix_dict[op.outputs[0].name[:-2]] = prefix
- elif op.type == "Einsum" and (op.inputs[0].name == nodename
- or op.inputs[1].name == nodename):
- prefix_node_dict[prefix].append(op.outputs[0].name[:-2])
- elif op.type == "GatherV2" and op.inputs[0].name == nodename:
- prefix_node_dict[prefix].append(op.outputs[0].name[:-2])
- elif len(op.outputs) == 1:
- for i in range(len(op.inputs)):
- if op.inputs[i].name == nodename:
- queue_prefix.append(prefix)
- queue_nodename.append(op.outputs[0].name)
- break
-
- # When TensorFlow model is converted to TensorFlow Lite model,
- # more than one operation can be fused as one.
- # For example, MatMul + BiasAdd + ReLU in TensorFlow can be fused as
- # one FullyConnected in TensorFlow Lite.
- # It means that even real user nodes of BCQ applicable constant nodes
- # in TensorFlow are found, they may be real user nodes in TensorFlow Lite.
- # Therefore additional candidates of real user nodes should be found either.
- # Finding additional candidates is done by BFS search with queue.
-
- fuseop_prefix_dict = {} # key : Candidate operation / Value : prefix
-
- # These ops can be candidate. However other candidates may exists after these ops.
- mark_type = ["Add", "AddV2", "BiasAdd", "Reshape", "Transpose"]
-
- # These ops can be candidate. And no more candidates will be found after these ops.
- mark_and_stop_type = ["Relu", "Relu6", "Tanh"]
-
- # These ops cannot be candidates but other candidates may exists after these ops.
- # NOTE : Some of following ops may be removed from the list but not sure for now.
- pass_type = [
- "BatchToSpaceND", "Cast", "DepthToSpace", "ExpandDims", "ResizeBilinear",
- "ResizeNearestNeighbor", "ScatterNd", "SpaceToBatchND", "SpaceToDepth", "Squeeze",
- "Identity", "Pack", "Unpack", "Stack"
- ]
-
- queue_prefix = list(matmul_node_prefix_dict.values())
- queue_nodename = [matmul + ":0" for matmul in matmul_node_prefix_dict.keys()]
-
- visited_nodes = set(queue_nodename)
- while len(queue_prefix) > 0:
- prefix = queue_prefix.pop(0)
- nodename = queue_nodename.pop(0)
-
- # Usually, output name of op is like "outputname:0"
- # -2 is for removing ":0"
- for op in ops:
- for i in range(len(op.inputs)):
- if nodename == op.inputs[i].name:
- if op.type in mark_type:
- if op.outputs[0].name[:-2] not in fuseop_prefix_dict.keys():
- fuseop_prefix_dict[op.outputs[0].name[:-2]] = set()
- fuseop_prefix_dict[op.outputs[0].name[:-2]].add(prefix)
- if op.outputs[0].name not in visited_nodes:
- queue_prefix.append(prefix)
- queue_nodename.append(op.outputs[0].name)
- visited_nodes.add(op.outputs[0].name)
- elif op.type in mark_and_stop_type:
- if op.outputs[0].name[:-2] not in fuseop_prefix_dict.keys():
- fuseop_prefix_dict[op.outputs[0].name[:-2]] = set()
- fuseop_prefix_dict[op.outputs[0].name[:-2]].add(prefix)
- elif op.type in pass_type and op.outputs[0].name not in visited_nodes:
- queue_prefix.append(prefix)
- queue_nodename.append(op.outputs[0].name)
- visited_nodes.add(op.outputs[0].name)
-
# Write the name of metadata node
with open(flags.metadata_path, 'w') as f_metadata:
f_metadata.write("one_compiler/bcqinfo_one_metadata,")
- # Write all pairs of candidate operations and related BCQ information nodes.
+ # Write all pairs of a constant node and related BCQ information nodes.
with open(flags.output_arrays_path, 'w') as f_arrays:
for prefix in prefix_set:
- for fusable_op in prefix_node_dict[prefix]:
- f_arrays.write("," + prefix + "/bcqinfo_do_w_x")
- f_arrays.write("," + prefix + "/bcqinfo_alpha")
- f_arrays.write("," + prefix + "/bcqinfo_packed_binary_code")
- f_arrays.write("," + prefix + "/bcqinfo_number_of_clusters")
- f_arrays.write("," + prefix + "/bcqinfo_size_of_clusters")
- f_arrays.write("," + prefix + "/bcqinfo_qbits_of_clusters")
- f_arrays.write("," + fusable_op)
- if has_dequant_weight:
- f_arrays.write("," + prefix + "/bcqinfo_dequant_weight")
- for fuseop in fuseop_prefix_dict.keys():
- if len(fuseop_prefix_dict[fuseop]) == 1:
- prefix = fuseop_prefix_dict[fuseop].pop()
- f_arrays.write("," + prefix + "/bcqinfo_do_w_x")
- f_arrays.write("," + prefix + "/bcqinfo_alpha")
- f_arrays.write("," + prefix + "/bcqinfo_packed_binary_code")
- f_arrays.write("," + prefix + "/bcqinfo_number_of_clusters")
- f_arrays.write("," + prefix + "/bcqinfo_size_of_clusters")
- f_arrays.write("," + prefix + "/bcqinfo_qbits_of_clusters")
- f_arrays.write("," + fuseop)
- if has_dequant_weight:
- f_arrays.write("," + prefix + "/bcqinfo_dequant_weight")
+ f_arrays.write("," + prefix + "/bcqinfo_do_w_x")
+ f_arrays.write("," + prefix + "/bcqinfo_alpha")
+ f_arrays.write("," + prefix + "/bcqinfo_packed_binary_code")
+ f_arrays.write("," + prefix + "/bcqinfo_number_of_clusters")
+ f_arrays.write("," + prefix + "/bcqinfo_size_of_clusters")
+ f_arrays.write("," + prefix + "/bcqinfo_qbits_of_clusters")
+ f_arrays.write("," + prefix)
+ if has_dequant_weight:
+ f_arrays.write("," + prefix + "/bcqinfo_dequant_weight")
def print_bcq_output_arrays(flags):
diff --git a/compiler/bcq-tools/generate_bcq_output_arrays.py b/compiler/bcq-tools/generate_bcq_output_arrays.py
index 0cc131880..5d9fbe687 100644
--- a/compiler/bcq-tools/generate_bcq_output_arrays.py
+++ b/compiler/bcq-tools/generate_bcq_output_arrays.py
@@ -81,129 +81,23 @@ def get_bcqinfo_output_arrays_v1(input_path, output_arrays):
if infoname == "bcqinfo_dequant_weight":
has_dequant_weight = True
- # Ideal situation is that the user nodes of BCQ applicable constant nodes
- # are BCQ applicable operations such as MatMul, GatherV2, etc.
- # However, operations which do not change original values such as
- # Ideneity or Transpose can exist between them. In view of TensorFlow Lite,
- # real user nodes of BCQ applicable constant nodes must be found first.
- # This work is done by BFS search with queue.
-
- prefix_node_dict = {} # key : prefix / value : list of candidates
- matmul_node_prefix_dict = {} # key : Name of MatMul node / value : prefix
-
- queue_prefix = list(prefix_set)
- queue_nodename = [queue_prefix[idx] + ":0" for idx in range(len(queue_prefix))]
-
- while len(queue_prefix) > 0:
- prefix = queue_prefix.pop(0)
- nodename = queue_nodename.pop(0)
- if prefix not in prefix_node_dict.keys():
- prefix_node_dict[prefix] = []
-
- # Usually, output name of op is like "outputname:0"
- # -2 is for removing ":0"
- for op in ops:
- if op.type == "MatMul" and (op.inputs[0].name == nodename
- or op.inputs[1].name == nodename):
- prefix_node_dict[prefix].append(op.outputs[0].name[:-2])
- matmul_node_prefix_dict[op.outputs[0].name[:-2]] = prefix
- elif op.type == "Einsum" and (op.inputs[0].name == nodename
- or op.inputs[1].name == nodename):
- prefix_node_dict[prefix].append(op.outputs[0].name[:-2])
- elif op.type == "GatherV2" and op.inputs[0].name == nodename:
- prefix_node_dict[prefix].append(op.outputs[0].name[:-2])
- elif len(op.outputs) == 1:
- for i in range(len(op.inputs)):
- if op.inputs[i].name == nodename:
- queue_prefix.append(prefix)
- queue_nodename.append(op.outputs[0].name)
- break
-
- # When TensorFlow model is converted to TensorFlow Lite model,
- # more than one operation can be fused as one.
- # For example, MatMul + BiasAdd + ReLU in TensorFlow can be fused as
- # one FullyConnected in TensorFlow Lite.
- # It means that even real user nodes of BCQ applicable constant nodes
- # in TensorFlow are found, they may be real user nodes in TensorFlow Lite.
- # Therefore additional candidates of real user nodes should be found either.
- # Finding additional candidates is done by BFS search with queue.
-
- fuseop_prefix_dict = {} # key : Candidate operation / Value : prefix
-
- # These ops can be candidate. However other candidates may exists after these ops.
- mark_type = ["Add", "AddV2", "BiasAdd", "Reshape", "Transpose"]
-
- # These ops can be candidate. And no more candidates will be found after these ops.
- mark_and_stop_type = ["Relu", "Relu6", "Tanh"]
-
- # These ops cannot be candidates but other candidates may exists after these ops.
- # NOTE : Some of following ops may be removed from the list but not sure for now.
- pass_type = [
- "BatchToSpaceND", "Cast", "DepthToSpace", "ExpandDims", "ResizeBilinear",
- "ResizeNearestNeighbor", "ScatterNd", "SpaceToBatchND", "SpaceToDepth", "Squeeze",
- "Identity", "Pack", "Unpack", "Stack"
- ]
-
- queue_prefix = list(matmul_node_prefix_dict.values())
- queue_nodename = [matmul + ":0" for matmul in matmul_node_prefix_dict.keys()]
-
- visited_nodes = set(queue_nodename)
- while len(queue_prefix) > 0:
- prefix = queue_prefix.pop(0)
- nodename = queue_nodename.pop(0)
-
- # Usually, output name of op is like "outputname:0"
- # -2 is for removing ":0"
- for op in ops:
- for i in range(len(op.inputs)):
- if nodename == op.inputs[i].name:
- if op.type in mark_type:
- if op.outputs[0].name[:-2] not in fuseop_prefix_dict.keys():
- fuseop_prefix_dict[op.outputs[0].name[:-2]] = set()
- fuseop_prefix_dict[op.outputs[0].name[:-2]].add(prefix)
- if op.outputs[0].name not in visited_nodes:
- queue_prefix.append(prefix)
- queue_nodename.append(op.outputs[0].name)
- visited_nodes.add(op.outputs[0].name)
- elif op.type in mark_and_stop_type:
- if op.outputs[0].name[:-2] not in fuseop_prefix_dict.keys():
- fuseop_prefix_dict[op.outputs[0].name[:-2]] = set()
- fuseop_prefix_dict[op.outputs[0].name[:-2]].add(prefix)
- elif op.type in pass_type and op.outputs[0].name not in visited_nodes:
- queue_prefix.append(prefix)
- queue_nodename.append(op.outputs[0].name)
- visited_nodes.add(op.outputs[0].name)
-
# the name of metadata node
ret_output_arrays = ['one_compiler/bcqinfo_one_metadata']
# given node from user
- ret_output_arrays.append(output_arrays)
+ ret_output_arrays += output_arrays.split(',')
- # all pairs of candidate operations and related BCQ information nodes
+ # all pairs of a constant node and related BCQ information nodes.
for prefix in prefix_set:
- for fusable_op in prefix_node_dict[prefix]:
- ret_output_arrays.append(prefix + '/bcqinfo_do_w_x')
- ret_output_arrays.append(prefix + '/bcqinfo_alpha')
- ret_output_arrays.append(prefix + '/bcqinfo_packed_binary_code')
- ret_output_arrays.append(prefix + '/bcqinfo_number_of_clusters')
- ret_output_arrays.append(prefix + '/bcqinfo_size_of_clusters')
- ret_output_arrays.append(prefix + '/bcqinfo_qbits_of_clusters')
- ret_output_arrays.append(fusable_op)
- if has_dequant_weight:
- ret_output_arrays.append(prefix + '/bcqinfo_dequant_weight')
- for fuseop in fuseop_prefix_dict.keys():
- if len(fuseop_prefix_dict[fuseop]) == 1:
- prefix = fuseop_prefix_dict[fuseop].pop()
- ret_output_arrays.append(prefix + '/bcqinfo_do_w_x')
- ret_output_arrays.append(prefix + '/bcqinfo_alpha')
- ret_output_arrays.append(prefix + '/bcqinfo_packed_binary_code')
- ret_output_arrays.append(prefix + '/bcqinfo_number_of_clusters')
- ret_output_arrays.append(prefix + '/bcqinfo_size_of_clusters')
- ret_output_arrays.append(prefix + '/bcqinfo_qbits_of_clusters')
- ret_output_arrays.append(fuseop)
- if has_dequant_weight:
- ret_output_arrays.append(prefix + '/bcqinfo_dequant_weight')
+ ret_output_arrays.append(prefix + '/bcqinfo_do_w_x')
+ ret_output_arrays.append(prefix + '/bcqinfo_alpha')
+ ret_output_arrays.append(prefix + '/bcqinfo_packed_binary_code')
+ ret_output_arrays.append(prefix + '/bcqinfo_number_of_clusters')
+ ret_output_arrays.append(prefix + '/bcqinfo_size_of_clusters')
+ ret_output_arrays.append(prefix + '/bcqinfo_qbits_of_clusters')
+ ret_output_arrays.append(prefix)
+ if has_dequant_weight:
+ ret_output_arrays.append(prefix + '/bcqinfo_dequant_weight')
return ret_output_arrays
@@ -216,7 +110,7 @@ def get_bcq_output_arrays(input_path, output_arrays):
if model_version == 1:
return get_bcqinfo_output_arrays_v1(input_path, output_arrays)
elif model_version == -1:
- return None
+ return output_arrays.split(',')
else:
err_msg = "BCQ version of the model(v{}) ".format(model_version)
err_msg += "is higher than "
diff --git a/compiler/circle2circle/src/Circle2Circle.cpp b/compiler/circle2circle/src/Circle2Circle.cpp
index 20e3ea9b6..cde5de8fd 100644
--- a/compiler/circle2circle/src/Circle2Circle.cpp
+++ b/compiler/circle2circle/src/Circle2Circle.cpp
@@ -110,6 +110,18 @@ int entry(int argc, char **argv)
.default_value(false)
.help("This will fuse BatchNorm operators of pre-activations to Convolution operator");
+ arser.add_argument("--remove_redundant_transpose")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("This will fuse or remove subsequent Transpose operators");
+
+ arser.add_argument("--replace_cw_mul_add_with_depthwise_conv")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("This will replace channel-wise mul/add with DepthwiseConv2D operator");
+
arser.add_argument("--resolve_customop_add")
.nargs(0)
.required(false)
@@ -128,6 +140,19 @@ int entry(int argc, char **argv)
.default_value(false)
.help("This will convert Custom(Matmul) to Matmul operator");
+ arser.add_argument("--shuffle_weight_to_16x1float32")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("This will convert weight format of FullyConnected to SHUFFLED16x1FLOAT32. Note that "
+ "it only converts weights whose row is a multiple of 16");
+
+ arser.add_argument("--substitute_pack_to_reshape")
+ .nargs(0)
+ .required(false)
+ .default_value(false)
+ .help("This will convert single input Pack to Reshape");
+
arser.add_argument("--mute_warnings")
.nargs(0)
.required(false)
@@ -196,6 +221,8 @@ int entry(int argc, char **argv)
options->enable(Algorithms::ResolveCustomOpAdd);
options->enable(Algorithms::ResolveCustomOpBatchMatMul);
options->enable(Algorithms::ResolveCustomOpMatMul);
+ options->enable(Algorithms::RemoveRedundantTranspose);
+ options->enable(Algorithms::SubstitutePackToReshape);
}
if (arser.get<bool>("--fold_dequantize"))
options->enable(Algorithms::FoldDequantize);
@@ -213,12 +240,20 @@ int entry(int argc, char **argv)
options->enable(Algorithms::MakeBatchNormGammaPositive);
if (arser.get<bool>("--fuse_preactivation_batchnorm"))
options->enable(Algorithms::FusePreActivationBatchNorm);
+ if (arser.get<bool>("--remove_redundant_transpose"))
+ options->enable(Algorithms::RemoveRedundantTranspose);
+ if (arser.get<bool>("--replace_cw_mul_add_with_depthwise_conv"))
+ options->enable(Algorithms::ReplaceMulAddWithDepthwiseConv);
if (arser.get<bool>("--resolve_customop_add"))
options->enable(Algorithms::ResolveCustomOpAdd);
if (arser.get<bool>("--resolve_customop_batchmatmul"))
options->enable(Algorithms::ResolveCustomOpBatchMatMul);
if (arser.get<bool>("--resolve_customop_matmul"))
options->enable(Algorithms::ResolveCustomOpMatMul);
+ if (arser.get<bool>("--shuffle_weight_to_16x1float32"))
+ options->enable(Algorithms::ShuffleWeightTo16x1Float32);
+ if (arser.get<bool>("--substitute_pack_to_reshape"))
+ options->enable(Algorithms::SubstitutePackToReshape);
if (arser.get<bool>("--mute_warnings"))
settings->set(luci::UserSettings::Key::MuteWarnings, true);
@@ -281,11 +316,14 @@ int entry(int argc, char **argv)
luci::Importer importer;
auto module = importer.importModule(circle_model);
+ // call luci optimizations for module
+ optimizer.optimize(module.get());
+
for (size_t idx = 0; idx < module->size(); ++idx)
{
auto graph = module->graph(idx);
- // call luci optimizations
+ // call luci optimizations for graph
optimizer.optimize(graph);
optimizer.sparsify(graph);
diff --git a/compiler/circlechef/tests/CMakeLists.txt b/compiler/circlechef/tests/CMakeLists.txt
index 4dc58addf..773ff5403 100644
--- a/compiler/circlechef/tests/CMakeLists.txt
+++ b/compiler/circlechef/tests/CMakeLists.txt
@@ -26,6 +26,32 @@ foreach(RECIPE IN ITEMS ${RECIPES})
list(APPEND TESTFILES ${RECIPE_OUTPUT_FILE})
endforeach(RECIPE)
+# Add local files
+file(GLOB RECIPES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*/test.recipe")
+
+foreach(RECIPE IN ITEMS ${RECIPES})
+ get_filename_component(RECIPE_PREFIX ${RECIPE} DIRECTORY)
+
+ set(RECIPE_SOURCE_FILE "${RECIPE_PREFIX}.recipe")
+ set(RECIPE_OUTPUT_FILE "${RECIPE_PREFIX}.circle")
+
+ # Copy .recipe
+ add_custom_command(OUTPUT ${RECIPE_SOURCE_FILE}
+ COMMAND ${CMAKE_COMMAND} -E copy_if_different
+ "${CMAKE_CURRENT_SOURCE_DIR}/${RECIPE}" ${RECIPE_SOURCE_FILE}
+ DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/${RECIPE}"
+ COMMENT "Generating ${RECIPE_SOURCE_FILE}")
+
+ # Generate .circle
+ add_custom_command(OUTPUT ${RECIPE_OUTPUT_FILE}
+ COMMAND circlechef-file ${RECIPE_SOURCE_FILE} ${RECIPE_OUTPUT_FILE}
+ DEPENDS circlechef-file ${RECIPE_SOURCE_FILE}
+ COMMENT "Generating ${RECIPE_OUTPUT_FILE}")
+
+ list(APPEND TESTS ${RECIPE_PREFIX})
+ list(APPEND TESTFILES ${RECIPE_OUTPUT_FILE})
+endforeach(RECIPE)
+
#Test circlechef-reverse
file(GLOB GEN_CIRCLEFILES RELATIVE ${CIRCLERECIPES_DIR} "${CIRCLERECIPES_DIR}/*/test.reverse")
# Note: While in development, circlechef-reverse may not handle the operator.
@@ -58,6 +84,31 @@ foreach(CIRCLEFILE IN ITEMS ${GEN_CIRCLEFILES})
list(APPEND TESTFILES ${RECIPE_GEN_OUTPUT_FILE2})
endforeach(CIRCLEFILE)
+# Test local circlechef-reverse
+file(GLOB GEN_CIRCLEFILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*/test.reverse")
+
+foreach(CIRCLEFILE IN ITEMS ${GEN_CIRCLEFILES})
+ get_filename_component(CIRCLE_PREFIX ${CIRCLEFILE} DIRECTORY)
+
+ set(RECIPE_OUTPUT_FILE "${CIRCLE_PREFIX}.circle")
+ set(RECIPE_GEN_OUTPUT_FILE "${CIRCLE_PREFIX}.gen.recipe")
+ set(RECIPE_GEN_OUTPUT_FILE2 "${CIRCLE_PREFIX}.gen.circle")
+
+ # Generate .gen.recipe from generated .circle
+ add_custom_command(OUTPUT ${RECIPE_GEN_OUTPUT_FILE}
+ COMMAND circlechef-reverse ${RECIPE_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE}
+ DEPENDS circlechef-reverse ${RECIPE_OUTPUT_FILE}
+ COMMENT "Generating ${RECIPE_GEN_OUTPUT_FILE}")
+
+ add_custom_command(OUTPUT ${RECIPE_GEN_OUTPUT_FILE2}
+ COMMAND circlechef-file ${RECIPE_GEN_OUTPUT_FILE} ${RECIPE_GEN_OUTPUT_FILE2}
+ DEPENDS circlechef-file ${RECIPE_GEN_OUTPUT_FILE}
+ COMMENT "Generating ${RECIPE_GEN_OUTPUT_FILE2}")
+
+ list(APPEND TESTS ${CIRCLE_PREFIX}.gen)
+ list(APPEND TESTFILES ${RECIPE_GEN_OUTPUT_FILE2})
+endforeach(CIRCLEFILE)
+
# Add a dummy target to create a target-level dependency.
# TODO Find a way to create a dependency between circlechef_test and generated testfiles.
add_custom_target(circlechef_testfiles ALL DEPENDS ${TESTFILES})
diff --git a/compiler/circlechef/tests/shape_signature/test.recipe b/compiler/circlechef/tests/shape_signature/test.recipe
new file mode 100644
index 000000000..37968ab0b
--- /dev/null
+++ b/compiler/circlechef/tests/shape_signature/test.recipe
@@ -0,0 +1,45 @@
+operand {
+ name: "ifm"
+ type: FLOAT32
+ shape { dim: 1 dim: 8 dim: 6 dim: 12 }
+ shape_signature { dim: -1 dim: 8 dim: 6 dim: 12 }
+}
+operand {
+ name: "gamma"
+ type: FLOAT32
+ shape { dim: 12 }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "1.0"
+ }
+}
+operand {
+ name: "beta"
+ type: FLOAT32
+ shape { dim: 12 }
+ filler {
+ tag: "gaussian"
+ arg: "0.0"
+ arg: "1.0"
+ }
+}
+operand {
+ name: "ofm"
+ type: FLOAT32
+ shape { dim: 1 dim: 8 dim: 6 dim: 12 }
+ shape_signature { dim: -1 dim: 8 dim: 6 dim: 12 }
+}
+operation {
+ type: "InstanceNorm"
+ input: "ifm"
+ input: "gamma"
+ input: "beta"
+ output: "ofm"
+ instance_norm_options {
+ epsilon: 0.00001
+ activation: NONE
+ }
+}
+input: "ifm"
+output: "ofm"
diff --git a/compiler/circlechef/tests/shape_signature/test.reverse b/compiler/circlechef/tests/shape_signature/test.reverse
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/compiler/circlechef/tests/shape_signature/test.reverse
diff --git a/compiler/common-artifacts/exclude.lst b/compiler/common-artifacts/exclude.lst
index b2abfd583..34a4d2c6a 100644
--- a/compiler/common-artifacts/exclude.lst
+++ b/compiler/common-artifacts/exclude.lst
@@ -16,10 +16,6 @@ tcgenerate(AddN_000)
tcgenerate(Add_001) # runtime doesn't support
tcgenerate(Add_U8_000)
tcgenerate(All_000)
-tcgenerate(ArgMax_U8_000)
-tcgenerate(ArgMax_U8_001)
-tcgenerate(ArgMax_U8_002)
-tcgenerate(ArgMax_U8_003)
tcgenerate(ArgMin_000)
tcgenerate(ArgMin_001)
tcgenerate(ArgMin_002)
@@ -35,58 +31,35 @@ tcgenerate(BatchToSpaceND_000)
tcgenerate(Cast_000)
tcgenerate(Cast_001)
tcgenerate(Ceil_000)
-tcgenerate(Concatenation_U8_000)
tcgenerate(Conv2D_003) # runtime doesn't support dilation
-tcgenerate(Conv2D_U8_000)
-tcgenerate(Conv2D_U8_001)
tcgenerate(Cos_000)
tcgenerate(DepthwiseConv2D_001) # runtime doesn't support dilation
tcgenerate(DepthwiseConv2D_003) # runtime doesn't support dilation
-tcgenerate(DepthwiseConv2D_U8_000)
tcgenerate(DepthwiseConv2D_U8_001) # luci-interpreter doesn't support channel-wise quantization yet
tcgenerate(Dequantize_000) # runtime and luci-interpreter doesn't support Dequantize op yet
-tcgenerate(Div_000)
-tcgenerate(Equal_000)
-tcgenerate(Exp_000)
tcgenerate(ExpandDims_000)
tcgenerate(ExpandDims_001)
tcgenerate(ExpandDims_002)
tcgenerate(ExpandDims_003)
tcgenerate(Fill_000)
tcgenerate(Fill_001)
-tcgenerate(Floor_000)
-tcgenerate(FloorDiv_000)
-tcgenerate(FloorDiv_001)
tcgenerate(FloorMod_000)
tcgenerate(FloorMod_001)
-tcgenerate(FullyConnected_002)
tcgenerate(FullyConnected_U8_000)
tcgenerate(Gather_000)
tcgenerate(GatherNd_000)
tcgenerate(GatherNd_001)
-tcgenerate(Greater_000)
-tcgenerate(GreaterEqual_000)
tcgenerate(If_000)
tcgenerate(If_001)
tcgenerate(L2Pool2D_U8_000)
-tcgenerate(Less_000)
-tcgenerate(LessEqual_000)
tcgenerate(Log_000)
-tcgenerate(LogicalAnd_000)
-tcgenerate(LogicalNot_000)
-tcgenerate(LogicalOr_000)
-tcgenerate(LogSoftmax_000)
tcgenerate(MatMul_000)
tcgenerate(MatrixBandPart_000)
tcgenerate(MatrixDiag_000)
tcgenerate(MatrixSetDiag_000)
-tcgenerate(Maximum_000)
-tcgenerate(MaxPool2D_U8_000)
tcgenerate(MaxPoolWithArgMax_000)
tcgenerate(MaxPoolWithArgMax_001)
tcgenerate(MaxPoolWithArgMax_002)
-tcgenerate(Mean_U8_000)
-tcgenerate(Minimum_000)
tcgenerate(NonMaxSuppressionV4_000)
tcgenerate(NonMaxSuppressionV4_001)
tcgenerate(NonMaxSuppressionV5_000)
@@ -99,36 +72,38 @@ tcgenerate(Net_InstanceNorm_001)
tcgenerate(Net_InstanceNorm_002)
tcgenerate(Net_InstanceNorm_003)
tcgenerate(Net_ZeroDim_001) # luci-interpreter doesn't support zero dim
-tcgenerate(NotEqual_000)
tcgenerate(OneHot_000)
tcgenerate(OneHot_001)
tcgenerate(OneHot_002)
tcgenerate(OneHot_003)
tcgenerate(Pack_000)
tcgenerate(Pack_U8_000)
-tcgenerate(Pad_U8_000)
tcgenerate(PadV2_000)
-tcgenerate(Pow_000)
tcgenerate(Range_000)
tcgenerate(Rank_000)
tcgenerate(ReduceAny_000)
tcgenerate(ReduceAny_001)
tcgenerate(ReduceAny_002)
tcgenerate(ReduceAny_003)
+tcgenerate(ReduceAny_dynamic_000)
+tcgenerate(ReduceAny_dynamic_001)
+tcgenerate(ReduceAny_dynamic_002)
+tcgenerate(ReduceAny_dynamic_003)
tcgenerate(ReduceMax_000)
+tcgenerate(ReduceMax_dynamic_000)
tcgenerate(ReduceMin_000)
+tcgenerate(ReduceMin_dynamic_000)
tcgenerate(ReduceProd_000)
tcgenerate(ReduceProd_001)
tcgenerate(ReduceProd_002)
tcgenerate(ReduceProd_003)
-tcgenerate(ReLU_000)
-tcgenerate(ReLU6_000)
+tcgenerate(ReduceProd_dynamic_000)
+tcgenerate(ReduceProd_dynamic_001)
+tcgenerate(ReduceProd_dynamic_002)
+tcgenerate(ReduceProd_dynamic_003)
tcgenerate(ReLUN1To1_000)
+tcgenerate(ReLUN1To1_dynamic_000)
tcgenerate(Reshape_003) # luci-interpreter doesn't support reshape without built-in option
-tcgenerate(Reshape_U8_000)
-tcgenerate(ResizeBilinear_000)
-tcgenerate(ResizeBilinear_U8_000) # luci-interpreter
-tcgenerate(ResizeNearestNeighbor_000)
tcgenerate(ReverseSequence_000)
tcgenerate(ReverseV2_000)
tcgenerate(Round_000)
@@ -142,7 +117,6 @@ tcgenerate(SelectV2_001)
tcgenerate(SelectV2_002)
tcgenerate(Shape_000)
tcgenerate(Sin_000)
-tcgenerate(Softmax_U8_000)
tcgenerate(SpaceToBatchND_000)
tcgenerate(SpaceToBatchND_001)
tcgenerate(SpaceToBatchND_002)
@@ -151,11 +125,10 @@ tcgenerate(SparseToDense_000)
tcgenerate(SplitV_000)
tcgenerate(Square_000)
tcgenerate(SquaredDifference_000)
-tcgenerate(Sub_000)
-tcgenerate(Sub_001)
-tcgenerate(Sub_U8_000)
tcgenerate(Sum_000)
tcgenerate(Sum_001)
+tcgenerate(Sum_dynamic_000)
+tcgenerate(Sum_dynamic_001)
tcgenerate(Tile_000)
tcgenerate(Tile_U8_000)
tcgenerate(TopKV2_000)
@@ -184,3 +157,4 @@ tcgenerate(BCQFullyConnected_001)
tcgenerate(BCQGather_000)
tcgenerate(CircleBatchMatMul_000)
tcgenerate(InstanceNorm_000)
+tcgenerate(InstanceNorm_001)
diff --git a/compiler/exo/src/Circle/CircleExporterUtils.h b/compiler/exo/src/Circle/CircleExporterUtils.h
index fdd162bae..78f0cf7ed 100644
--- a/compiler/exo/src/Circle/CircleExporterUtils.h
+++ b/compiler/exo/src/Circle/CircleExporterUtils.h
@@ -65,7 +65,7 @@ namespace circle_detail
{
/**
- * @breif Record the information of T/F Lite SubGraph and its mapping to loco
+ * @brief Record the information of T/F Lite SubGraph and its mapping to loco
*/
struct SubGraphContext
{
diff --git a/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp
index f4bb10364..26cc561e1 100644
--- a/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp
+++ b/compiler/exo/src/Dialect/Service/TFLShapeInferenceRule.cpp
@@ -116,7 +116,7 @@ private:
};
/**
- * @breif Expand shape x and y to same rank by align right and filling with 1
+ * @brief Expand shape x and y to same rank by align right and filling with 1
*/
void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
{
@@ -136,7 +136,7 @@ void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
}
/**
- * @breif Returns shape of expanded dimension of input x and y having same rank
+ * @brief Returns shape of expanded dimension of input x and y having same rank
*/
loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
{
diff --git a/compiler/exo/src/TFLite/TFLExporterUtils.h b/compiler/exo/src/TFLite/TFLExporterUtils.h
index dbd7a52fb..f2fe6075e 100644
--- a/compiler/exo/src/TFLite/TFLExporterUtils.h
+++ b/compiler/exo/src/TFLite/TFLExporterUtils.h
@@ -65,7 +65,7 @@ namespace tflite_detail
{
/**
- * @breif Record the information of T/F Lite SubGraph and its mapping to loco
+ * @brief Record the information of T/F Lite SubGraph and its mapping to loco
*/
struct SubGraphContext
{
diff --git a/compiler/hermes/include/hermes/core/Message.h b/compiler/hermes/include/hermes/core/Message.h
index 28cfd7942..460163f64 100644
--- a/compiler/hermes/include/hermes/core/Message.h
+++ b/compiler/hermes/include/hermes/core/Message.h
@@ -37,7 +37,7 @@ public:
public:
/// @brief The number of lines
uint32_t lines(void) const { return _lines.size(); }
- /// @breif The content of a specific line
+ /// @brief The content of a specific line
const std::string &line(uint32_t n) const { return _lines.at(n); }
private:
diff --git a/compiler/luci-interpreter/src/kernels/Conv2D.cpp b/compiler/luci-interpreter/src/kernels/Conv2D.cpp
index 47e2498f1..c5069e403 100644
--- a/compiler/luci-interpreter/src/kernels/Conv2D.cpp
+++ b/compiler/luci-interpreter/src/kernels/Conv2D.cpp
@@ -135,7 +135,17 @@ void Conv2D::execute() const
}
throw std::runtime_error("Unsupported type.");
case DataType::U8:
- evalQuantized();
+ if (filter()->scales().size() == 1)
+ {
+ evalQuantized();
+ }
+ else if (filter()->scales().size() > 1)
+ {
+ LUCI_INTERPRETER_CHECK(filter()->shape().num_dims() == 4);
+ LUCI_INTERPRETER_CHECK(filter()->scales().size() ==
+ static_cast<size_t>(filter()->shape().dim(0)));
+ evalQuantizedPerChannel();
+ }
break;
case DataType::S16:
evalQuantizedS16();
@@ -219,6 +229,92 @@ void Conv2D::evalQuantized() const
getTensorData<uint8_t>(_im2col.get()), gemmlowp_context.get());
}
+void Conv2D::evalQuantizedPerChannel() const
+{
+ const auto *input_data = getTensorData<uint8_t>(input());
+ const auto *filter_data = getTensorData<uint8_t>(filter());
+ const auto *bias_data = getTensorData<int32_t>(bias());
+ auto *output_data = getTensorData<uint8_t>(output());
+
+ const Shape &input_shape = input()->shape();
+ const Shape &filter_shape = filter()->shape();
+ const Shape &output_shape = output()->shape();
+
+ const int32_t batches = input_shape.dim(0);
+ const int32_t input_height = input_shape.dim(1);
+ const int32_t input_width = input_shape.dim(2);
+ const int32_t input_depth = input_shape.dim(3);
+ const int32_t output_depth = filter_shape.dim(0);
+ const int32_t filter_height = filter_shape.dim(1);
+ const int32_t filter_width = filter_shape.dim(2);
+ const int32_t output_height = output_shape.dim(1);
+ const int32_t output_width = output_shape.dim(2);
+
+ const int32_t stride_height = _params.stride_height;
+ const int32_t stride_width = _params.stride_width;
+ const int32_t dilation_height_factor = _params.dilation_height_factor;
+ const int32_t dilation_width_factor = _params.dilation_width_factor;
+
+ int32_t activation_min{};
+ int32_t activation_max{};
+ calculateActivationRangeQuantized(_params.activation, output(), &activation_min, &activation_max);
+
+ const std::vector<double> effective_output_scale =
+ getQuantizedConvolutionMultiplers(input()->scale(), filter()->scales(), output()->scale());
+
+ const std::vector<ChannelQuantMultipliers> multipliers_raw =
+ quantizeMultipliers(effective_output_scale);
+ BroadcastableWrapper<ChannelQuantMultipliers> quant_multipliers(multipliers_raw);
+
+ for (int32_t batch = 0; batch < batches; ++batch)
+ {
+ for (int32_t out_y = 0; out_y < output_height; ++out_y)
+ {
+ for (int32_t out_x = 0; out_x < output_width; ++out_x)
+ {
+ for (int32_t out_c = 0; out_c < output_depth; ++out_c)
+ {
+ const int32_t in_y_origin = out_y * stride_height - _padding_height;
+ const int32_t in_x_origin = out_x * stride_width - _padding_width;
+ int32_t acc = 0;
+ for (int32_t filter_y = 0; filter_y < filter_height; ++filter_y)
+ {
+ for (int32_t filter_x = 0; filter_x < filter_width; ++filter_x)
+ {
+ const int32_t in_y = in_y_origin + dilation_height_factor * filter_y;
+ const int32_t in_x = in_x_origin + dilation_width_factor * filter_x;
+ if ((in_y >= 0 && in_y < input_height) && (in_x >= 0 && in_x < input_width))
+ {
+ for (int32_t in_c = 0; in_c < input_depth; ++in_c)
+ {
+ const uint8_t input_val =
+ input_data[calcOffset(input_shape, batch, in_y, in_x, in_c)];
+ const uint8_t filter_val =
+ filter_data[calcOffset(filter_shape, out_c, filter_y, filter_x, in_c)];
+ acc += static_cast<int32_t>(input_val - input()->zero_point()) *
+ static_cast<int32_t>(filter_val - filter()->zero_points()[out_c]);
+ }
+ }
+ }
+ }
+ if (bias_data)
+ {
+ acc += bias_data[out_c];
+ }
+
+ int32_t scaled_acc = tflite::MultiplyByQuantizedMultiplier(
+ acc, quant_multipliers[out_c].multiplier, quant_multipliers[out_c].shift);
+
+ scaled_acc += output()->zero_point();
+ scaled_acc = std::max(scaled_acc, activation_min);
+ scaled_acc = std::min(scaled_acc, activation_max);
+ output_data[calcOffset(output_shape, batch, out_y, out_x, out_c)] = scaled_acc;
+ }
+ }
+ }
+ }
+}
+
void Conv2D::evalQuantizedS16() const
{
const auto *input_data = getTensorData<int16_t>(input());
diff --git a/compiler/luci-interpreter/src/kernels/Conv2D.h b/compiler/luci-interpreter/src/kernels/Conv2D.h
index 83ac67d3d..86f73c251 100644
--- a/compiler/luci-interpreter/src/kernels/Conv2D.h
+++ b/compiler/luci-interpreter/src/kernels/Conv2D.h
@@ -44,6 +44,7 @@ public:
private:
void evalFloat() const;
void evalQuantized() const;
+ void evalQuantizedPerChannel() const;
void evalQuantizedS16() const;
private:
diff --git a/compiler/luci-interpreter/src/kernels/Conv2D.test.cpp b/compiler/luci-interpreter/src/kernels/Conv2D.test.cpp
index 7aa66a898..35a0c5491 100644
--- a/compiler/luci-interpreter/src/kernels/Conv2D.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/Conv2D.test.cpp
@@ -169,6 +169,78 @@ TEST(Conv2DTest, Uint8)
EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
}
+TEST(Conv2DTest, Uint8_CWQ)
+{
+ const int output_channels = 3;
+ std::vector<float> input_data{
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ };
+ std::vector<float> filter_data{
+ 1, 2, 3, 4, // first 2x2 filter
+ -1, 1, -1, 1, // second 2x2 filter
+ -1, -1, 1, 1, // third 2x2 filter
+ };
+ std::vector<float> bias_data{1, 2, 3};
+ Shape filter_shape{output_channels, 2, 2, 1};
+
+ std::pair<float, int32_t> input_quant_param = quantizationParams<uint8_t>(0, 4);
+ std::pair<float, int32_t> output_quant_param = quantizationParams<uint8_t>(-127, 128);
+
+ std::vector<std::pair<float, int32_t>> filter_quant_params;
+ filter_quant_params.push_back(quantizationParams<uint8_t>(0, 4));
+ filter_quant_params.push_back(quantizationParams<uint8_t>(-1, 1));
+ filter_quant_params.push_back(quantizationParams<uint8_t>(-1, 1));
+
+ std::vector<float> filter_scales;
+ std::vector<int32_t> filter_zerops;
+ for (auto iter : filter_quant_params)
+ {
+ filter_scales.push_back(iter.first);
+ filter_zerops.push_back(iter.second);
+ }
+
+ std::vector<float> bias_scales;
+ for (int i = 0; i < output_channels; ++i)
+ bias_scales.push_back(filter_quant_params[i].first * input_quant_param.first);
+ std::vector<int32_t> zerop(output_channels, 0);
+
+ Tensor input_tensor = makeInputTensor<DataType::U8>({2, 2, 4, 1}, input_quant_param.first,
+ input_quant_param.second, input_data);
+ Tensor filter_tensor =
+ makeInputTensor<DataType::U8>(filter_shape, filter_scales, filter_zerops, 0, filter_data);
+ Tensor bias_tensor =
+ makeInputTensor<DataType::S32>({output_channels}, bias_scales, zerop, 0, bias_data);
+ Tensor output_tensor =
+ makeOutputTensor(DataType::U8, output_quant_param.first, output_quant_param.second);
+
+ Conv2DParams params{};
+ params.padding = Padding::VALID;
+ params.stride_height = 2;
+ params.stride_width = 2;
+ params.dilation_height_factor = 1;
+ params.dilation_width_factor = 1;
+ params.activation = Activation::NONE;
+
+ Conv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, params);
+ kernel.configure();
+ kernel.execute();
+
+ std::vector<float> ref_output_data{
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ };
+ std::vector<int32_t> ref_output_shape{2, 1, 2, 3};
+ EXPECT_THAT(dequantizeTensorData(output_tensor), FloatArrayNear(ref_output_data));
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
+}
+
TEST(Conv2DTest, SInt16)
{
Shape input_shape{1, 4, 3, 2};
diff --git a/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.cpp b/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.cpp
index 1957f3c9d..921133191 100644
--- a/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.cpp
+++ b/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.cpp
@@ -111,7 +111,17 @@ void DepthwiseConv2D::execute() const
}
throw std::runtime_error("Unsupported type.");
case DataType::U8:
- evalQuantized();
+ if (filter()->scales().size() == 1)
+ {
+ evalQuantized();
+ }
+ else if (filter()->scales().size() > 1)
+ {
+ LUCI_INTERPRETER_CHECK(filter()->shape().num_dims() == 4);
+ LUCI_INTERPRETER_CHECK(filter()->scales().size() ==
+ static_cast<size_t>(filter()->shape().dim(3)));
+ evalQuantizedPerChannel();
+ }
break;
case DataType::S16:
evalQuantizedS16();
@@ -144,6 +154,97 @@ void DepthwiseConv2D::evalFloat() const
getTensorShape(output()), getTensorData<float>(output()));
}
+void DepthwiseConv2D::evalQuantizedPerChannel() const
+{
+ const auto *input_data = getTensorData<uint8_t>(input());
+ const auto *filter_data = getTensorData<uint8_t>(filter());
+ const auto *bias_data = getTensorData<int32_t>(bias());
+ auto *output_data = getTensorData<uint8_t>(output());
+
+ const Shape &input_shape = input()->shape();
+ const Shape &filter_shape = filter()->shape();
+ const Shape &output_shape = output()->shape();
+
+ const int32_t batches = input_shape.dim(0);
+ const int32_t input_height = input_shape.dim(1);
+ const int32_t input_width = input_shape.dim(2);
+ const int32_t input_depth = input_shape.dim(3);
+ const int32_t filter_height = filter_shape.dim(1);
+ const int32_t filter_width = filter_shape.dim(2);
+ const int32_t output_height = output_shape.dim(1);
+ const int32_t output_width = output_shape.dim(2);
+
+ const int32_t stride_height = _params.stride_height;
+ const int32_t stride_width = _params.stride_width;
+ const int32_t dilation_height_factor = _params.dilation_height_factor;
+ const int32_t dilation_width_factor = _params.dilation_width_factor;
+ const int32_t depth_multiplier = _params.depth_multiplier;
+
+ int32_t activation_min{};
+ int32_t activation_max{};
+ calculateActivationRangeQuantized(_params.activation, output(), &activation_min, &activation_max);
+
+ const std::vector<double> effective_output_scales =
+ getQuantizedConvolutionMultiplers(input()->scale(), filter()->scales(), output()->scale());
+
+ std::vector<ChannelQuantMultipliers> quant_multipliers_raw =
+ quantizeMultipliers(effective_output_scales);
+ BroadcastableWrapper<ChannelQuantMultipliers> quant_multipliers(quant_multipliers_raw);
+
+ for (int batch = 0; batch < batches; ++batch)
+ {
+ for (int out_y = 0; out_y < output_height; ++out_y)
+ {
+ for (int out_x = 0; out_x < output_width; ++out_x)
+ {
+ for (int in_channel = 0; in_channel < input_depth; ++in_channel)
+ {
+ for (int m = 0; m < depth_multiplier; ++m)
+ {
+ const int output_channel = m + in_channel * depth_multiplier;
+ const int in_x_origin = (out_x * stride_width) - _padding_width;
+ const int in_y_origin = (out_y * stride_height) - _padding_height;
+ int32 acc = 0;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y)
+ {
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x)
+ {
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y = in_y_origin + dilation_height_factor * filter_y;
+ // Zero padding by omitting the areas outside the image.
+ const bool is_point_inside_image =
+ (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && (in_y < input_height);
+ if (is_point_inside_image)
+ {
+ int32 input_val =
+ input_data[calcOffset(input_shape, batch, in_y, in_x, in_channel)];
+ int32 filter_val =
+ filter_data[calcOffset(filter_shape, 0, filter_y, filter_x, output_channel)];
+ acc += (filter_val - filter()->zero_points()[output_channel]) *
+ (input_val - input()->zero_point());
+ }
+ }
+ }
+ if (bias_data)
+ {
+ acc += bias_data[output_channel];
+ }
+ int32_t output_multiplier = quant_multipliers[output_channel].multiplier;
+ int output_shift = quant_multipliers[output_channel].shift;
+ int32_t scaled_acc =
+ tflite::MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
+ scaled_acc += output()->zero_point();
+ scaled_acc = std::max(scaled_acc, activation_min);
+ scaled_acc = std::min(scaled_acc, activation_max);
+ output_data[calcOffset(output_shape, batch, out_y, out_x, output_channel)] =
+ static_cast<uint8_t>(scaled_acc);
+ }
+ }
+ }
+ }
+ }
+}
+
void DepthwiseConv2D::evalQuantized() const
{
const auto input_scale = static_cast<double>(input()->scale());
diff --git a/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.h b/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.h
index 400bebe5a..6d700dd0f 100644
--- a/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.h
+++ b/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.h
@@ -42,6 +42,7 @@ public:
private:
void evalFloat() const;
void evalQuantized() const;
+ void evalQuantizedPerChannel() const;
void evalQuantizedS16() const;
private:
diff --git a/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.test.cpp b/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.test.cpp
index 0c76b585e..f79e888a1 100644
--- a/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/DepthwiseConv2D.test.cpp
@@ -220,6 +220,79 @@ TEST(DepthwiseConv2DTest, SInt16_CWQ_weights)
EXPECT_THAT(dequantizeTensorData(output_tensor), FloatArrayNear(ref_output_data));
}
+TEST(DepthwiseConv2DTest, Uint8_CWQ_weights)
+{
+ const int output_channels = 4;
+ Shape input_shape{1, 3, 2, 2};
+ Shape filter_shape{1, 2, 2, output_channels};
+ Shape bias_shape{4};
+ std::vector<int32_t> ref_output_shape{1, 2, 1, output_channels};
+
+ std::vector<float> input_data{
+ 1, 2, 7, 8, //
+ 3, 4, 9, 10, //
+ 5, 6, 11, 12, //
+ };
+ std::vector<float> filter_data{
+ 1, 2, 3, 4, //
+ -9, 10, -11, 12, //
+ 5, 6, 7, 8, //
+ 13, -14, 15, -16, //
+ };
+ std::vector<float> bias_data{1, 2, 3, 4};
+ std::vector<float> ref_output_data{
+ 71, -34, 99, -20, //
+ 91, -26, 127, -4, //
+ };
+
+ std::pair<float, int32_t> input_quant_param = quantizationParams<uint8_t>(0, 16);
+ std::pair<float, int32_t> output_quant_param = quantizationParams<uint8_t>(-127, 128);
+
+ std::vector<std::pair<float, int32_t>> filter_quant_params;
+ filter_quant_params.push_back(quantizationParams<uint8_t>(-9, 13));
+ filter_quant_params.push_back(quantizationParams<uint8_t>(-14, 10));
+ filter_quant_params.push_back(quantizationParams<uint8_t>(-11, 15));
+ filter_quant_params.push_back(quantizationParams<uint8_t>(-16, 12));
+
+ std::vector<float> filter_scales;
+ std::vector<int32_t> filter_zerops;
+ for (auto iter : filter_quant_params)
+ {
+ filter_scales.push_back(iter.first);
+ filter_zerops.push_back(iter.second);
+ }
+
+ std::vector<float> bias_scales;
+ for (int i = 0; i < output_channels; ++i)
+ bias_scales.push_back(filter_quant_params[i].first * input_quant_param.first);
+ std::vector<int32_t> zerop(output_channels, 0);
+
+ Tensor input_tensor = makeInputTensor<DataType::U8>(input_shape, input_quant_param.first,
+ input_quant_param.second, input_data);
+ Tensor filter_tensor =
+ makeInputTensor<DataType::U8>(filter_shape, filter_scales, filter_zerops, 3, filter_data);
+ Tensor bias_tensor = makeInputTensor<DataType::S32>(bias_shape, bias_scales, zerop, 0, bias_data);
+ Tensor output_tensor =
+ makeOutputTensor(DataType::U8, output_quant_param.first, output_quant_param.second);
+
+ DepthwiseConv2DParams params{};
+ params.padding = Padding::VALID;
+ params.depth_multiplier = 2;
+ params.stride_height = 1;
+ params.stride_width = 1;
+ params.dilation_height_factor = 1;
+ params.dilation_width_factor = 1;
+ params.activation = Activation::NONE;
+
+ DepthwiseConv2D kernel(&input_tensor, &filter_tensor, &bias_tensor, &output_tensor, params);
+ kernel.configure();
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(ref_output_shape));
+ EXPECT_THAT(dequantizeTensorData(output_tensor),
+ FloatArrayNear(ref_output_data, output_quant_param.first));
+}
+
TEST(DepthwiseConv2DTest, InvalidBiasType_NEG)
{
Shape input_shape{1, 4, 2, 2};
diff --git a/compiler/luci-interpreter/src/kernels/TransposeConv.cpp b/compiler/luci-interpreter/src/kernels/TransposeConv.cpp
index b0ee905dc..491ae51ae 100644
--- a/compiler/luci-interpreter/src/kernels/TransposeConv.cpp
+++ b/compiler/luci-interpreter/src/kernels/TransposeConv.cpp
@@ -93,7 +93,17 @@ void TransposeConv::execute() const
evalFloat();
break;
case DataType::U8:
- evalQuantized();
+ if (filter()->scales().size() == 1)
+ {
+ evalQuantized();
+ }
+ else if (filter()->scales().size() > 1)
+ {
+ LUCI_INTERPRETER_CHECK(filter()->shape().num_dims() == 4);
+ LUCI_INTERPRETER_CHECK(filter()->scales().size() ==
+ static_cast<size_t>(filter()->shape().dim(0)));
+ evalQuantizedPerChannel();
+ }
break;
case DataType::S16:
evalQuantizedS16();
@@ -147,6 +157,98 @@ void TransposeConv::evalQuantized() const
getTensorData<int32_t>(_scratch_tensor.get()));
}
+void TransposeConv::evalQuantizedPerChannel() const
+{
+ const auto *input_data = getTensorData<uint8_t>(input());
+ const auto *filter_data = getTensorData<uint8_t>(filter());
+ const auto *bias_data = getTensorData<int32_t>(bias());
+ auto *output_data = getTensorData<uint8_t>(output());
+ auto *scratch_data = getTensorData<int32_t>(_scratch_tensor.get());
+
+ const Shape &input_shape = input()->shape();
+ const Shape &filter_shape = filter()->shape();
+ const Shape &output_shape = output()->shape();
+
+ const int32_t batches = input_shape.dim(0);
+ const int32_t input_height = input_shape.dim(1);
+ const int32_t input_width = input_shape.dim(2);
+ const int32_t input_depth = input_shape.dim(3);
+ const int32_t output_depth = filter_shape.dim(0);
+ const int32_t filter_height = filter_shape.dim(1);
+ const int32_t filter_width = filter_shape.dim(2);
+ const int32_t output_height = output_shape.dim(1);
+ const int32_t output_width = output_shape.dim(2);
+
+ const int32_t stride_height = _params.stride_height;
+ const int32_t stride_width = _params.stride_width;
+
+ int32_t activation_min{};
+ int32_t activation_max{};
+ calculateActivationRangeQuantized(Activation::NONE, output(), &activation_min, &activation_max);
+
+ std::memset(scratch_data, 0, _scratch_tensor->shape().num_elements() * sizeof(int32_t));
+
+ BroadcastableWrapper<ChannelQuantMultipliers> output_multipliers(_quant_multipliers);
+ for (int32_t batch = 0; batch < batches; ++batch)
+ {
+ for (int32_t in_y = 0; in_y < input_height; ++in_y)
+ {
+ for (int32_t in_x = 0; in_x < input_width; ++in_x)
+ {
+ for (int32_t in_c = 0; in_c < input_depth; ++in_c)
+ {
+ const int32_t out_y_origin = in_y * stride_height - _padding_height;
+ const int32_t out_x_origin = in_x * stride_width - _padding_width;
+ for (int32_t filter_y = 0; filter_y < filter_height; ++filter_y)
+ {
+ for (int32_t filter_x = 0; filter_x < filter_width; ++filter_x)
+ {
+ const int32_t out_x = out_x_origin + filter_x;
+ const int32_t out_y = out_y_origin + filter_y;
+ if ((out_y >= 0 && out_y < output_height) && (out_x >= 0 && out_x < output_width))
+ {
+ for (int32_t out_c = 0; out_c < output_depth; ++out_c)
+ {
+ const uint8_t input_val =
+ input_data[calcOffset(input_shape, batch, in_y, in_x, in_c)];
+ const uint8_t filter_val =
+ filter_data[calcOffset(filter_shape, out_c, filter_y, filter_x, in_c)];
+ scratch_data[calcOffset(output_shape, batch, out_y, out_x, out_c)] +=
+ static_cast<int32_t>(input_val - input()->zero_point()) *
+ static_cast<int32_t>(filter_val - filter()->zero_points()[out_c]);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ for (int32_t out_y = 0; out_y < output_height; ++out_y)
+ {
+ for (int32_t out_x = 0; out_x < output_width; ++out_x)
+ {
+ for (int32_t out_c = 0; out_c < output_depth; ++out_c)
+ {
+ int32_t acc = scratch_data[calcOffset(output_shape, batch, out_y, out_x, out_c)];
+ if (bias_data)
+ {
+ acc += bias_data[out_c];
+ }
+
+ int32_t scaled_acc = tflite::MultiplyByQuantizedMultiplier(
+ acc, output_multipliers[out_c].multiplier, output_multipliers[out_c].shift);
+
+ scaled_acc += output()->zero_point();
+ scaled_acc = std::max(scaled_acc, activation_min);
+ scaled_acc = std::min(scaled_acc, activation_max);
+
+ output_data[calcOffset(output_shape, batch, out_y, out_x, out_c)] = scaled_acc;
+ }
+ }
+ }
+ }
+}
+
void TransposeConv::evalQuantizedS16() const
{
const auto *input_data = getTensorData<int16_t>(input());
diff --git a/compiler/luci-interpreter/src/kernels/TransposeConv.h b/compiler/luci-interpreter/src/kernels/TransposeConv.h
index f51e16976..2e0beece8 100644
--- a/compiler/luci-interpreter/src/kernels/TransposeConv.h
+++ b/compiler/luci-interpreter/src/kernels/TransposeConv.h
@@ -47,6 +47,7 @@ public:
private:
void evalFloat() const;
void evalQuantized() const;
+ void evalQuantizedPerChannel() const;
void evalQuantizedS16() const;
private:
diff --git a/compiler/luci-interpreter/src/kernels/TransposeConv.test.cpp b/compiler/luci-interpreter/src/kernels/TransposeConv.test.cpp
index 8564de01d..b1309c128 100644
--- a/compiler/luci-interpreter/src/kernels/TransposeConv.test.cpp
+++ b/compiler/luci-interpreter/src/kernels/TransposeConv.test.cpp
@@ -154,6 +154,65 @@ TEST(TransposeConvTest, UInt8)
EXPECT_THAT(dequantizeTensorData(output_tensor), FloatArrayNear(ref_output_data));
}
+TEST(TransposeConvTest, UInt8_CWQ)
+{
+ const int32_t output_channels = 2;
+ std::vector<float> input_data{1, 2, 3, 4};
+ std::vector<float> filter_data{1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, 8, 10, 12, 14, 16, 18};
+ std::vector<float> bias_data{3, 4};
+ std::vector<int32_t> output_shape_data{1, 5, 5, 2};
+ std::vector<float> ref_output_data{
+ 4, 6, 6, 8, 10, 14, 9, 12, 13, 16, //
+ 10, 12, 12, 14, 28, 32, 21, 24, 25, 28, //
+ 19, 24, 27, 32, 65, 76, 45, 52, 57, 64, //
+ 24, 28, 30, 34, 64, 72, 39, 44, 47, 52, //
+ 42, 46, 48, 52, 106, 114, 63, 68, 71, 76, //
+ };
+
+ // Choose quantization parameters carefully.
+ auto input_quant = quantizationParams<uint8_t>(-8.0, 7.9375); // s = 1 / 16, zp = 128
+ auto output_quant = quantizationParams<uint8_t>(-64.0, 191.0); // s = 1, zp = 64
+
+ std::vector<std::pair<float, int32_t>> filter_quant_params;
+ filter_quant_params.push_back(quantizationParams<uint8_t>(0, 17));
+ filter_quant_params.push_back(quantizationParams<uint8_t>(0, 18));
+
+ std::vector<float> filter_scales;
+ std::vector<int32_t> filter_zerops;
+ for (auto iter : filter_quant_params)
+ {
+ filter_scales.push_back(iter.first);
+ filter_zerops.push_back(iter.second);
+ }
+
+ std::vector<float> bias_scales;
+ for (int i = 0; i < output_channels; ++i)
+ bias_scales.push_back(filter_quant_params[i].first * input_quant.first);
+ std::vector<int32_t> zerop(output_channels, 0);
+
+ Tensor input_tensor = makeInputTensor<DataType::U8>({1, 2, 2, 1}, input_quant.first,
+ input_quant.second, input_data);
+ Tensor filter_tensor = makeInputTensor<DataType::U8>({output_channels, 3, 3, 1}, filter_scales,
+ filter_zerops, 0, filter_data);
+ Tensor bias_tensor =
+ makeInputTensor<DataType::S32>({output_channels}, bias_scales, zerop, 0, bias_data);
+ Tensor output_shape_tensor = makeInputTensor<DataType::S32>({4}, output_shape_data);
+ Tensor output_tensor = makeOutputTensor(DataType::U8, output_quant.first, output_quant.second);
+
+ TransposeConvParams params{};
+ params.padding = Padding::VALID;
+ params.stride_height = 2;
+ params.stride_width = 2;
+
+ TransposeConv kernel(&output_shape_tensor, &filter_tensor, &input_tensor, &bias_tensor,
+ &output_tensor, params);
+ kernel.configure();
+ kernel.execute();
+
+ EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape_data));
+ EXPECT_THAT(dequantizeTensorData(output_tensor), FloatArrayNear(ref_output_data));
+}
+
TEST(TransposeConvTest, SInt16)
{
std::vector<float> input_data{1, 2, 3, 4};
diff --git a/compiler/luci-interpreter/src/loader/GraphLoader.cpp b/compiler/luci-interpreter/src/loader/GraphLoader.cpp
index c52d99e6f..09e923597 100644
--- a/compiler/luci-interpreter/src/loader/GraphLoader.cpp
+++ b/compiler/luci-interpreter/src/loader/GraphLoader.cpp
@@ -57,8 +57,12 @@ const void *getNodeData(const luci::CircleConst *node, size_t *data_size)
return getNodeDataImpl<DataType::U8>(node, data_size);
case DataType::FLOAT32:
return getNodeDataImpl<DataType::FLOAT32>(node, data_size);
+ case DataType::S16:
+ return getNodeDataImpl<DataType::S16>(node, data_size);
case DataType::S32:
return getNodeDataImpl<DataType::S32>(node, data_size);
+ case DataType::S64:
+ return getNodeDataImpl<DataType::S64>(node, data_size);
default:
throw std::runtime_error("Unsupported type.");
}
diff --git a/compiler/luci/export/src/CircleExporterImpl.cpp b/compiler/luci/export/src/CircleExporterImpl.cpp
index 860cebf6e..df7542797 100644
--- a/compiler/luci/export/src/CircleExporterImpl.cpp
+++ b/compiler/luci/export/src/CircleExporterImpl.cpp
@@ -16,7 +16,6 @@
#include "CircleExporterImpl.h"
#include "Optimize.h"
-#include "TypeBridge.h"
#include "CircleTensorExporter.h"
#include "CircleOperationExporter.h"
#include "CircleExporterUtils.h"
@@ -150,9 +149,6 @@ void CircleExporterImpl::exportGraph(loco::Graph *graph)
// do graph optimization
optimize(graph);
- // copy shape/dtype inference data to CircleNode
- copy_shape_dtype(graph);
-
_builder.Clear();
SerializedModelData md;
@@ -223,9 +219,6 @@ void CircleExporterImpl::exportModule(Module *module)
optimize(graph);
- // copy shape/dtype inference data to CircleNode
- copy_shape_dtype(graph);
-
SerializedGraphData gd;
// set Subgraph name
diff --git a/compiler/luci/export/src/CircleExporterUtils.cpp b/compiler/luci/export/src/CircleExporterUtils.cpp
index 1fdb40e51..3715513e0 100644
--- a/compiler/luci/export/src/CircleExporterUtils.cpp
+++ b/compiler/luci/export/src/CircleExporterUtils.cpp
@@ -87,6 +87,22 @@ circle::MirrorPadMode to_circle_mirrorpadmode(luci::MirrorPadMode mode)
}
}
+circle::FullyConnectedOptionsWeightsFormat
+to_circle_weightsformat(luci::CircleFullyConnected::WeightsFormat format)
+{
+ switch (format)
+ {
+ case luci::CircleFullyConnected::WeightsFormat::DEFAULT:
+ return circle::FullyConnectedOptionsWeightsFormat_DEFAULT;
+ case luci::CircleFullyConnected::WeightsFormat::SHUFFLED4x16INT8:
+ return circle::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
+ case luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32:
+ return circle::FullyConnectedOptionsWeightsFormat_SHUFFLED16x1FLOAT32;
+ default:
+ INTERNAL_EXN_V("trying to convert unsupported luci::WeightsFormat", oops::to_uint32(format));
+ }
+}
+
circle::DimensionType to_circle_dimensiontype(luci::DimensionType type)
{
switch (type)
diff --git a/compiler/luci/export/src/CircleExporterUtils.h b/compiler/luci/export/src/CircleExporterUtils.h
index 7857213b2..95310b353 100644
--- a/compiler/luci/export/src/CircleExporterUtils.h
+++ b/compiler/luci/export/src/CircleExporterUtils.h
@@ -32,6 +32,8 @@ namespace luci
circle::ActivationFunctionType to_circle_actfunc(luci::FusedActFunc func);
circle::TensorType to_circle_tensortype(loco::DataType type);
circle::MirrorPadMode to_circle_mirrorpadmode(luci::MirrorPadMode mode);
+circle::FullyConnectedOptionsWeightsFormat
+to_circle_weightsformat(luci::CircleFullyConnected::WeightsFormat format);
circle::DimensionType to_circle_dimensiontype(luci::DimensionType type);
flatbuffers::Offset<void> to_circle_sparse_index_vector(flatbuffers::FlatBufferBuilder &fb,
const SparseIndexVector &sparse_idx_vec);
diff --git a/compiler/luci/export/src/CircleOperationExporter.cpp b/compiler/luci/export/src/CircleOperationExporter.cpp
index c937109cd..4343cf3c9 100644
--- a/compiler/luci/export/src/CircleOperationExporter.cpp
+++ b/compiler/luci/export/src/CircleOperationExporter.cpp
@@ -21,7 +21,6 @@
#include <luci/IR/CircleNode.h>
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Service/CircleShapeInference.h>
#include <luci/UserSettings.h>
#include <luci/Log.h>
@@ -930,7 +929,8 @@ void OperationExporter::visit(luci::CircleFullyConnected *node)
{
export_simple(
node, circle::BuiltinOperator_FULLY_CONNECTED, circle::BuiltinOptions_FullyConnectedOptions,
- CreateFullyConnectedOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction()))
+ CreateFullyConnectedOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction()),
+ to_circle_weightsformat(node->weights_format()))
.Union());
}
diff --git a/compiler/luci/export/src/CircleTensorExporter.cpp b/compiler/luci/export/src/CircleTensorExporter.cpp
index 1429d2810..9bdfa0079 100644
--- a/compiler/luci/export/src/CircleTensorExporter.cpp
+++ b/compiler/luci/export/src/CircleTensorExporter.cpp
@@ -111,10 +111,10 @@ void allocateCircleTensorInfo(CircleNode *node, CircleTensorContext &ctx)
CircleTensoInfo tensor_info;
tensor_info.name(tensor_name);
- tensor_info.dtype(to_circle_tensortype(luci::node_dtype(node)));
+ tensor_info.dtype(to_circle_tensortype(node->dtype()));
tensor_info.shape_signature(node->shape_signature());
if (node->shape_status() == ShapeStatus::VALID)
- tensor_info.shape(to_shape_description(luci::node_shape(node)));
+ tensor_info.shape(to_shape_description(node));
tensor_info.shape_status(node->shape_status());
tensor_info.content(dynamic_cast<luci::CircleConst *>(node));
@@ -243,6 +243,9 @@ flatbuffers::Offset<Vector<int32_t>> encodeShape(FlatBufferBuilder &builder,
flatbuffers::Offset<Vector<int32_t>> encodeShapeSignature(FlatBufferBuilder &builder,
const ShapeSignature &shape_signature)
{
+ if (shape_signature.rank() == 0)
+ return 0;
+
return builder.CreateVector(shape_signature.as_vector());
}
diff --git a/compiler/luci/export/src/Optimize.cpp b/compiler/luci/export/src/Optimize.cpp
index 6fa50b564..036a4a2f9 100644
--- a/compiler/luci/export/src/Optimize.cpp
+++ b/compiler/luci/export/src/Optimize.cpp
@@ -18,6 +18,7 @@
#include "ProgressReporter.h"
#include <luci/Pass/ShapeInferencePass.h>
+#include <luci/Pass/ShapeSignatureInferencePass.h>
#include <luci/Pass/TypeInferencePass.h>
#include <logo/Phase.h>
@@ -34,6 +35,7 @@ void optimize(loco::Graph *g)
// prepare type and shape before optimization
phase.emplace_back(std::make_unique<TypeInferencePass>());
phase.emplace_back(std::make_unique<ShapeInferencePass>());
+ phase.emplace_back(std::make_unique<ShapeSignatureInferencePass>());
// TODO add more optimization passes (with a knob)
}
diff --git a/compiler/luci/export/src/SerializedData.h b/compiler/luci/export/src/SerializedData.h
index 46b1ac2d5..c41f50edd 100644
--- a/compiler/luci/export/src/SerializedData.h
+++ b/compiler/luci/export/src/SerializedData.h
@@ -64,7 +64,7 @@ namespace luci
{
/**
- * @breif Record the information of T/F Lite SubGraph and its mapping to loco
+ * @brief Record the information of T/F Lite SubGraph and its mapping to loco
*/
struct SubGraphContext
{
diff --git a/compiler/luci/import/include/luci/Import/CircleReader.h b/compiler/luci/import/include/luci/Import/CircleReader.h
index 8636b1d9a..8e210dd77 100644
--- a/compiler/luci/import/include/luci/Import/CircleReader.h
+++ b/compiler/luci/import/include/luci/Import/CircleReader.h
@@ -46,6 +46,8 @@ loco::DataType luci_datatype(circle::TensorType type);
FusedActFunc luci_actfunc(const circle::ActivationFunctionType type);
Padding luci_padding(const circle::Padding padding);
MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode);
+luci::CircleFullyConnected::WeightsFormat
+luci_weights_format(const circle::FullyConnectedOptionsWeightsFormat weights_format);
std::unique_ptr<CircleQuantParam>
luci_quantparam(const circle::QuantizationParametersT *quantization);
diff --git a/compiler/luci/import/src/CircleReader.cpp b/compiler/luci/import/src/CircleReader.cpp
index 068de5239..b33c920b1 100644
--- a/compiler/luci/import/src/CircleReader.cpp
+++ b/compiler/luci/import/src/CircleReader.cpp
@@ -151,6 +151,22 @@ MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode)
return MirrorPadMode::UNDEFINED;
}
+luci::CircleFullyConnected::WeightsFormat
+luci_weights_format(const circle::FullyConnectedOptionsWeightsFormat weights_format)
+{
+ switch (weights_format)
+ {
+ case circle::FullyConnectedOptionsWeightsFormat_DEFAULT:
+ return luci::CircleFullyConnected::WeightsFormat::DEFAULT;
+ case circle::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
+ return luci::CircleFullyConnected::WeightsFormat::SHUFFLED4x16INT8;
+ case circle::FullyConnectedOptionsWeightsFormat_SHUFFLED16x1FLOAT32:
+ return luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32;
+ default:
+ throw std::runtime_error("Invalid FullyConnectedOptionsWeightsFormat");
+ }
+}
+
DimensionType luci_dim_type(const circle::DimensionType dim_type)
{
switch (dim_type)
diff --git a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
index 65a863bde..17293ad7a 100644
--- a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
@@ -53,12 +53,7 @@ CircleNode *CircleFullyConnectedGraphBuilder::build_node(const circle::OperatorT
const auto *options = op.builtin_options.AsFullyConnectedOptions();
node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
- if (options->weights_format != circle::FullyConnectedOptionsWeightsFormat_DEFAULT)
- {
- throw oops::UserExn(
- "Unsupported weights format",
- circle::EnumNameFullyConnectedOptionsWeightsFormat(options->weights_format));
- }
+ node->weights_format(luci_weights_format(options->weights_format));
return node;
}
diff --git a/compiler/luci/lang/include/luci/IR/AttrDilation.h b/compiler/luci/lang/include/luci/IR/AttrDilation.h
index c2b28d77d..ed8232576 100644
--- a/compiler/luci/lang/include/luci/IR/AttrDilation.h
+++ b/compiler/luci/lang/include/luci/IR/AttrDilation.h
@@ -27,15 +27,17 @@ class Dilation final
public:
Dilation() : _w(1), _h(1) {}
- int32_t w() const { return _w; }
- void w(int32_t w) { _w = w; }
+ uint32_t w() const { return _w; }
+ void w(uint32_t w) { _w = w; }
+ void w(int32_t w);
- int32_t h() const { return _h; }
- void h(int32_t h) { _h = h; }
+ uint32_t h() const { return _h; }
+ void h(uint32_t h) { _h = h; }
+ void h(int32_t h);
private:
- int32_t _w;
- int32_t _h;
+ uint32_t _w;
+ uint32_t _h;
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/AttrFilter.h b/compiler/luci/lang/include/luci/IR/AttrFilter.h
index 7909fa523..af9d7519f 100644
--- a/compiler/luci/lang/include/luci/IR/AttrFilter.h
+++ b/compiler/luci/lang/include/luci/IR/AttrFilter.h
@@ -27,15 +27,17 @@ class Filter final
public:
Filter() : _w(1), _h(1) {}
- int32_t w() const { return _w; }
- void w(int32_t w) { _w = w; }
+ uint32_t w() const { return _w; }
+ void w(uint32_t w) { _w = w; }
+ void w(int32_t w);
- int32_t h() const { return _h; }
- void h(int32_t h) { _h = h; }
+ uint32_t h() const { return _h; }
+ void h(uint32_t h) { _h = h; }
+ void h(int32_t h);
private:
- int32_t _w;
- int32_t _h;
+ uint32_t _w;
+ uint32_t _h;
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/AttrStride.h b/compiler/luci/lang/include/luci/IR/AttrStride.h
index 654967d73..6be697975 100644
--- a/compiler/luci/lang/include/luci/IR/AttrStride.h
+++ b/compiler/luci/lang/include/luci/IR/AttrStride.h
@@ -27,15 +27,17 @@ class Stride final
public:
Stride() : _w(1), _h(1) {}
- int32_t w() const { return _w; }
- void w(int32_t w) { _w = w; }
+ uint32_t w() const { return _w; }
+ void w(uint32_t w) { _w = w; }
+ void w(int32_t w);
- int32_t h() const { return _h; }
- void h(int32_t h) { _h = h; }
+ uint32_t h() const { return _h; }
+ void h(uint32_t h) { _h = h; }
+ void h(int32_t h);
private:
- int32_t _w;
- int32_t _h;
+ uint32_t _w;
+ uint32_t _h;
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h b/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h
index 970f1b521..18a260486 100644
--- a/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h
+++ b/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h
@@ -46,6 +46,8 @@ private:
std::vector<int32_t> _shape_signature{};
};
+bool operator==(const ShapeSignature &lhs, const ShapeSignature &rhs);
+
} // namespace luci
#endif // __LUCI_IR_SHAPE_SIGNATURE_H__
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
index d78f39494..952befc87 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
@@ -35,6 +35,16 @@ class CircleFullyConnected final
public LuciNodeMixin<LuciNodeTrait::Bias>
{
public:
+ enum class WeightsFormat
+ {
+ UNDEFINED, // This is not defined by Circle. This was added to prevent programming error.
+
+ DEFAULT,
+ SHUFFLED4x16INT8,
+ SHUFFLED16x1FLOAT32,
+ };
+
+public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
@@ -43,6 +53,13 @@ public:
loco::Node *bias(void) const override { return at(2)->node(); }
void bias(loco::Node *node) override { at(2)->node(node); }
+
+public:
+ WeightsFormat weights_format(void) const { return _weights_format; }
+ void weights_format(WeightsFormat weights_format) { _weights_format = weights_format; }
+
+private:
+ WeightsFormat _weights_format{WeightsFormat::DEFAULT};
};
} // namespace luci
diff --git a/compiler/luci/lang/src/AttrDilation.cpp b/compiler/luci/lang/src/AttrDilation.cpp
new file mode 100644
index 000000000..a9f479502
--- /dev/null
+++ b/compiler/luci/lang/src/AttrDilation.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/AttrDilation.h"
+
+#include <cassert>
+
+namespace luci
+{
+
+void Dilation::w(int32_t w)
+{
+ assert(w >= 0);
+ _w = static_cast<uint32_t>(w);
+}
+
+void Dilation::h(int32_t h)
+{
+ assert(h >= 0);
+ _h = static_cast<uint32_t>(h);
+}
+
+} // namespace luci
diff --git a/compiler/luci/lang/src/AttrDilation.test.cpp b/compiler/luci/lang/src/AttrDilation.test.cpp
new file mode 100644
index 000000000..3e4658990
--- /dev/null
+++ b/compiler/luci/lang/src/AttrDilation.test.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/AttrDilation.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleAttrDilationTest, set)
+{
+ auto d = luci::Dilation();
+
+ d.h(10u);
+ d.w(10u);
+
+ ASSERT_EQ(d.h(), 10u);
+ ASSERT_EQ(d.w(), 10u);
+
+ d.h(10); // int32_t
+ d.w(10);
+
+ ASSERT_EQ(d.h(), 10u);
+ ASSERT_EQ(d.w(), 10u);
+}
diff --git a/compiler/luci/lang/src/AttrFilter.cpp b/compiler/luci/lang/src/AttrFilter.cpp
new file mode 100644
index 000000000..9c571e7f5
--- /dev/null
+++ b/compiler/luci/lang/src/AttrFilter.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/AttrFilter.h"
+
+#include <cassert>
+
+namespace luci
+{
+
+void Filter::w(int32_t w)
+{
+ assert(w >= 0);
+ _w = static_cast<uint32_t>(w);
+}
+
+void Filter::h(int32_t h)
+{
+ assert(h >= 0);
+ _h = static_cast<uint32_t>(h);
+}
+
+} // namespace luci
diff --git a/compiler/luci/lang/src/AttrFilter.test.cpp b/compiler/luci/lang/src/AttrFilter.test.cpp
new file mode 100644
index 000000000..06dbcacd5
--- /dev/null
+++ b/compiler/luci/lang/src/AttrFilter.test.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/AttrFilter.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleAttrFilterTest, set)
+{
+ auto f = luci::Filter();
+
+ f.h(10u);
+ f.w(10u);
+
+ ASSERT_EQ(f.h(), 10u);
+ ASSERT_EQ(f.w(), 10u);
+
+ f.h(10); // int32_t
+ f.w(10);
+
+ ASSERT_EQ(f.h(), 10u);
+ ASSERT_EQ(f.w(), 10u);
+}
diff --git a/compiler/luci/lang/src/AttrStride.cpp b/compiler/luci/lang/src/AttrStride.cpp
new file mode 100644
index 000000000..9720d12b5
--- /dev/null
+++ b/compiler/luci/lang/src/AttrStride.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/AttrStride.h"
+
+#include <cassert>
+
+namespace luci
+{
+
+void Stride::w(int32_t w)
+{
+ assert(w >= 0);
+ _w = static_cast<uint32_t>(w);
+}
+
+void Stride::h(int32_t h)
+{
+ assert(h >= 0);
+ _h = static_cast<uint32_t>(h);
+}
+
+} // namespace luci
diff --git a/compiler/luci/lang/src/AttrStride.test.cpp b/compiler/luci/lang/src/AttrStride.test.cpp
new file mode 100644
index 000000000..e91365bd5
--- /dev/null
+++ b/compiler/luci/lang/src/AttrStride.test.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/AttrStride.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleAttrStrideTest, set)
+{
+ auto s = luci::Stride();
+
+ s.h(10u);
+ s.w(10u);
+
+ ASSERT_EQ(s.h(), 10u);
+ ASSERT_EQ(s.w(), 10u);
+
+ s.h(10); // int32_t
+ s.w(10);
+
+ ASSERT_EQ(s.h(), 10u);
+ ASSERT_EQ(s.w(), 10u);
+}
diff --git a/compiler/luci/lang/src/CircleShapeSignature.cpp b/compiler/luci/lang/src/CircleShapeSignature.cpp
new file mode 100644
index 000000000..970000203
--- /dev/null
+++ b/compiler/luci/lang/src/CircleShapeSignature.cpp
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/CircleShapeSignature.h"
+
+namespace luci
+{
+
+bool operator==(const ShapeSignature &lhs, const ShapeSignature &rhs)
+{
+ if (lhs.rank() != rhs.rank())
+ return false;
+
+ for (uint32_t i = 0; i < lhs.rank(); ++i)
+ if (lhs.dim(i) != rhs.dim(i))
+ return false;
+
+ return true;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h
index db5bdb501..906760e0a 100644
--- a/compiler/luci/pass/include/luci/CircleOptimizer.h
+++ b/compiler/luci/pass/include/luci/CircleOptimizer.h
@@ -19,6 +19,8 @@
#include <loco.h>
+#include <luci/IR/Module.h>
+
#include <string>
#include <vector>
@@ -47,6 +49,10 @@ public:
FusePreActivationBatchNorm,
MakeBatchNormGammaPositive,
FuseActivationFunction,
+ ShuffleWeightTo16x1Float32,
+ RemoveRedundantTranspose,
+ ReplaceMulAddWithDepthwiseConv,
+ SubstitutePackToReshape,
};
enum AlgorithmParameters
@@ -77,6 +83,8 @@ public:
Options *options(void);
public:
+ void optimize(luci::Module *) const;
+
void optimize(loco::Graph *) const;
void quantize(loco::Graph *) const;
diff --git a/compiler/luci/pass/include/luci/ModulePass.h b/compiler/luci/pass/include/luci/ModulePass.h
new file mode 100644
index 000000000..1835f6e0c
--- /dev/null
+++ b/compiler/luci/pass/include/luci/ModulePass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MODULE_PASS_H__
+#define __MODULE_PASS_H__
+
+#include <loco.h>
+#include <logo/Pass.h>
+
+#include <luci/IR/Module.h>
+
+namespace luci
+{
+
+class Pass : public logo::Pass
+{
+public:
+ // Run module pass and return false if there was nothing changed
+ virtual bool run(luci::Module *) = 0;
+};
+
+} // namespace luci
+
+#endif // __MODULE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/CircleTypeInferencePass.h b/compiler/luci/pass/include/luci/Pass/CircleTypeInferencePass.h
new file mode 100644
index 000000000..379b44ccd
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/CircleTypeInferencePass.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CIRCLE_TYPE_INFERENCE_PASS_H__
+#define __LUCI_CIRCLE_TYPE_INFERENCE_PASS_H__
+
+#include <loco.h>
+
+#include <luci/ModulePass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Pass to infer type of circle nodes
+ */
+class CircleTypeInferencePass : public luci::Pass
+{
+public:
+ virtual const char *name(void) const { return "luci::CircleTypeInferencePass"; }
+
+public:
+ bool run(luci::Module *m);
+ bool run(loco::Graph *g);
+};
+
+} // namespace luci
+
+#endif //__LUCI_CIRCLE_TYPE_INFERENCE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/FuseBCQPass.h b/compiler/luci/pass/include/luci/Pass/FuseBCQPass.h
index 4404a9fc9..912ad4225 100644
--- a/compiler/luci/pass/include/luci/Pass/FuseBCQPass.h
+++ b/compiler/luci/pass/include/luci/Pass/FuseBCQPass.h
@@ -17,7 +17,7 @@
#ifndef __LUCI_FUSE_BCQ_PASS_H__
#define __LUCI_FUSE_BCQ_PASS_H__
-#include <logo/Pass.h>
+#include <luci/ModulePass.h>
namespace luci
{
@@ -26,10 +26,11 @@ namespace luci
* @brief Class to fuse certain pattern of subgraph into CircleBCQFullyConnected or CircleBCQGather
*
*/
-struct FuseBCQPass final : public logo::Pass
+struct FuseBCQPass final : public luci::Pass
{
const char *name(void) const final { return "luci::FuseBCQPass"; }
+ bool run(luci::Module *m) final;
bool run(loco::Graph *g) final;
};
diff --git a/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h b/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h
new file mode 100644
index 000000000..c0ebc4e5d
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__
+#define __LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__
+
+#include <loco.h>
+
+#include <luci/ModulePass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Pass to copy shape/dtype of loco to circle node
+ *
+ * CAUTION : This pass will be removed after refactoring is finished
+ */
+class MigrateLegacyShapeDtypePass : public luci::Pass
+{
+public:
+ virtual const char *name(void) const { return "luci::MigrateLegacyShapeDtypePass"; }
+
+public:
+ bool run(luci::Module *m);
+ bool run(loco::Graph *graph);
+};
+
+} // namespace luci
+
+#endif //__LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h b/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h
new file mode 100644
index 000000000..7e0c44b8c
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__
+#define __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to propagate quantization parameters of an operator's output to input
+ */
+struct PropagateQuantParamPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::PropagateQuantParamPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/RemoveRedundantTransposePass.h b/compiler/luci/pass/include/luci/Pass/RemoveRedundantTransposePass.h
new file mode 100644
index 000000000..ca20da5ac
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/RemoveRedundantTransposePass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REMOVE_REDUNDANT_TRANSPOSE_H__
+#define __LUCI_REMOVE_REDUNDANT_TRANSPOSE_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief fuse or remove subsequent Transpose operators
+ */
+struct RemoveRedundantTransposePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::RemoveRedundantTransposePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REMOVE_REDUNDANT_TRANSPOSE_H__
diff --git a/compiler/luci/pass/include/luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h b/compiler/luci/pass/include/luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h
new file mode 100644
index 000000000..5dbcc8f5b
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REPLACE_MUL_ADD_WITH_DEPTHWISE_CONV_PASS_H__
+#define __LUCI_REPLACE_MUL_ADD_WITH_DEPTHWISE_CONV_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to replace channel-wise mul/add with CircleDepthwiseConv2D
+ */
+struct ReplaceMulAddWithDepthwiseConvPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::ReplaceMulAddWithDepthwiseConvPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REPLACE_MUL_ADD_WITH_DEPTHWISE_CONV_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h b/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h
index 86bb2ab42..e21ab4cce 100644
--- a/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h
+++ b/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h
@@ -19,7 +19,7 @@
#include <loco.h>
-#include <logo/Pass.h>
+#include <luci/ModulePass.h>
namespace luci
{
@@ -27,12 +27,13 @@ namespace luci
/**
* @brief Pass to infer shape of nodes
*/
-class ShapeInferencePass : public logo::Pass
+class ShapeInferencePass : public luci::Pass
{
public:
virtual const char *name(void) const { return "luci::ShapeInferencePass"; }
public:
+ bool run(luci::Module *m);
bool run(loco::Graph *graph);
};
diff --git a/compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h b/compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h
new file mode 100644
index 000000000..2c6ffcf4e
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__
+#define __LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__
+
+#include <loco.h>
+
+#include <luci/ModulePass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Pass to infer shape_signature of nodes
+ */
+class ShapeSignatureInferencePass : public luci::Pass
+{
+public:
+ virtual const char *name(void) const { return "luci::ShapeSignatureInferencePass"; }
+
+public:
+ bool run(luci::Module *m);
+ bool run(loco::Graph *graph);
+};
+
+} // namespace luci
+
+#endif //__LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/ShuffleWeightTo16x1Float32Pass.h b/compiler/luci/pass/include/luci/Pass/ShuffleWeightTo16x1Float32Pass.h
new file mode 100644
index 000000000..3d84f5133
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/ShuffleWeightTo16x1Float32Pass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_SHUFFLE_WEIGHT_TO_16X1_FLOAT32_PASS_H__
+#define __LUCI_SHUFFLE_WEIGHT_TO_16X1_FLOAT32_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to convert weight format of FullyConnected to SHUFFLED16x1FLOAT32
+ */
+struct ShuffleWeightTo16x1Float32Pass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::ShuffleWeightTo16x1Float32Pass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_SHUFFLE_WEIGHT_TO_16X1_FLOAT32_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/SubstitutePackToReshapePass.h b/compiler/luci/pass/include/luci/Pass/SubstitutePackToReshapePass.h
new file mode 100644
index 000000000..36d13f19f
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/SubstitutePackToReshapePass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_SUBSTITUTE_PACK_TO_RESHAPE_PASS_H__
+#define __LUCI_SUBSTITUTE_PACK_TO_RESHAPE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to Substitute Pack with 1 input to single reshape node.
+ */
+struct SubstitutePackToReshapePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::SubstitutePackToReshapePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_SUBSTITUTE_PACK_TO_RESHAPE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h b/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h
index c607ac63f..9d964bdd6 100644
--- a/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h
+++ b/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h
@@ -20,7 +20,7 @@
#include <loco.h>
-#include <logo/Pass.h>
+#include <luci/ModulePass.h>
namespace luci
{
@@ -28,12 +28,13 @@ namespace luci
/**
* @brief Pass to infer type of nodes
*/
-class TypeInferencePass : public logo::Pass
+class TypeInferencePass : public luci::Pass
{
public:
virtual const char *name(void) const { return "luci::TypeInferencePass"; }
public:
+ bool run(luci::Module *m);
bool run(loco::Graph *graph);
};
diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp
index 34f647301..cc9fe481c 100644
--- a/compiler/luci/pass/src/CircleOptimizer.cpp
+++ b/compiler/luci/pass/src/CircleOptimizer.cpp
@@ -24,6 +24,9 @@
#include "luci/Pass/FuseInstanceNormPass.h"
#include "luci/Pass/FusePreActivationBatchNormPass.h"
#include "luci/Pass/MakeBatchNormGammaPositivePass.h"
+#include "luci/Pass/PropagateQuantParamPass.h"
+#include "luci/Pass/RemoveRedundantTransposePass.h"
+#include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
#include "luci/Pass/ResolveCustomOpAddPass.h"
#include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
#include "luci/Pass/ResolveCustomOpMatMulPass.h"
@@ -31,14 +34,21 @@
#include "luci/Pass/QuantizeWithMinMaxPass.h"
#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
#include "luci/Pass/SparsifyTensorPass.h"
+#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
+#include "luci/Pass/SubstitutePackToReshapePass.h"
// TODO add more passes
#include "luci/Pass/ShapeInferencePass.h"
+#include "luci/Pass/ShapeSignatureInferencePass.h"
#include "luci/Pass/TypeInferencePass.h"
+// Following passes will be removed after refactoring is finished
+#include "luci/Pass/MigrateLegacyShapeDtypePass.h"
+
// logo passes
#include <logo/RemoveDeadNodeWithQueryPass.h>
+#include "ModulePhase.h"
#include "ProgressReporter.h"
#include "CircleOptimizerUtils.h"
@@ -124,11 +134,44 @@ CircleOptimizer::Options *CircleOptimizer::options(void)
return _options.get();
}
+void CircleOptimizer::optimize(luci::Module *m) const
+{
+ luci::Phase phase;
+
+ // Following passes will be deprecated after refactoring is finished.
+ phase.emplace_back(std::make_unique<luci::MigrateLegacyShapeDtypePass>());
+
+ // Following passes are needed everytime when other passes create new node or modify some nodes.
+ phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::ShapeSignatureInferencePass>());
+ phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
+
+ if (_options->query(Options::Algorithm::FuseBCQ))
+ {
+ phase.emplace_back(std::make_unique<FuseBCQPass>());
+ }
+
+ ModuleProgressReporter prog(m, logo::PhaseStrategy::Restart);
+ PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{m};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+}
+
void CircleOptimizer::optimize(loco::Graph *g) const
{
logo::Phase phase;
/* TRANSFORM DECLARATION BEGIN */
+ phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
+
+ // Following passes will be deprecated after refactoring is finished.
+ phase.emplace_back(std::make_unique<luci::MigrateLegacyShapeDtypePass>());
+
+ // Following passes are needed everytime when other passes create new node or modify some nodes.
+ phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::ShapeSignatureInferencePass>());
+
if (_options->query(Options::Algorithm::ResolveCustomOpAdd))
{
phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
@@ -145,10 +188,6 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
}
- if (_options->query(Options::Algorithm::FuseBCQ))
- {
- phase.emplace_back(std::make_unique<FuseBCQPass>());
- }
if (_options->query(Options::Algorithm::FuseBatchNormWithTConv))
{
phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>());
@@ -173,15 +212,27 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::MakeBatchNormGammaPositivePass>());
}
+ if (_options->query(Options::Algorithm::ShuffleWeightTo16x1Float32))
+ {
+ phase.emplace_back(std::make_unique<luci::ShuffleWeightTo16x1Float32Pass>());
+ }
+ if (_options->query(Options::Algorithm::RemoveRedundantTranspose))
+ {
+ phase.emplace_back(std::make_unique<luci::RemoveRedundantTransposePass>());
+ }
+ if (_options->query(Options::Algorithm::ReplaceMulAddWithDepthwiseConv))
+ {
+ phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>());
+ }
+ if (_options->query(Options::Algorithm::SubstitutePackToReshape))
+ {
+ phase.emplace_back(std::make_unique<luci::SubstitutePackToReshapePass>());
+ }
- // Shape inference is needed for added nodes doing above transformations
- phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
- phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
- phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
/* TRANSFORM DECLARATION END */
- ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
- logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ ProgressReporter prog(g, logo::PhaseStrategy::Restart);
+ logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
phase_runner.attach(&prog);
phase_runner.run(phase);
}
@@ -258,6 +309,20 @@ void CircleOptimizer::quantize(loco::Graph *g) const
luci::QuantizeWithMinMaxPass quantizer(str_to_dtype(input_dtype), str_to_dtype(output_dtype),
str_to_granularity(granularity));
quantizer.run(g);
+
+ // Post-quantization optimizations
+ logo::Phase phase;
+
+ phase.emplace_back(std::make_unique<luci::PropagateQuantParamPass>());
+
+ phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
+ phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
}
// Requantize
diff --git a/compiler/luci/pass/src/CircleTypeInferencePass.cpp b/compiler/luci/pass/src/CircleTypeInferencePass.cpp
new file mode 100644
index 000000000..67bd253e0
--- /dev/null
+++ b/compiler/luci/pass/src/CircleTypeInferencePass.cpp
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/CircleTypeInferencePass.h"
+
+#include <luci/Service/CircleTypeInference.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool CircleTypeInferencePass::run(luci::Module *m)
+{
+ bool changed = false;
+
+ for (size_t g = 0; g < m->size(); ++g)
+ {
+ if (run(m->graph(g)))
+ changed = true;
+ }
+
+ return changed;
+}
+
+bool CircleTypeInferencePass::run(loco::Graph *g)
+{
+ luci::tinf::Rule type_infer_rule;
+ bool changed = false;
+
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ loco::DataType dtype;
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+
+ if (type_infer_rule.infer(circle_node, dtype) && circle_node->dtype() != dtype)
+ {
+ circle_node->dtype(dtype);
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FuseBCQPass.cpp b/compiler/luci/pass/src/FuseBCQPass.cpp
index ebf28779b..c0583d848 100644
--- a/compiler/luci/pass/src/FuseBCQPass.cpp
+++ b/compiler/luci/pass/src/FuseBCQPass.cpp
@@ -25,6 +25,85 @@
namespace
{
+bool is_fusable_const(luci::CircleConst *before, luci::CircleConst *after, bool do_w_x)
+{
+ if (after->dtype() != loco::DataType::FLOAT32)
+ return false;
+
+ if (after->rank() != 2)
+ return false;
+
+ if (after->size<loco::DataType::FLOAT32>() != before->size<loco::DataType::FLOAT32>())
+ return false;
+
+ auto after_dim0 = after->dim(0).value();
+ auto after_dim1 = after->dim(1).value();
+
+ if (before->rank() == 2)
+ {
+ if (do_w_x)
+ {
+ // Check for [dim0, dim1] --> [dim0, dim1]
+ if (!(after->dim(0) == before->dim(0) && after->dim(1) == before->dim(1)))
+ return false;
+
+ for (uint32_t i = 0; i < after->size<loco::DataType::FLOAT32>(); ++i)
+ if (after->at<loco::DataType::FLOAT32>(i) != before->at<loco::DataType::FLOAT32>(i))
+ return false;
+ }
+ else
+ {
+ // Check for [dim0, dim1] --> [dim1, dim0]
+ if (!(after->dim(0) == before->dim(1) && after->dim(1) == before->dim(0)))
+ return false;
+
+ for (uint32_t i = 0; i < after_dim0; ++i)
+ for (uint32_t j = 0; j < after_dim1; ++j)
+ if (after->at<loco::DataType::FLOAT32>(i * after_dim1 + j) !=
+ before->at<loco::DataType::FLOAT32>(j * after_dim0 + i))
+ return false;
+ }
+
+ return true;
+ }
+ else if (before->rank() == 3)
+ {
+ if (do_w_x)
+ {
+ // This case is not found yet.
+ return false;
+ }
+ else
+ {
+ // When Einsum op is converted to FullyConnected, original rank can be 3.
+ auto before_dim0 = before->dim(0).value();
+ auto before_dim1 = before->dim(1).value();
+ auto before_dim2 = before->dim(2).value();
+
+ // Check if [dim0, dim1, dim2] --> [dim2, dim0 * dim1] or
+ // [dim0, dim1, dim2] --> [dim1 * dim2, dim0]
+ if ((after_dim0 == before_dim1 * before_dim2 && after_dim1 == before_dim0) ||
+ (after_dim0 == before_dim2 && after_dim1 == before_dim0 * before_dim1))
+ {
+ for (uint32_t i = 0; i < after_dim0; ++i)
+ for (uint32_t j = 0; j < after_dim1; ++j)
+ if (after->at<loco::DataType::FLOAT32>(i * after_dim1 + j) !=
+ before->at<loco::DataType::FLOAT32>(j * after_dim0 + i))
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ return false;
+}
+
+} // namespace
+
+namespace
+{
+
// V means the version of BCQ.
template <int32_t V> class BCQFuser;
@@ -38,11 +117,9 @@ public:
}
public:
- bool fuseBCQ(loco::Graph *g)
+ void register_bcq_info(loco::Graph *g)
{
-
- const auto output_nodes = loco::output_nodes(g);
- for (auto node : output_nodes)
+ for (auto node : loco::output_nodes(g))
{
auto output_node = loco::must_cast<luci::CircleOutput *>(node);
@@ -61,28 +138,29 @@ public:
add_BCQ_info_node(prefix, metadata_type, circle_node);
}
}
+ }
+ bool fuseBCQ(loco::Graph *g)
+ {
if (!is_bcqinfo_valid())
return false;
- for (auto f : _fusable_op)
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
{
- auto prefix = f.first;
- luci::CircleNode *node = f.second;
-
- if (!is_valid_prefix(prefix))
- continue;
-
// Fuse Gather to BCQGather
if (auto gather = dynamic_cast<luci::CircleGather *>(node))
{
if (auto params = dynamic_cast<luci::CircleConst *>(gather->params()))
{
+ auto prefix = get_prefix_of_const(params);
+ if (prefix == -1 || !is_valid_prefix(prefix))
+ continue;
+
auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>();
bcq_gather->op_version(1);
- bcq_gather->input_scales(_alpha[prefix]);
- bcq_gather->input_binary(_packed_binary_code[prefix]);
+ bcq_gather->input_scales(alpha(g, prefix));
+ bcq_gather->input_binary(packed_binary_code(g, prefix));
bcq_gather->indices(gather->indices());
bcq_gather->input_clusters(packed_clusters(g, prefix));
@@ -122,29 +200,20 @@ public:
}
}
- // Einsum is unpacked to FullyConnected, Pack and Reshape
- if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
- {
- node = dynamic_cast<luci::CircleNode *>(reshape->tensor());
- }
- if (auto pack = dynamic_cast<luci::CirclePack *>(node))
- {
- if (pack->values_count() == 1 && pack->rank() == 3)
- {
- node = dynamic_cast<luci::CircleNode *>(pack->values(0));
- }
- }
-
// Fuse FullyConnected to BCQFullyConnected
if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node))
{
if (auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights()))
{
+ auto prefix = get_prefix_of_const(weights);
+ if (prefix == -1 || !is_valid_prefix(prefix))
+ continue;
+
auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
bcq_fc->op_version(1);
- bcq_fc->weights_scales(_alpha[prefix]);
- bcq_fc->weights_binary(_packed_binary_code[prefix]);
+ bcq_fc->weights_scales(alpha(g, prefix));
+ bcq_fc->weights_binary(packed_binary_code(g, prefix));
bcq_fc->bias(fully_connected->bias());
bcq_fc->weights_clusters(packed_clusters(g, prefix));
bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction());
@@ -179,43 +248,69 @@ public:
}
// If x_w formation, we should insert Transpose in front and back of BCQFullyConnected
- if (_do_w_x[prefix]->at<loco::DataType::BOOL>(0))
- {
- bcq_fc->weights_hidden_size(weights->dim(0).value());
- bcq_fc->input(bcq_input);
- loco::replace(fully_connected).with(bcq_fc);
- }
- else
- {
- bcq_fc->weights_hidden_size(weights->dim(1).value());
+ bcq_fc->weights_hidden_size(weights->dim(1).value());
- auto perm = g->nodes()->create<luci::CircleConst>();
- perm->dtype(loco::DataType::S32);
- perm->size<loco::DataType::S32>(2);
- perm->rank(1);
- perm->dim(0) = 2;
- perm->at<loco::DataType::S32>(0) = 1;
- perm->at<loco::DataType::S32>(1) = 0;
- perm->shape_status(luci::ShapeStatus::VALID);
+ auto perm = g->nodes()->create<luci::CircleConst>();
+ perm->dtype(loco::DataType::S32);
+ perm->size<loco::DataType::S32>(2);
+ perm->rank(1);
+ perm->dim(0) = 2;
+ perm->at<loco::DataType::S32>(0) = 1;
+ perm->at<loco::DataType::S32>(1) = 0;
+ perm->shape_status(luci::ShapeStatus::VALID);
- auto input_transpose = g->nodes()->create<luci::CircleTranspose>();
- input_transpose->a(bcq_input);
- input_transpose->perm(perm);
+ auto input_transpose = g->nodes()->create<luci::CircleTranspose>();
+ input_transpose->a(bcq_input);
+ input_transpose->perm(perm);
- bcq_fc->input(input_transpose);
+ bcq_fc->input(input_transpose);
- auto output_transpose = g->nodes()->create<luci::CircleTranspose>();
- output_transpose->a(bcq_fc);
- output_transpose->perm(perm);
+ auto output_transpose = g->nodes()->create<luci::CircleTranspose>();
+ output_transpose->a(bcq_fc);
+ output_transpose->perm(perm);
- loco::replace(fully_connected).with(output_transpose);
- }
+ loco::replace(fully_connected).with(output_transpose);
return true;
}
- else
+ else if (auto weights_as_input =
+ dynamic_cast<luci::CircleConst *>(fully_connected->input()))
{
- // TODO Is there any case that input() is constant, instead of weights()?
+ auto prefix = get_prefix_of_const(weights_as_input);
+ if (prefix == -1 || !is_valid_prefix(prefix))
+ continue;
+
+ assert(_do_w_x[prefix]->at<loco::DataType::BOOL>(0) == true);
+
+ auto perm = g->nodes()->create<luci::CircleConst>();
+ perm->dtype(loco::DataType::S32);
+ perm->size<loco::DataType::S32>(2);
+ perm->rank(1);
+ perm->dim(0) = 2;
+ perm->at<loco::DataType::S32>(0) = 1;
+ perm->at<loco::DataType::S32>(1) = 0;
+ perm->shape_status(luci::ShapeStatus::VALID);
+
+ auto input_transpose = g->nodes()->create<luci::CircleTranspose>();
+ input_transpose->a(fully_connected->weights());
+ input_transpose->perm(perm);
+
+ auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
+
+ assert(dynamic_cast<luci::CircleOutputExclude *>(fully_connected->bias()) != nullptr);
+
+ bcq_fc->op_version(1);
+ bcq_fc->weights_scales(alpha(g, prefix));
+ bcq_fc->weights_binary(packed_binary_code(g, prefix));
+ bcq_fc->bias(fully_connected->bias());
+ bcq_fc->weights_clusters(packed_clusters(g, prefix));
+ bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction());
+
+ bcq_fc->weights_hidden_size(weights_as_input->dim(1).value());
+ bcq_fc->input(input_transpose);
+ loco::replace(fully_connected).with(bcq_fc);
+
+ return true;
}
}
}
@@ -268,6 +363,19 @@ private:
_dequant_weight[prefix] = const_node;
}
+ int32_t get_prefix_of_const(luci::CircleConst *w_after)
+ {
+ for (auto n : _fusable_op)
+ {
+ auto prefix = n.first;
+ auto w_before = loco::must_cast<luci::CircleConst *>(n.second);
+ if (is_fusable_const(w_before, w_after, _do_w_x[prefix]->at<loco::DataType::BOOL>(0)))
+ return prefix;
+ }
+
+ return -1;
+ }
+
bool is_bcqinfo_valid()
{
LOGGER(l);
@@ -332,6 +440,16 @@ private:
}
}
+ for (auto n : _fusable_op)
+ {
+ // fusable_op should be FLOAT32 type
+ if (n.second->dtype() != loco::DataType::FLOAT32)
+ {
+ WARN(l) << "FuseBCQPass : fusable_op has wrong type" << std::endl;
+ return false;
+ }
+ }
+
// As dequant_weight is not used for fusing, skip validation.
return true;
@@ -377,12 +495,50 @@ private:
return false;
}
+ if (_fusable_op.find(prefix) == _fusable_op.end())
+ {
+ WARN(l) << "fusable_op is not found" << std::endl;
+ return false;
+ }
+
// As dequant_weight is not used for fusing, skip validation.
return true;
}
private:
+ luci::CircleConst *alpha(loco::Graph *graph, int32_t prefix)
+ {
+ auto new_alpha = graph->nodes()->create<luci::CircleConst>();
+
+ new_alpha->dtype(loco::DataType::FLOAT32);
+ new_alpha->size<loco::DataType::FLOAT32>(_alpha[prefix]->size<loco::DataType::FLOAT32>());
+ new_alpha->rank(1);
+ new_alpha->dim(0) = _alpha[prefix]->dim(0);
+ for (uint32_t i = 0; i < _alpha[prefix]->size<loco::DataType::FLOAT32>(); ++i)
+ new_alpha->at<loco::DataType::FLOAT32>(i) = _alpha[prefix]->at<loco::DataType::FLOAT32>(i);
+ new_alpha->shape_status(luci::ShapeStatus::VALID);
+
+ return new_alpha;
+ }
+
+ luci::CircleConst *packed_binary_code(loco::Graph *graph, int32_t prefix)
+ {
+ auto new_beta = graph->nodes()->create<luci::CircleConst>();
+
+ new_beta->dtype(loco::DataType::S32);
+ new_beta->size<loco::DataType::S32>(_packed_binary_code[prefix]->size<loco::DataType::S32>());
+ new_beta->rank(2);
+ new_beta->dim(0) = _packed_binary_code[prefix]->dim(0);
+ new_beta->dim(1) = _packed_binary_code[prefix]->dim(1);
+ for (uint32_t i = 0; i < _packed_binary_code[prefix]->size<loco::DataType::S32>(); ++i)
+ new_beta->at<loco::DataType::S32>(i) =
+ _packed_binary_code[prefix]->at<loco::DataType::S32>(i);
+ new_beta->shape_status(luci::ShapeStatus::VALID);
+
+ return new_beta;
+ }
+
luci::CircleConst *packed_clusters(loco::Graph *graph, int32_t prefix)
{
auto qbits_of_clusters = _qbits_of_clusters[prefix];
@@ -428,15 +584,17 @@ private:
namespace luci
{
-bool FuseBCQPass::run(loco::Graph *g)
+bool FuseBCQPass::run(luci::Module *m)
{
bool changed = false;
const int32_t start_magicnum = -2e9 + 27;
const int32_t end_magicnum = 2e9 - 27;
+ loco::Graph *main_graph = m->graph(0);
+
luci::CircleConst *metadata_node = nullptr;
- for (auto node : loco::output_nodes(g))
+ for (auto node : loco::output_nodes(main_graph))
{
auto output_node = loco::must_cast<luci::CircleOutput *>(node);
@@ -474,8 +632,11 @@ bool FuseBCQPass::run(loco::Graph *g)
const auto bundle_cnt = metadata_node->at<loco::DataType::S32>(3);
BCQFuser<1> fuser{original_output_cnt, bundle_cnt};
- if (fuser.fuseBCQ(g))
- changed = true;
+ fuser.register_bcq_info(main_graph);
+
+ for (size_t g = 0; g < m->size(); ++g)
+ if (fuser.fuseBCQ(m->graph(g)))
+ changed = true;
}
else
{
@@ -486,12 +647,12 @@ bool FuseBCQPass::run(loco::Graph *g)
// Remove all of BCQ information nodes iff there is no change
if (changed == false)
{
- for (auto node : loco::output_nodes(g))
+ for (auto node : loco::output_nodes(main_graph))
{
auto output_node = loco::must_cast<luci::CircleOutput *>(node);
if (output_node->index() == 0 || (int)output_node->index() > original_output_cnt)
{
- auto noOp = g->nodes()->create<luci::CircleOutputExclude>();
+ auto noOp = main_graph->nodes()->create<luci::CircleOutputExclude>();
noOp->dtype(loco::DataType::FLOAT32); // TODO Remove this setting
output_node->from(noOp);
changed = true;
@@ -503,4 +664,10 @@ bool FuseBCQPass::run(loco::Graph *g)
return changed;
}
+bool FuseBCQPass::run(loco::Graph *)
+{
+ // Do nothing for graph
+ return false;
+}
+
} // namespace luci
diff --git a/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp b/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp
new file mode 100644
index 000000000..beb962a05
--- /dev/null
+++ b/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp
@@ -0,0 +1,112 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/MigrateLegacyShapeDtypePass.h"
+
+#include <loco/Service/ShapeInference.h>
+#include <loco/Service/TypeInference.h>
+
+#include <luci/IR/CircleNodes.h>
+
+#include <loco.h>
+
+namespace
+{
+
+bool has_same_shape(luci::CircleNode *node, loco::TensorShape shape)
+{
+ if (node->rank() != shape.rank())
+ return false;
+
+ for (uint32_t i = 0; i < shape.rank(); ++i)
+ if (!(node->dim(i) == shape.dim(i)))
+ return false;
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool MigrateLegacyShapeDtypePass::run(luci::Module *m)
+{
+ bool changed = false;
+
+ for (size_t g = 0; g < m->size(); ++g)
+ {
+ if (run(m->graph(g)))
+ changed = true;
+ }
+
+ return changed;
+}
+
+bool MigrateLegacyShapeDtypePass::run(loco::Graph *g)
+{
+ bool changed = false;
+
+ for (auto node : loco::all_nodes(g))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (loco::shape_known(node))
+ {
+ auto loco_shape = loco::shape_get(node).as<loco::TensorShape>();
+
+ assert(circle_node->shape_signature().rank() == 0 ||
+ circle_node->shape_signature().rank() == loco_shape.rank());
+
+ // When shape of loco is copied to circle node, ShapeSignature should be applied.
+ loco::TensorShape new_shape;
+ new_shape.rank(loco_shape.rank());
+ for (uint32_t i = 0; i < loco_shape.rank(); ++i)
+ {
+ if (circle_node->shape_signature().rank() > 0 &&
+ circle_node->shape_signature().dim(i) == -1)
+ new_shape.dim(i) = 1;
+ else
+ new_shape.dim(i) = loco_shape.dim(i);
+ }
+
+ if (circle_node->shape_status() == luci::ShapeStatus::UNDEFINED ||
+ !has_same_shape(circle_node, new_shape))
+ {
+ circle_node->rank(new_shape.rank());
+ for (uint32_t i = 0; i < new_shape.rank(); ++i)
+ circle_node->dim(i) = new_shape.dim(i);
+
+ if (circle_node->shape_status() == luci::ShapeStatus::UNDEFINED)
+ circle_node->shape_status(luci::ShapeStatus::VALID);
+
+ changed = true;
+ }
+ }
+
+ if (loco::dtype_known(node))
+ {
+ if (loco::dtype_get(node) != circle_node->dtype())
+ {
+ circle_node->dtype(loco::dtype_get(node));
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ModulePhase.cpp b/compiler/luci/pass/src/ModulePhase.cpp
new file mode 100644
index 000000000..46819a0f7
--- /dev/null
+++ b/compiler/luci/pass/src/ModulePhase.cpp
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ModulePhase.h"
+
+namespace luci
+{
+
+void PhaseRunner<logo::PhaseStrategy::Saturate>::run(const Phase &phase) const
+{
+ notifyPhaseBegin();
+
+ for (bool changed = true; changed;)
+ {
+ changed = false;
+
+ for (auto &pass : phase)
+ {
+ notifyPassBegin(pass.get());
+
+ bool pass_changed = pass->run(_module);
+ changed = changed || pass_changed;
+
+ notifyPassEnd(pass.get(), pass_changed);
+ }
+ }
+
+ notifyPhaseEnd();
+}
+
+void PhaseRunner<logo::PhaseStrategy::Restart>::run(const Phase &phase) const
+{
+ notifyPhaseBegin();
+
+ for (bool changed = true; changed;)
+ {
+ changed = false;
+
+ for (auto &pass : phase)
+ {
+ notifyPassBegin(pass.get());
+
+ bool pass_changed = pass->run(_module);
+ changed = changed || pass_changed;
+
+ notifyPassEnd(pass.get(), pass_changed);
+
+ if (changed)
+ {
+ break;
+ }
+ }
+ }
+
+ notifyPhaseEnd();
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ModulePhase.h b/compiler/luci/pass/src/ModulePhase.h
new file mode 100644
index 000000000..05966cc29
--- /dev/null
+++ b/compiler/luci/pass/src/ModulePhase.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MODULE_PHASE_H__
+#define __MODULE_PHASE_H__
+
+#include <luci/ModulePass.h>
+
+#include <logo/Phase.h>
+
+#include <vector>
+
+namespace luci
+{
+
+using Phase = std::vector<std::unique_ptr<Pass>>;
+
+template <logo::PhaseStrategy S> class PhaseRunner;
+
+template <>
+class PhaseRunner<logo::PhaseStrategy::Saturate> final : public logo::PhaseRunnerMixinObservable
+{
+public:
+ PhaseRunner(luci::Module *module) : _module{module}
+ {
+ // DO NOTHING
+ }
+
+public:
+ void run(const Phase &) const;
+
+private:
+ luci::Module *_module;
+};
+
+template <>
+class PhaseRunner<logo::PhaseStrategy::Restart> final : public logo::PhaseRunnerMixinObservable
+{
+public:
+ PhaseRunner(luci::Module *module) : _module{module}
+ {
+ // DO NOTHING
+ }
+
+public:
+ void run(const Phase &) const;
+
+private:
+ luci::Module *_module;
+};
+
+} // namespace luci
+
+#endif // __MODULE_PHASE_H__
diff --git a/compiler/luci/pass/src/ProgressReporter.cpp b/compiler/luci/pass/src/ProgressReporter.cpp
index dcf47aba6..515739dc7 100644
--- a/compiler/luci/pass/src/ProgressReporter.cpp
+++ b/compiler/luci/pass/src/ProgressReporter.cpp
@@ -81,4 +81,46 @@ void ProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassE
INFO(prime) << luci::fmt(graph());
}
+void ModuleProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseBegin> *)
+{
+ LOGGER(prime);
+
+ INFO(prime) << "==============================================================";
+ INFO(prime) << "ModulePhaseRunner<" << to_str(strategy()) << ">";
+ INFO(prime) << "Initial graphs";
+ for (size_t g = 0; g < module()->size(); ++g)
+ {
+ INFO(prime) << "graphs #" << g;
+ INFO(prime) << luci::fmt(module()->graph(g));
+ }
+}
+
+void ModuleProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseEnd> *)
+{
+ LOGGER(prime);
+
+ INFO(prime) << "ModulePhaseRunner<" << to_str(strategy()) << "> - done";
+}
+
+void ModuleProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassBegin> *info)
+{
+ LOGGER(prime);
+
+ INFO(prime) << "--------------------------------------------------------------";
+ INFO(prime) << "Before " << logo::pass_name(info->pass());
+}
+
+void ModuleProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassEnd> *info)
+{
+ LOGGER(prime);
+
+ INFO(prime) << "After " << logo::pass_name(info->pass())
+ << " (changed: " << to_char(info->changed()) << ")";
+ for (size_t g = 0; g < module()->size(); ++g)
+ {
+ INFO(prime) << "graphs #" << g;
+ INFO(prime) << luci::fmt(module()->graph(g));
+ }
+}
+
} // namespace luci
diff --git a/compiler/luci/pass/src/ProgressReporter.h b/compiler/luci/pass/src/ProgressReporter.h
index bd2ba9849..cf30da735 100644
--- a/compiler/luci/pass/src/ProgressReporter.h
+++ b/compiler/luci/pass/src/ProgressReporter.h
@@ -21,6 +21,8 @@
#include <loco.h>
+#include <luci/IR/Module.h>
+
namespace luci
{
@@ -48,6 +50,30 @@ private:
logo::PhaseStrategy _strategy;
};
+class ModuleProgressReporter : public logo::PhaseEventListener
+{
+public:
+ ModuleProgressReporter(luci::Module *module, logo::PhaseStrategy strategy)
+ : _module{module}, _strategy{strategy}
+ {
+ // DO NOTHING
+ }
+
+public:
+ void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseBegin> *) override;
+ void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseEnd> *) override;
+ void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassBegin> *) override;
+ void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassEnd> *) override;
+
+public:
+ luci::Module *module(void) const { return _module; }
+ logo::PhaseStrategy strategy(void) const { return _strategy; }
+
+private:
+ luci::Module *_module;
+ logo::PhaseStrategy _strategy;
+};
+
} // namespace luci
#endif // __LUCI_PROGRESSREPORTER_H__
diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.cpp
new file mode 100644
index 000000000..af83cd83b
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQuantParamPass.cpp
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQuantParamPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Log.h>
+
+#include <iostream>
+
+namespace
+{
+
+bool copy_qparam(luci::CircleQuantParam *src, luci::CircleQuantParam *dst)
+{
+ assert(src->scale.size() == dst->scale.size());
+ assert(src->zerop.size() == dst->zerop.size());
+
+ // src and dst have the same qparam
+ if (std::equal(src->scale.begin(), src->scale.end(), dst->scale.begin()) &&
+ std::equal(src->zerop.begin(), src->zerop.end(), dst->zerop.begin()) &&
+ src->quantized_dimension == dst->quantized_dimension)
+ return false;
+
+ dst->scale.assign(src->scale.begin(), src->scale.end());
+ dst->zerop.assign(src->zerop.begin(), src->zerop.end());
+ dst->quantized_dimension = src->quantized_dimension;
+ return true;
+}
+
+bool copy_qparam(luci::CircleNode *src, luci::CircleNode *dst)
+{
+ // Skip nodes that do not have quantparams
+ auto src_qparam = src->quantparam();
+ if (not src_qparam)
+ return false;
+
+ auto dst_qparam = dst->quantparam();
+ if (not dst_qparam)
+ return false;
+
+ return copy_qparam(src_qparam, dst_qparam);
+}
+
+// Visitor to propagate quantization parameters
+struct PropagateQuantParam final : public luci::CircleNodeMutableVisitor<bool>
+{
+ PropagateQuantParam() = default;
+
+ bool visit(luci::CircleNode *) { return false; }
+
+ bool visit(luci::CircleReshape *node)
+ {
+ auto input = node->tensor();
+ if (loco::succs(input).size() != 1)
+ return false;
+
+ auto input_node = loco::must_cast<luci::CircleNode *>(input);
+ return copy_qparam(node, input_node);
+ }
+
+ // TODO : Add more Ops (e.g., Transpose)
+};
+
+} // namespace
+
+namespace luci
+{
+
+bool PropagateQuantParamPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ LOGGER(l);
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ INFO(l) << "PropagateQuantParamPass visit node: " << circle_node->name() << std::endl;
+
+ PropagateQuantParam pqp;
+ changed = circle_node->accept(&pqp);
+ if (changed)
+ break;
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp
new file mode 100644
index 000000000..15adbfc01
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp
@@ -0,0 +1,118 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQuantParamPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+void addQuantParam(luci::CircleNode *node, const std::vector<float> &scale,
+ const std::vector<int64_t> &zp)
+{
+ assert(node->quantparam() == nullptr);
+
+ auto quantparam = std::make_unique<luci::CircleQuantParam>();
+ quantparam->scale = scale;
+ quantparam->zerop = zp;
+ node->quantparam(std::move(quantparam));
+}
+
+/**
+ * Simple graph for test
+ *
+ * BEFORE
+ *
+ * [Conv] (qparam 1)
+ * |
+ * [Reshape] (qparam 2)
+ *
+ * AFTER
+ *
+ * [Conv] (qparam 2)
+ * |
+ * [Reshape] (qparam 2)
+ *
+ */
+class SimpleGraph
+{
+public:
+ SimpleGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ conv = g.nodes()->create<luci::CircleConv2D>();
+ reshape = g.nodes()->create<luci::CircleReshape>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ addQuantParam(conv, {0.1, 0.2, 0.3}, {0, 10, 20});
+ addQuantParam(reshape, {0.2, 0.4, 0.6}, {-10, 0, 10});
+
+ conv->input(input);
+ reshape->tensor(conv);
+ output->from(reshape);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input;
+ luci::CircleConv2D *conv;
+ luci::CircleReshape *reshape;
+ luci::CircleOutput *output;
+};
+
+} // namespace
+
+TEST(PropagateQuantParam, simple)
+{
+ SimpleGraph g;
+
+ luci::PropagateQuantParamPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[0]);
+ EXPECT_FLOAT_EQ(0.4, g.conv->quantparam()->scale[1]);
+ EXPECT_FLOAT_EQ(0.6, g.conv->quantparam()->scale[2]);
+ EXPECT_EQ(-10, g.conv->quantparam()->zerop[0]);
+ EXPECT_EQ(0, g.conv->quantparam()->zerop[1]);
+ EXPECT_EQ(10, g.conv->quantparam()->zerop[2]);
+}
+
+TEST(PropagateQuantParam, wrong_op_NEG)
+{
+ SimpleGraph g;
+ g.output->from(g.conv);
+ g.reshape->drop();
+
+ luci::PropagateQuantParamPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(0.1, g.conv->quantparam()->scale[0]);
+ EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[1]);
+ EXPECT_FLOAT_EQ(0.3, g.conv->quantparam()->scale[2]);
+ EXPECT_EQ(0, g.conv->quantparam()->zerop[0]);
+ EXPECT_EQ(10, g.conv->quantparam()->zerop[1]);
+ EXPECT_EQ(20, g.conv->quantparam()->zerop[2]);
+}
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
index 0ecab008f..f6eebe3b9 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
@@ -86,6 +86,100 @@ void quant_const_values(luci::CircleConst *const_node, float scaling_factor, flo
}
}
+// Quantize const per channel
+//
+// The last dimension of const is the same as the dimension of channel
+// And the rest of the const dimensions should be 1
+// So, a 'single value' is quantized per channel
+//
+// Quantization spec (f: fp value, q: quantized value)
+//
+// uint8
+// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
+// Negative f: f = (-f) * (q - 1) [q = 0, scale = -f, zp = 1]
+//
+// int16
+// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
+// Negative f: f = (-f) * (q - 0) [q = -1, scale = -f, zp = 0]
+void quant_const_per_channel(CircleConst *node, loco::DataType quant_type)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+ assert(node->rank() > 0);
+
+ for (uint32_t i = 0; i < node->rank() - 1; i++)
+ {
+ // Caller should call this function when the below condition is satisfied
+ if (node->dim(i).value() != 1)
+ throw std::runtime_error("Non-channel dimension of const node must be 1");
+ }
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ assert(size == node->dim(node->rank() - 1).value());
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->quantized_dimension = node->rank() - 1;
+ std::vector<int32_t> quantized_data(size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ auto data = node->at<loco::DataType::FLOAT32>(i);
+ if (quant_type == loco::DataType::U8)
+ {
+ if (data >= 0)
+ {
+ quantparam->scale.push_back(data);
+ quantparam->zerop.push_back(0);
+ quantized_data[i] = 1;
+ }
+ else
+ {
+ quantparam->scale.push_back(-data);
+ quantparam->zerop.push_back(1);
+ quantized_data[i] = 0;
+ }
+ }
+ else if (quant_type == loco::DataType::S16)
+ {
+ if (data >= 0)
+ {
+ quantparam->scale.push_back(data);
+ quantized_data[i] = 1;
+ }
+ else
+ {
+ quantparam->scale.push_back(-data);
+ quantized_data[i] = -1;
+ }
+ quantparam->zerop.push_back(0);
+ }
+ }
+ node->quantparam(std::move(quantparam));
+
+ switch (quant_type)
+ {
+ case loco::DataType::U8:
+ node->dtype(loco::DataType::U8);
+ node->size<loco::DataType::U8>(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ assert(quantized_data[i] == 0 || quantized_data[i] == 1);
+ node->at<loco::DataType::U8>(i) = quantized_data[i];
+ }
+ break;
+ case loco::DataType::S16:
+ node->dtype(loco::DataType::S16);
+ node->size<loco::DataType::S16>(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ assert(quantized_data[i] == -1 || quantized_data[i] == 1);
+ node->at<loco::DataType::S16>(i) = quantized_data[i];
+ }
+ break;
+ default:
+ throw std::runtime_error("Unsupported data type");
+ }
+}
+
void quant_const(CircleConst *node, loco::DataType quant_type)
{
assert(node->dtype() == loco::DataType::FLOAT32);
@@ -612,10 +706,51 @@ struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
}
};
+void quant_instnorm(luci::CircleInstanceNorm *node, loco::DataType output_type,
+ QuantizationGranularity granularity)
+{
+ auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma());
+ auto beta = loco::must_cast<luci::CircleConst *>(node->beta());
+ assert(gamma->dtype() == loco::DataType::FLOAT32);
+ assert(beta->dtype() == loco::DataType::FLOAT32);
+
+ if (granularity == QuantizationGranularity::LayerWise)
+ {
+ quant_const(gamma, output_type);
+ quant_const(beta, output_type);
+ }
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ {
+ quant_const_per_channel(gamma, output_type);
+ quant_const_per_channel(beta, output_type);
+ }
+ else
+ throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'");
+}
+
+void quant_prelu(luci::CirclePRelu *node, loco::DataType output_type,
+ QuantizationGranularity granularity)
+{
+ auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha());
+ assert(alpha->dtype() == loco::DataType::FLOAT32);
+
+ if (granularity == QuantizationGranularity::LayerWise)
+ {
+ quant_const(alpha, output_type);
+ }
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ {
+ quant_const_per_channel(alpha, output_type);
+ }
+ else
+ throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'");
+}
+
/**
* @brief Quantize const input tensors using min/max of const values
*/
-void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type)
+void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type,
+ QuantizationGranularity granularity)
{
auto opcode = node->opcode();
auto arity = node->arity();
@@ -660,20 +795,26 @@ void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type)
quant_const(const_node, output_type);
break;
+ case luci::CircleOpcode::INSTANCE_NORM:
+ quant_instnorm(loco::must_cast<luci::CircleInstanceNorm *>(node), output_type, granularity);
+ break;
+
+ case luci::CircleOpcode::PRELU:
+ quant_prelu(loco::must_cast<luci::CirclePRelu *>(node), output_type, granularity);
+ break;
+
case luci::CircleOpcode::ADD:
case luci::CircleOpcode::ADD_N:
case luci::CircleOpcode::DIV:
case luci::CircleOpcode::EQUAL:
case luci::CircleOpcode::GREATER:
case luci::CircleOpcode::GREATER_EQUAL:
- case luci::CircleOpcode::INSTANCE_NORM:
case luci::CircleOpcode::LESS:
case luci::CircleOpcode::LESS_EQUAL:
case luci::CircleOpcode::MAXIMUM:
case luci::CircleOpcode::MINIMUM:
case luci::CircleOpcode::MUL:
case luci::CircleOpcode::NOT_EQUAL:
- case luci::CircleOpcode::PRELU:
case luci::CircleOpcode::SUB:
// Quantize all const inputs using their values
for (uint32_t i = 0; i < arity; i++)
@@ -817,7 +958,7 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g)
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- quantize_const_inputs(circle_node, _output_dtype);
+ quantize_const_inputs(circle_node, _output_dtype, _granularity);
}
// Propagate quantization parameters of concat Op
diff --git a/compiler/luci/pass/src/RemoveRedundantTranspose.cpp b/compiler/luci/pass/src/RemoveRedundantTranspose.cpp
new file mode 100644
index 000000000..33cb76520
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantTranspose.cpp
@@ -0,0 +1,127 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveRedundantTransposePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+/// @brief Return true if first_perm[second_perm[i]] == i
+bool check_perm(const luci::CircleConst *first_perm, const luci::CircleConst *second_perm)
+{
+ assert(first_perm->rank() == 1);
+ assert(second_perm->rank() == 1);
+ assert(second_perm->size<loco::DataType::S32>() == first_perm->size<loco::DataType::S32>());
+ for (int32_t i = 0; i < static_cast<int32_t>(first_perm->size<loco::DataType::S32>()); i++)
+ {
+ if (first_perm->at<loco::DataType::S32>(second_perm->at<loco::DataType::S32>(i)) != i)
+ return false;
+ }
+ return true;
+}
+
+bool remove_consecutive_transpose_function(luci::CircleNode *node)
+{
+ auto target_node = dynamic_cast<luci::CircleTranspose *>(node);
+ if (target_node == nullptr)
+ return false;
+ auto pred_node = dynamic_cast<luci::CircleTranspose *>(target_node->a());
+ if (pred_node == nullptr)
+ return false;
+ if (loco::succs(pred_node).size() != 1)
+ return false;
+
+ auto pred_perm = dynamic_cast<luci::CircleConst *>(target_node->perm());
+ if (pred_perm == nullptr)
+ return false;
+
+ auto main_perm = dynamic_cast<luci::CircleConst *>(pred_node->perm());
+ if (main_perm == nullptr)
+ return false;
+
+ auto main_node = loco::must_cast<luci::CircleNode *>(pred_node->a());
+ if (check_perm(pred_perm, main_perm))
+ {
+ replace(node).with(main_node);
+ }
+ else
+ {
+ auto g = main_perm->graph();
+ auto new_const_node = g->nodes()->create<luci::CircleConst>();
+
+ new_const_node->dtype(loco::DataType::S32);
+ new_const_node->rank(1);
+ new_const_node->dim(0) = main_perm->dim(0);
+ new_const_node->size<loco::DataType::S32>(main_perm->dim(0).value());
+ new_const_node->shape_status(luci::ShapeStatus::VALID);
+ for (uint32_t i = 0; i < main_perm->size<loco::DataType::S32>(); i++)
+ {
+ new_const_node->at<loco::DataType::S32>(i) =
+ pred_perm->at<loco::DataType::S32>(main_perm->at<loco::DataType::S32>(i));
+ }
+ pred_node->perm(new_const_node);
+ replace(node).with(pred_node);
+ }
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+/**
+ * BEFORE
+ * |
+ * [CircleNode] [CircleConst]
+ * (main_node) (main_perm)
+ * \ /
+ * [CircleTranspose] [CircleConst]
+ * (pred_node) (pred_perm)
+ * \ /
+ * [CircleTranspose]
+ * (target_node)
+ * |
+ *
+ * AFTER
+ * <Optional Case>
+ *
+ * | | |
+ * [CircleNode] [CircleConst] |
+ * (main_node) (new_const_node) |
+ * \ / or [CircleNode]
+ * [CircleTranspose] (main_node)
+ * (pred_node) |
+ * | |
+ *
+ */
+bool RemoveRedundantTransposePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (remove_consecutive_transpose_function(circle_node))
+ {
+ changed = true;
+ break;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp b/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp
new file mode 100644
index 000000000..db608b674
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp
@@ -0,0 +1,156 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "luci/Pass/RemoveRedundantTransposePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <vector>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+void setValue(luci::CircleConst *node, const std::vector<int> &v)
+{
+ node->dtype(loco::DataType::S32);
+ node->size<loco::DataType::S32>(v.size());
+ node->rank(1);
+ node->dim(0).set(v.size());
+ for (int i = 0; i < v.size(); ++i)
+ {
+ node->at<loco::DataType::S32>(i) = v[i];
+ }
+}
+
+/**
+ * Type1
+ * BEFORE
+ * |
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleTranspose] [CircleConst]
+ * \ /
+ * [CircleTranspose]
+ * |
+ *
+ * AFTER
+ * |
+ * [CircleNode]
+ * | Remove Both
+ *
+ * --------------------------------------------
+ *
+ * Type2
+ * BEFORE
+ * |
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleTranspose] [CircleConst]
+ * \ /
+ * [CircleTranspose]
+ * |
+ *
+ * AFTER
+ * | |
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleTranspose]
+ * |
+ *
+ */
+void create_redundunt_transpose(loco::Graph *g, const std::vector<int32_t> &perm1,
+ const std::vector<int32_t> &perm2)
+{
+ assert(g);
+
+ auto input = g->nodes()->create<luci::CircleInput>();
+ auto graph_input = g->inputs()->create();
+ input->index(graph_input->index());
+
+ // Create perm1
+ auto perm1_node = g->nodes()->create<luci::CircleConst>();
+ setValue(perm1_node, perm1);
+
+ auto transpose1 = g->nodes()->create<luci::CircleTranspose>();
+ transpose1->dtype(loco::DataType::FLOAT32);
+ transpose1->a(input);
+ transpose1->perm(perm1_node);
+
+ // Create perm2
+ auto perm2_node = g->nodes()->create<luci::CircleConst>();
+ setValue(perm2_node, perm2);
+
+ auto transpose2 = g->nodes()->create<luci::CircleTranspose>();
+ transpose2->dtype(loco::DataType::FLOAT32);
+ transpose2->a(transpose1);
+ transpose2->perm(perm2_node);
+
+ // Output
+ auto output = g->nodes()->create<luci::CircleOutput>();
+ output->from(transpose2);
+ auto graph_output = g->outputs()->create();
+ output->index(graph_output->index());
+}
+
+} // namespace
+
+TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type1)
+{
+ auto graph = loco::make_graph();
+ create_redundunt_transpose(graph.get(), {1, 0, 2, 3}, {1, 0, 2, 3});
+
+ luci::RemoveRedundantTransposePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleTranspose *transpose_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ auto trans = dynamic_cast<luci::CircleTranspose *>(node);
+ if (not trans)
+ continue;
+ transpose_node = trans;
+ break;
+ }
+ // No transpose node is in graph.
+ ASSERT_EQ(nullptr, transpose_node);
+}
+
+TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type2)
+{
+ auto graph = loco::make_graph();
+ create_redundunt_transpose(graph.get(), {0, 1, 3, 2}, {1, 0, 2, 3});
+
+ luci::RemoveRedundantTransposePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleTranspose *transpose_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ auto trans = dynamic_cast<luci::CircleTranspose *>(node);
+ if (not trans)
+ continue;
+ transpose_node = trans;
+ break;
+ }
+ // Just one transpose node, with updated perm constant.
+ ASSERT_NE(nullptr, transpose_node);
+ auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm());
+ ASSERT_EQ(1, perm->at<loco::DataType::S32>(0));
+ ASSERT_EQ(0, perm->at<loco::DataType::S32>(1));
+ ASSERT_EQ(3, perm->at<loco::DataType::S32>(2));
+ ASSERT_EQ(2, perm->at<loco::DataType::S32>(3));
+}
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
new file mode 100644
index 000000000..7096c2591
--- /dev/null
+++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
@@ -0,0 +1,223 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma)
+{
+ assert(gamma->rank() == 1);
+ auto channel_size = gamma->dim(0).value();
+
+ // Channel-wise MUL is the same as DEPTHWISE_CONV2D with filter shape (1,1,1,channel_size)
+ auto weights = gamma->graph()->nodes()->create<luci::CircleConst>();
+ weights->dtype(loco::DataType::FLOAT32);
+ weights->rank(4);
+ weights->dim(0).set(1);
+ weights->dim(1).set(1);
+ weights->dim(2).set(1);
+ weights->dim(3).set(channel_size);
+ weights->shape_status(luci::ShapeStatus::VALID);
+ weights->size<loco::DataType::FLOAT32>(channel_size);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ weights->at<loco::DataType::FLOAT32>(i) = gamma->at<loco::DataType::FLOAT32>(i);
+ }
+
+ return weights;
+}
+
+luci::CircleConst *create_bias_from_beta(luci::CircleConst *beta)
+{
+ assert(beta->rank() == 1);
+ auto channel_size = beta->dim(0).value();
+
+ // Channel-wise ADD is the same as bias (shape = (channel_size)) of DEPTHWISE_CONV2D
+ auto bias = beta->graph()->nodes()->create<luci::CircleConst>();
+ bias->dtype(loco::DataType::FLOAT32);
+ bias->rank(1);
+ bias->dim(0).set(channel_size);
+ bias->size<loco::DataType::FLOAT32>(channel_size);
+ bias->shape_status(luci::ShapeStatus::VALID);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ bias->at<loco::DataType::FLOAT32>(i) = beta->at<loco::DataType::FLOAT32>(i);
+ }
+
+ return bias;
+}
+
+bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::CircleConst *&beta)
+{
+ auto x = loco::must_cast<luci::CircleNode *>(add->x());
+ auto y = loco::must_cast<luci::CircleNode *>(add->y());
+
+ luci::CircleMul *pred = nullptr;
+ luci::CircleConst *constant = nullptr;
+
+ if (x->opcode() == luci::CircleOpcode::CIRCLECONST && y->opcode() == luci::CircleOpcode::MUL)
+ {
+ pred = loco::must_cast<luci::CircleMul *>(y);
+ constant = loco::must_cast<luci::CircleConst *>(x);
+ }
+ else if (x->opcode() == luci::CircleOpcode::MUL && y->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ pred = loco::must_cast<luci::CircleMul *>(x);
+ constant = loco::must_cast<luci::CircleConst *>(y);
+ }
+ else
+ {
+ return false;
+ }
+
+ if (constant->rank() != 1)
+ return false;
+
+ auto channel_dim = constant->dim(0);
+ // Assumption: Layout is channel-last
+ if (!(channel_dim == add->dim(add->rank() - 1)))
+ return false;
+
+ mul = pred;
+ beta = constant;
+ return true;
+}
+
+// Check if mul is batchnorm mul
+bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node,
+ luci::CircleConst *&gamma)
+{
+ auto x = dynamic_cast<luci::CircleConst *>(mul->x());
+ auto y = dynamic_cast<luci::CircleConst *>(mul->y());
+
+ luci::CircleNode *pred = nullptr;
+ luci::CircleConst *constant = nullptr;
+
+ if (x != nullptr && y == nullptr)
+ {
+ pred = loco::must_cast<luci::CircleNode *>(mul->y());
+ constant = x;
+ }
+ else if (x == nullptr && y != nullptr)
+ {
+ pred = loco::must_cast<luci::CircleNode *>(mul->x());
+ constant = y;
+ }
+ else
+ {
+ return false;
+ }
+
+ if (constant->rank() != 1)
+ return false;
+
+ auto channel_dim = constant->dim(0);
+ if (!(channel_dim == mul->dim(mul->rank() - 1)))
+ return false;
+
+ pred_node = pred;
+ gamma = constant;
+ return true;
+}
+
+/**
+ * Replace channel-wise Mul/Add with DepthwiseConv2D
+ *
+ * BEFORE
+ *
+ * [Node] [gamma]
+ * | /
+ * [Mul] [beta]
+ * | /
+ * [Add]
+ *
+ * AFTER
+ *
+ * [Node] [weights] [bias]
+ * \ / /
+ * [DepthwiseConv2D]
+ */
+bool replace_mul_add_with_dwconv(luci::CircleAdd *add)
+{
+ luci::CircleNode *pred_node = nullptr;
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *beta = nullptr;
+ luci::CircleConst *gamma = nullptr;
+
+ if (!is_batchnorm_add(add, mul, beta))
+ return false;
+
+ if (loco::succs(mul).size() != 1)
+ return false;
+
+ if (!is_batchnorm_mul(mul, pred_node, gamma))
+ return false;
+
+ if (pred_node->rank() != 4)
+ return false;
+
+ if (pred_node->dtype() != loco::DataType::FLOAT32 || beta->dtype() != loco::DataType::FLOAT32 ||
+ gamma->dtype() != loco::DataType::FLOAT32)
+ return false;
+
+ auto weights = create_weights_from_gamma(gamma);
+ auto bias = create_bias_from_beta(beta);
+
+ auto dwconv = add->graph()->nodes()->create<luci::CircleDepthwiseConv2D>();
+ dwconv->input(pred_node);
+ dwconv->filter(weights);
+ dwconv->bias(bias);
+ dwconv->padding(luci::Padding::SAME);
+ dwconv->stride()->w(1);
+ dwconv->stride()->h(1);
+ dwconv->depthMultiplier(1);
+ dwconv->dilation()->w(1);
+ dwconv->dilation()->h(1);
+ dwconv->fusedActivationFunction(add->fusedActivationFunction());
+
+ loco::replace(add).with(dwconv);
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool ReplaceMulAddWithDepthwiseConvPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto add = dynamic_cast<luci::CircleAdd *>(node);
+ if (not add)
+ continue;
+
+ if (replace_mul_add_with_dwconv(add))
+ {
+ changed = true;
+ break;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
new file mode 100644
index 000000000..a90182aaa
--- /dev/null
+++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
@@ -0,0 +1,142 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ * Simple graph for test
+ *
+ * BEFORE
+ *
+ * [Node] [gamma]
+ * | /
+ * [Mul] [beta]
+ * | /
+ * [Add]
+ *
+ * AFTER
+ *
+ * [Node] [weights] [bias]
+ * \ / /
+ * [DepthwiseConv2D]
+ */
+class SimpleGraph
+{
+public:
+ SimpleGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ mul = g.nodes()->create<luci::CircleMul>();
+ gamma = g.nodes()->create<luci::CircleConst>();
+ add = g.nodes()->create<luci::CircleAdd>();
+ beta = g.nodes()->create<luci::CircleConst>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ input->dtype(loco::DataType::FLOAT32);
+ mul->dtype(loco::DataType::FLOAT32);
+ gamma->dtype(loco::DataType::FLOAT32);
+ add->dtype(loco::DataType::FLOAT32);
+ beta->dtype(loco::DataType::FLOAT32);
+ output->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ input->shape({1, 4, 4, channel_size});
+ mul->shape({1, 4, 4, channel_size});
+ gamma->shape({channel_size});
+ add->shape({1, 4, 4, channel_size});
+ beta->shape({channel_size});
+ output->shape({1, 4, 4, channel_size});
+
+ gamma->size<loco::DataType::FLOAT32>(channel_size);
+ beta->size<loco::DataType::FLOAT32>(channel_size);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ gamma->at<loco::DataType::FLOAT32>(i) = i;
+ beta->at<loco::DataType::FLOAT32>(i) = i;
+ }
+
+ mul->x(input);
+ mul->y(gamma);
+ add->x(mul);
+ add->y(beta);
+ output->from(add);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *gamma = nullptr;
+ luci::CircleAdd *add = nullptr;
+ luci::CircleConst *beta = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+} // namespace
+
+TEST(ReplaceMulAddWithDepthwiseConv, simple)
+{
+ SimpleGraph g;
+
+ luci::ReplaceMulAddWithDepthwiseConvPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ auto dwconv = dynamic_cast<luci::CircleDepthwiseConv2D *>(g.output->from());
+ EXPECT_NE(nullptr, dwconv);
+
+ uint32_t channel_size = 16;
+ auto weights = dynamic_cast<luci::CircleConst *>(dwconv->filter());
+ auto bias = dynamic_cast<luci::CircleConst *>(dwconv->bias());
+ EXPECT_NE(nullptr, weights);
+ EXPECT_EQ(4, weights->rank());
+ EXPECT_EQ(channel_size, weights->dim(3).value());
+ EXPECT_NE(nullptr, bias);
+ EXPECT_EQ(1, bias->rank());
+ EXPECT_EQ(channel_size, bias->dim(0).value());
+
+ for (int i = 0; i < channel_size; i++)
+ {
+ EXPECT_FLOAT_EQ(i, weights->at<loco::DataType::FLOAT32>(i));
+ EXPECT_FLOAT_EQ(i, bias->at<loco::DataType::FLOAT32>(i));
+ }
+}
+
+TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG)
+{
+ SimpleGraph g;
+ // swap mul/add (changed to add->mul)
+ g.add->x(g.input);
+ loco::replace(g.add).with(g.mul);
+ g.mul->x(g.add);
+
+ luci::ReplaceMulAddWithDepthwiseConvPass pass;
+ auto changed = pass.run(&g.g);
+
+ EXPECT_EQ(false, changed);
+}
diff --git a/compiler/luci/pass/src/ShapeInferencePass.cpp b/compiler/luci/pass/src/ShapeInferencePass.cpp
index f681b3d5f..4bd0aaed4 100644
--- a/compiler/luci/pass/src/ShapeInferencePass.cpp
+++ b/compiler/luci/pass/src/ShapeInferencePass.cpp
@@ -28,6 +28,19 @@
namespace luci
{
+bool ShapeInferencePass::run(luci::Module *m)
+{
+ bool changed = false;
+
+ for (size_t g = 0; g < m->size(); ++g)
+ {
+ if (run(m->graph(g)))
+ changed = true;
+ }
+
+ return changed;
+}
+
bool ShapeInferencePass::run(loco::Graph *g)
{
loco::CanonicalShapeInferenceRule canonical_rule;
diff --git a/compiler/luci/pass/src/ShapeSignatureInferencePass.cpp b/compiler/luci/pass/src/ShapeSignatureInferencePass.cpp
new file mode 100644
index 000000000..115b77a96
--- /dev/null
+++ b/compiler/luci/pass/src/ShapeSignatureInferencePass.cpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ShapeSignatureInferencePass.h"
+
+#include <luci/IR/CircleShapeSignature.h>
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool ShapeSignatureInferencePass::run(luci::Module *m)
+{
+ bool changed = false;
+
+ for (size_t g = 0; g < m->size(); ++g)
+ {
+ if (run(m->graph(g)))
+ changed = true;
+ }
+
+ return changed;
+}
+
+bool ShapeSignatureInferencePass::run(loco::Graph *g)
+{
+ luci::ssinf::Rule signature_inference_rule;
+ bool changed = false;
+
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ luci::ShapeSignature shape_signature;
+
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (signature_inference_rule.infer(circle_node, shape_signature))
+ {
+ if (!(circle_node->shape_signature() == shape_signature))
+ {
+ circle_node->shape_signature(shape_signature);
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp
new file mode 100644
index 000000000..6a58f18c5
--- /dev/null
+++ b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp
@@ -0,0 +1,139 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <cassert>
+#include <vector>
+
+namespace
+{
+
+bool satisfy_precondition(luci::CircleFullyConnected *fc)
+{
+ // check if it's already been shuffled
+ if (fc->weights_format() != luci::CircleFullyConnected::WeightsFormat::DEFAULT)
+ return false;
+
+ // check if its data type is FLOAT32
+ if (fc->dtype() != loco::DataType::FLOAT32)
+ return false;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(fc->weights());
+ // rank must be 2
+ if (weights->rank() != 2)
+ return false;
+
+ // check if it has sparsity parameter
+ if (weights->sparsityparam())
+ return false;
+
+ // check if the number of row of FullyConnected's weight is a multiple of 16
+ const uint32_t MULTIPLE = 16;
+ uint32_t rows = weights->dim(0).value();
+ if (rows % MULTIPLE)
+ return false;
+
+ return true;
+}
+
+// get FullyConnected op vector that has same tensor
+void get_FCs_having_same_tensor(std::vector<luci::CircleFullyConnected *> &fc_vec, loco::Graph *g,
+ luci::CircleFullyConnected *fc)
+{
+ auto the_tensor = fc->weights();
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
+ if (not fc)
+ continue;
+
+ if (fc->weights() == the_tensor)
+ fc_vec.push_back(fc);
+ }
+}
+
+luci::CircleConst *shuffle_weight(luci::CircleFullyConnected *fc)
+{
+ auto the_weights = loco::must_cast<luci::CircleConst *>(fc->weights());
+
+ // create CircleConst where shuffled data will be stored
+ luci::CircleConst *new_weights = fc->graph()->nodes()->create<luci::CircleConst>();
+ new_weights->dtype(loco::DataType::FLOAT32);
+ new_weights->size<loco::DataType::FLOAT32>(the_weights->size<loco::DataType::FLOAT32>());
+ new_weights->rank(the_weights->rank());
+ new_weights->shape_status(the_weights->shape_status());
+ for (uint32_t r = 0; r < new_weights->rank(); r++)
+ {
+ new_weights->dim(r).set(the_weights->dim(r).value());
+ }
+
+ // suffle weight
+ const uint32_t MULTIPLE = 16;
+ const uint32_t rows = the_weights->dim(0).value();
+ const uint32_t cols = the_weights->dim(1).value();
+ const uint32_t r_step = rows / MULTIPLE;
+ uint32_t index = 0;
+ for (uint32_t r = 0; r < r_step; r++)
+ {
+ for (uint32_t c = 0; c < cols; c++)
+ {
+ for (uint32_t i = 0; i < MULTIPLE; i++)
+ {
+ new_weights->at<loco::DataType::FLOAT32>(index++) =
+ the_weights->at<loco::DataType::FLOAT32>((r * MULTIPLE + i) * cols + c);
+ }
+ }
+ }
+
+ return new_weights;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool ShuffleWeightTo16x1Float32Pass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
+ if (not fc)
+ continue;
+
+ if (not satisfy_precondition(fc))
+ continue;
+
+ std::vector<luci::CircleFullyConnected *> fc_vec;
+ get_FCs_having_same_tensor(fc_vec, g, fc);
+ auto new_weights = shuffle_weight(fc);
+
+ // replace to new weights
+ for (const auto fc : fc_vec)
+ {
+ fc->weights(new_weights);
+ fc->weights_format(luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32);
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp
new file mode 100644
index 000000000..9745e5754
--- /dev/null
+++ b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp
@@ -0,0 +1,118 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+void create_fc_net(loco::Graph *g)
+{
+ assert(g);
+
+ const uint32_t ROW = 16;
+ const uint32_t COL = 2;
+ const uint32_t elements_num = ROW * COL;
+
+ // input
+ auto input = g->nodes()->create<luci::CircleInput>();
+ auto graph_input = g->inputs()->create();
+ input->index(graph_input->index());
+
+ // fc weights
+ auto weights = g->nodes()->create<luci::CircleConst>();
+ weights->dtype(loco::DataType::FLOAT32);
+ weights->size<loco::DataType::FLOAT32>(elements_num);
+ weights->rank(2);
+ weights->dim(0).set(ROW);
+ weights->dim(1).set(COL);
+ for (uint32_t idx = 0; idx < elements_num; idx++)
+ {
+ weights->at<loco::DataType::FLOAT32>(idx) = idx;
+ }
+
+ // fc
+ auto fc = g->nodes()->create<luci::CircleFullyConnected>();
+ fc->dtype(loco::DataType::FLOAT32);
+ fc->input(input);
+ fc->weights(weights);
+
+ // output
+ auto output = g->nodes()->create<luci::CircleOutput>();
+ output->from(fc);
+ auto graph_output = g->outputs()->create();
+ output->index(graph_output->index());
+}
+
+TEST(ShuffleWeightTo16x1Float32PassTest, SimpleTest1)
+{
+ auto graph = loco::make_graph();
+ create_fc_net(graph.get());
+
+ luci::CircleFullyConnected *fc_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
+ if (not fc)
+ continue;
+
+ fc_node = fc;
+ break;
+ }
+ ASSERT_NE(fc_node, nullptr);
+ auto weights = loco::must_cast<luci::CircleConst *>(fc_node->weights());
+ // before
+ ASSERT_EQ(0, weights->at<loco::DataType::FLOAT32>(0));
+ ASSERT_EQ(1, weights->at<loco::DataType::FLOAT32>(1));
+ ASSERT_EQ(2, weights->at<loco::DataType::FLOAT32>(2));
+ ASSERT_EQ(3, weights->at<loco::DataType::FLOAT32>(3));
+ ASSERT_EQ(4, weights->at<loco::DataType::FLOAT32>(4));
+ ASSERT_EQ(5, weights->at<loco::DataType::FLOAT32>(5));
+ ASSERT_EQ(6, weights->at<loco::DataType::FLOAT32>(6));
+ ASSERT_EQ(7, weights->at<loco::DataType::FLOAT32>(7));
+ ASSERT_EQ(8, weights->at<loco::DataType::FLOAT32>(8));
+ ASSERT_EQ(9, weights->at<loco::DataType::FLOAT32>(9));
+ ASSERT_EQ(10, weights->at<loco::DataType::FLOAT32>(10));
+ ASSERT_EQ(11, weights->at<loco::DataType::FLOAT32>(11));
+ ASSERT_EQ(12, weights->at<loco::DataType::FLOAT32>(12));
+ ASSERT_EQ(13, weights->at<loco::DataType::FLOAT32>(13));
+ ASSERT_EQ(14, weights->at<loco::DataType::FLOAT32>(14));
+ ASSERT_EQ(15, weights->at<loco::DataType::FLOAT32>(15));
+
+ luci::ShuffleWeightTo16x1Float32Pass pass;
+ while (pass.run(graph.get()))
+ ;
+
+ weights = loco::must_cast<luci::CircleConst *>(fc_node->weights());
+ // after
+ ASSERT_EQ(0, weights->at<loco::DataType::FLOAT32>(0));
+ ASSERT_EQ(2, weights->at<loco::DataType::FLOAT32>(1));
+ ASSERT_EQ(4, weights->at<loco::DataType::FLOAT32>(2));
+ ASSERT_EQ(6, weights->at<loco::DataType::FLOAT32>(3));
+ ASSERT_EQ(8, weights->at<loco::DataType::FLOAT32>(4));
+ ASSERT_EQ(10, weights->at<loco::DataType::FLOAT32>(5));
+ ASSERT_EQ(12, weights->at<loco::DataType::FLOAT32>(6));
+ ASSERT_EQ(14, weights->at<loco::DataType::FLOAT32>(7));
+ ASSERT_EQ(16, weights->at<loco::DataType::FLOAT32>(8));
+ ASSERT_EQ(18, weights->at<loco::DataType::FLOAT32>(9));
+ ASSERT_EQ(20, weights->at<loco::DataType::FLOAT32>(10));
+ ASSERT_EQ(22, weights->at<loco::DataType::FLOAT32>(11));
+ ASSERT_EQ(24, weights->at<loco::DataType::FLOAT32>(12));
+ ASSERT_EQ(26, weights->at<loco::DataType::FLOAT32>(13));
+ ASSERT_EQ(28, weights->at<loco::DataType::FLOAT32>(14));
+ ASSERT_EQ(30, weights->at<loco::DataType::FLOAT32>(15));
+}
diff --git a/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp b/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp
new file mode 100644
index 000000000..44e974b91
--- /dev/null
+++ b/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp
@@ -0,0 +1,107 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/SubstitutePackToReshapePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+bool substitute_pack_to_reshape(luci::CircleNode *node)
+{
+ auto target_node = dynamic_cast<luci::CirclePack *>(node);
+ if (target_node == nullptr)
+ return false;
+ if (target_node->values_count() != 1)
+ return false;
+ auto value_node = loco::must_cast<luci::CircleNode *>(target_node->values(0));
+ if (value_node->shape_status() != luci::ShapeStatus::VALID)
+ return false;
+ int32_t axis = target_node->axis();
+ if (axis < 0)
+ axis = axis + static_cast<int32_t>(value_node->rank()) + 1;
+
+ auto graph = target_node->graph();
+ auto reshape_node = graph->nodes()->create<luci::CircleReshape>();
+ reshape_node->tensor(value_node);
+
+ auto const_node = graph->nodes()->create<luci::CircleConst>();
+ const_node->dtype(loco::DataType::S32);
+ const_node->size<loco::DataType::S32>(value_node->rank() + 1);
+ const_node->shape_status(luci::ShapeStatus::VALID);
+ const_node->rank(1);
+ const_node->dim(0).set(value_node->rank() + 1);
+ for (int32_t i = 0; i < static_cast<int32_t>(value_node->rank()) + 1; i++)
+ {
+ if (i == axis)
+ {
+ const_node->at<loco::DataType::S32>(i) = 1;
+ }
+ else if (i < axis)
+ {
+ const_node->at<loco::DataType::S32>(i) = value_node->dim(i).value();
+ }
+ else
+ {
+ const_node->at<loco::DataType::S32>(i) = value_node->dim(i - 1).value();
+ }
+ }
+ reshape_node->shape(const_node);
+ replace(target_node).with(reshape_node);
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * BEFORE
+ * |
+ * [CircleNode]
+ * |
+ * [CirclePack]
+ * |
+ * [CircleNode]
+ * |
+ *
+ * AFTER
+ * |
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleReshape]
+ * |
+ * [CircleNode]
+ * |
+ *
+ */
+bool SubstitutePackToReshapePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (substitute_pack_to_reshape(circle_node))
+ {
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp b/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp
new file mode 100644
index 000000000..143b88896
--- /dev/null
+++ b/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp
@@ -0,0 +1,124 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "luci/Pass/SubstitutePackToReshapePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ * BEFORE
+ * |
+ * [CircleNode]
+ * |
+ * [CirclePack]
+ * |
+ * [CircleNode]
+ * |
+ *
+ * AFTER
+ * |
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleReshape]
+ * |
+ * [CircleNode]
+ * |
+ *
+ */
+void create_substitute_pack_to_reshape(loco::Graph *g, const std::initializer_list<uint32_t> shape,
+ int32_t axis)
+{
+ assert(g);
+
+ // Input Create.
+ auto input = g->nodes()->create<luci::CircleInput>();
+ auto graph_input = g->inputs()->create();
+ input->index(graph_input->index());
+ input->shape_status(luci::ShapeStatus::VALID);
+ input->rank(shape.size());
+ input->shape(shape);
+
+ // Pack Node create.
+ auto pack = g->nodes()->create<luci::CirclePack>(1);
+ pack->values(0, input);
+ pack->axis(axis);
+
+ // Output Connect.
+ auto output = g->nodes()->create<luci::CircleOutput>();
+ output->from(pack);
+ auto graph_output = g->outputs()->create();
+ output->index(graph_output->index());
+
+ return;
+}
+
+} // namespace
+
+TEST(SubstitutePackToReshapePass, simple_case)
+{
+ auto graph = loco::make_graph();
+ create_substitute_pack_to_reshape(graph.get(), {1, 2, 3, 4}, 0);
+ luci::SubstitutePackToReshapePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleReshape *reshape_node = nullptr;
+ luci::CirclePack *pack_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
+ reshape_node = reshape;
+ else if (auto pack = dynamic_cast<luci::CirclePack *>(node))
+ pack_node = pack;
+ }
+ ASSERT_NE(nullptr, reshape_node);
+ ASSERT_EQ(nullptr, pack_node);
+ auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape());
+ ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(0));
+ ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(1));
+ ASSERT_EQ(2, new_shape->at<loco::DataType::S32>(2));
+ ASSERT_EQ(3, new_shape->at<loco::DataType::S32>(3));
+ ASSERT_EQ(4, new_shape->at<loco::DataType::S32>(4));
+}
+
+TEST(SubstitutePackToReshapePass, simple_case_neg_axis)
+{
+ auto graph = loco::make_graph();
+ create_substitute_pack_to_reshape(graph.get(), {1, 2, 3, 4}, -1);
+ luci::SubstitutePackToReshapePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleReshape *reshape_node = nullptr;
+ luci::CirclePack *pack_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
+ reshape_node = reshape;
+ else if (auto pack = dynamic_cast<luci::CirclePack *>(node))
+ pack_node = pack;
+ }
+ ASSERT_NE(nullptr, reshape_node);
+ ASSERT_EQ(nullptr, pack_node);
+ auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape());
+ ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(0));
+ ASSERT_EQ(2, new_shape->at<loco::DataType::S32>(1));
+ ASSERT_EQ(3, new_shape->at<loco::DataType::S32>(2));
+ ASSERT_EQ(4, new_shape->at<loco::DataType::S32>(3));
+ ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(4));
+}
diff --git a/compiler/luci/pass/src/TypeInferencePass.cpp b/compiler/luci/pass/src/TypeInferencePass.cpp
index 2c7b3a897..63744045c 100644
--- a/compiler/luci/pass/src/TypeInferencePass.cpp
+++ b/compiler/luci/pass/src/TypeInferencePass.cpp
@@ -26,6 +26,19 @@
namespace luci
{
+bool TypeInferencePass::run(luci::Module *m)
+{
+ bool changed = false;
+
+ for (size_t g = 0; g < m->size(); ++g)
+ {
+ if (run(m->graph(g)))
+ changed = true;
+ }
+
+ return changed;
+}
+
bool TypeInferencePass::run(loco::Graph *g)
{
loco::CanonicalTypeInferenceRule canonical_rule;
diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h
index fb934c2cf..c301db5f4 100644
--- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h
+++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h
@@ -21,6 +21,10 @@
#include <loco/IR/Nodes.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Service/CircleShapeInferenceHelper.h>
+
namespace luci
{
@@ -36,6 +40,155 @@ struct ShapeInference
static ShapeDescription get(loco::Node *node);
};
+namespace sinf // namespace for Shape Inference
+{
+
+struct Rule
+{
+ bool infer(const luci::CircleNode *, loco::TensorShape &) const;
+};
+
+class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape>
+{
+public:
+ // TODO Remove this when all of visit function is implemented
+ loco::TensorShape visit(const luci::CircleNode *node) final { return sinf::circle_shape(node); }
+
+ // loco::TensorShape visit(const luci::CircleAbs *node) final;
+ // loco::TensorShape visit(const luci::CircleAdd *node) final;
+ // loco::TensorShape visit(const luci::CircleAddN *node) final;
+ // loco::TensorShape visit(const luci::CircleArgMax *node) final;
+ // loco::TensorShape visit(const luci::CircleArgMin *node) final;
+ // loco::TensorShape visit(const luci::CircleAveragePool2D *node) final;
+ // loco::TensorShape visit(const luci::CircleBatchMatMul *node) final;
+ // loco::TensorShape visit(const luci::CircleBatchToSpaceND *node) final;
+ // loco::TensorShape visit(const luci::CircleCast *node) final;
+ // loco::TensorShape visit(const luci::CircleCeil *node) final;
+ // loco::TensorShape visit(const luci::CircleConcatenation *node) final;
+ // loco::TensorShape visit(const luci::CircleConst *node) final;
+ // loco::TensorShape visit(const luci::CircleConv2D *node) final;
+ // loco::TensorShape visit(const luci::CircleCos *node) final;
+ // loco::TensorShape visit(const luci::CircleCustom *node) final;
+ // loco::TensorShape visit(const luci::CircleDepthToSpace *node) final;
+ // loco::TensorShape visit(const luci::CircleDepthwiseConv2D *node) final;
+ // loco::TensorShape visit(const luci::CircleDequantize *node) final;
+ // loco::TensorShape visit(const luci::CircleDiv *node) final;
+ // loco::TensorShape visit(const luci::CircleElu *node) final;
+ // loco::TensorShape visit(const luci::CircleEqual *node) final;
+ // loco::TensorShape visit(const luci::CircleExp *node) final;
+ // loco::TensorShape visit(const luci::CircleExpandDims *node) final;
+ // loco::TensorShape visit(const luci::CircleFill *node) final;
+ // loco::TensorShape visit(const luci::CircleFloor *node) final;
+ // loco::TensorShape visit(const luci::CircleFloorDiv *node) final;
+ // loco::TensorShape visit(const luci::CircleFloorMod *node) final;
+ // loco::TensorShape visit(const luci::CircleFullyConnected *node) final;
+ // loco::TensorShape visit(const luci::CircleGather *node) final;
+ // loco::TensorShape visit(const luci::CircleGatherNd *node) final;
+ // loco::TensorShape visit(const luci::CircleGreater *node) final;
+ // loco::TensorShape visit(const luci::CircleGreaterEqual *node) final;
+ // loco::TensorShape visit(const luci::CircleIf *node) final;
+ // loco::TensorShape visit(const luci::CircleL2Normalize *node) final;
+ // loco::TensorShape visit(const luci::CircleL2Pool2D *node) final;
+ // loco::TensorShape visit(const luci::CircleLeakyRelu *node) final;
+ // loco::TensorShape visit(const luci::CircleLess *node) final;
+ // loco::TensorShape visit(const luci::CircleLessEqual *node) final;
+ // loco::TensorShape visit(const luci::CircleLocalResponseNormalization *node) final;
+ // loco::TensorShape visit(const luci::CircleLog *node) final;
+ // loco::TensorShape visit(const luci::CircleLogicalAnd *node) final;
+ // loco::TensorShape visit(const luci::CircleLogicalNot *node) final;
+ // loco::TensorShape visit(const luci::CircleLogicalOr *node) final;
+ // loco::TensorShape visit(const luci::CircleLogistic *node) final;
+ // loco::TensorShape visit(const luci::CircleLogSoftmax *node) final;
+ // loco::TensorShape visit(const luci::CircleMatrixDiag *node) final;
+ // loco::TensorShape visit(const luci::CircleMatrixSetDiag *node) final;
+ // loco::TensorShape visit(const luci::CircleMaximum *node) final;
+ // loco::TensorShape visit(const luci::CircleMaxPool2D *node) final;
+ // loco::TensorShape visit(const luci::CircleMean *node) final;
+ // loco::TensorShape visit(const luci::CircleMinimum *node) final;
+ // loco::TensorShape visit(const luci::CircleMirrorPad *node) final;
+ // loco::TensorShape visit(const luci::CircleNeg *node) final;
+ // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4 *node) final;
+ // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5 *node) final;
+ // loco::TensorShape visit(const luci::CircleNotEqual *node) final;
+ // loco::TensorShape visit(const luci::CirclePack *node) final;
+ // loco::TensorShape visit(const luci::CirclePad *node) final;
+ // loco::TensorShape visit(const luci::CirclePadV2 *node) final;
+ // loco::TensorShape visit(const luci::CirclePow *node) final;
+ // loco::TensorShape visit(const luci::CirclePRelu *node) final;
+ // loco::TensorShape visit(const luci::CircleRange *node) final;
+ // loco::TensorShape visit(const luci::CircleRank *node) final;
+ // loco::TensorShape visit(const luci::CircleMul *node) final;
+ // loco::TensorShape visit(const luci::CircleOneHot *node) final;
+ // loco::TensorShape visit(const luci::CircleReduceAny *node) final;
+ // loco::TensorShape visit(const luci::CircleReduceMax *node) final;
+ // loco::TensorShape visit(const luci::CircleReduceMin *node) final;
+ // loco::TensorShape visit(const luci::CircleReduceProd *node) final;
+ // loco::TensorShape visit(const luci::CircleRelu *node) final;
+ // loco::TensorShape visit(const luci::CircleRelu6 *node) final;
+ // loco::TensorShape visit(const luci::CircleReluN1To1 *node) final;
+ // loco::TensorShape visit(const luci::CircleReshape *node) final;
+ // loco::TensorShape visit(const luci::CircleResizeBilinear *node) final;
+ // loco::TensorShape visit(const luci::CircleResizeNearestNeighbor *node) final;
+ // loco::TensorShape visit(const luci::CircleReverseSequence *node) final;
+ // loco::TensorShape visit(const luci::CircleReverseV2 *node) final;
+ // loco::TensorShape visit(const luci::CircleRound *node) final;
+ // loco::TensorShape visit(const luci::CircleRsqrt *node) final;
+ // loco::TensorShape visit(const luci::CircleScatterNd *node) final;
+ // loco::TensorShape visit(const luci::CircleSegmentSum *node) final;
+ // loco::TensorShape visit(const luci::CircleSelect *node) final;
+ // loco::TensorShape visit(const luci::CircleSelectV2 *node) final;
+ // loco::TensorShape visit(const luci::CircleShape *node) final;
+ // loco::TensorShape visit(const luci::CircleSin *node) final;
+ // loco::TensorShape visit(const luci::CircleSlice *node) final;
+ // loco::TensorShape visit(const luci::CircleSoftmax *node) final;
+ // loco::TensorShape visit(const luci::CircleSpaceToBatchND *node) final;
+ // loco::TensorShape visit(const luci::CircleSpaceToDepth *node) final;
+ // loco::TensorShape visit(const luci::CircleSparseToDense *node) final;
+ // loco::TensorShape visit(const luci::CircleSplit *node) final;
+ // loco::TensorShape visit(const luci::CircleSplitV *node) final;
+ // loco::TensorShape visit(const luci::CircleSqrt *node) final;
+ // loco::TensorShape visit(const luci::CircleSquare *node) final;
+ // loco::TensorShape visit(const luci::CircleSquaredDifference *node) final;
+ // loco::TensorShape visit(const luci::CircleSqueeze *node) final;
+ // loco::TensorShape visit(const luci::CircleStridedSlice *node) final;
+ // loco::TensorShape visit(const luci::CircleSub *node) final;
+ // loco::TensorShape visit(const luci::CircleSum *node) final;
+ // loco::TensorShape visit(const luci::CircleTanh *node) final;
+ // loco::TensorShape visit(const luci::CircleTile *node) final;
+ // loco::TensorShape visit(const luci::CircleTopKV2 *node) final;
+ // loco::TensorShape visit(const luci::CircleTranspose *node) final;
+ // loco::TensorShape visit(const luci::CircleTransposeConv *node) final;
+ // loco::TensorShape visit(const luci::CircleUnidirectionalSequenceLSTM *node) final;
+ // loco::TensorShape visit(const luci::CircleUnique *node) final;
+ // loco::TensorShape visit(const luci::CircleUnpack *node) final;
+ // loco::TensorShape visit(const luci::CircleWhere *node) final;
+ // loco::TensorShape visit(const luci::CircleWhile *node) final;
+ // loco::TensorShape visit(const luci::CircleZerosLike *node) final;
+
+ // Circle Only
+ // loco::TensorShape visit(const luci::CircleBCQFullyConnected *node) final;
+ // loco::TensorShape visit(const luci::CircleBCQGather *node) final;
+ // loco::TensorShape visit(const luci::CircleInstanceNorm *node) final;
+
+ // Virtual
+ // loco::TensorShape visit(const luci::CircleInput *node) final;
+ // loco::TensorShape visit(const luci::CircleOutput *node) final;
+ // loco::TensorShape visit(const luci::CircleOutputDummy *node) final;
+ // loco::TensorShape visit(const luci::CircleOutputExclude *node) final;
+ // loco::TensorShape visit(const luci::CircleCustomOut *node) final;
+ // loco::TensorShape visit(const luci::CircleIfOut *node) final;
+ // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final;
+ // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5Out *node) final;
+ // loco::TensorShape visit(const luci::CircleSplitOut *node) final;
+ // loco::TensorShape visit(const luci::CircleSplitVOut *node) final;
+ // loco::TensorShape visit(const luci::CircleTopKV2Out *node) final;
+ // loco::TensorShape visit(const luci::CircleUniqueOut *node) final;
+ // loco::TensorShape visit(const luci::CircleUnpackOut *node) final;
+ // loco::TensorShape visit(const luci::CircleWhileOut *node) final;
+};
+
+} // namespace sinf
+
} // namespace luci
#endif // __LUCI_CIRCLE_SHAPE_INFERENCE_H__
diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h b/compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h
new file mode 100644
index 000000000..dd6a5a454
--- /dev/null
+++ b/compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CIRCLE_SHAPE_INFERENCE_HELPER_H__
+#define __LUCI_CIRCLE_SHAPE_INFERENCE_HELPER_H__
+
+#include <loco/IR/TensorShape.h>
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleShapeSignature.h>
+
+namespace luci
+{
+namespace sinf // Namespace for Shape Inference
+{
+
+// Return shape of circle node as loco::TensorShape
+loco::TensorShape circle_shape(const luci::CircleNode *node);
+
+} // namespace sinf
+} // namespace luci
+
+#endif // __LUCI_CIRCLE_SHAPE_INFERENCE_HELPER_H__
diff --git a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceRule.h b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h
index 4d1d83012..f7ea89bb8 100644
--- a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceRule.h
+++ b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h
@@ -14,22 +14,26 @@
* limitations under the License.
*/
-#ifndef __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_RULE_H__
-#define __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_RULE_H__
+#ifndef __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__
+#define __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
#include <luci/IR/CircleShapeSignature.h>
+#include <luci/Service/CircleShapeSignatureInferenceHelper.h>
namespace luci
{
-struct CircleShapeSignatureInferenceRule
+namespace ssinf // namespace for Shape Signature Inference
+{
+
+struct Rule
{
bool infer(const luci::CircleNode *, ShapeSignature &) const;
};
-class ShapeSignatureInferenceAlgorithm final : public luci::CircleNodeVisitor<ShapeSignature>
+class Algorithm final : public luci::CircleNodeVisitor<ShapeSignature>
{
public:
// TODO Remove this when visit function is implemented for all the operations.
@@ -84,7 +88,7 @@ public:
// ShapeSignature visit(const luci::CircleMatrixSetDiag *node) final;
// ShapeSignature visit(const luci::CircleMaximum *node) final;
// ShapeSignature visit(const luci::CircleMaxPool2D *node) final;
- // ShapeSignature visit(const luci::CircleMean *node) final;
+ ShapeSignature visit(const luci::CircleMean *node) final;
// ShapeSignature visit(const luci::CircleMinimum *node) final;
// ShapeSignature visit(const luci::CircleMirrorPad *node) final;
// ShapeSignature visit(const luci::CircleNeg *node) final;
@@ -100,13 +104,13 @@ public:
// ShapeSignature visit(const luci::CircleRank *node) final;
// ShapeSignature visit(const luci::CircleMul *node) final;
// ShapeSignature visit(const luci::CircleOneHot *node) final;
- // ShapeSignature visit(const luci::CircleReduceAny *node) final;
- // ShapeSignature visit(const luci::CircleReduceMax *node) final;
- // ShapeSignature visit(const luci::CircleReduceMin *node) final;
- // ShapeSignature visit(const luci::CircleReduceProd *node) final;
- // ShapeSignature visit(const luci::CircleRelu *node) final;
- // ShapeSignature visit(const luci::CircleRelu6 *node) final;
- // ShapeSignature visit(const luci::CircleReluN1To1 *node) final;
+ ShapeSignature visit(const luci::CircleReduceAny *node) final;
+ ShapeSignature visit(const luci::CircleReduceMax *node) final;
+ ShapeSignature visit(const luci::CircleReduceMin *node) final;
+ ShapeSignature visit(const luci::CircleReduceProd *node) final;
+ ShapeSignature visit(const luci::CircleRelu *node) final;
+ ShapeSignature visit(const luci::CircleRelu6 *node) final;
+ ShapeSignature visit(const luci::CircleReluN1To1 *node) final;
// ShapeSignature visit(const luci::CircleReshape *node) final;
// ShapeSignature visit(const luci::CircleResizeBilinear *node) final;
// ShapeSignature visit(const luci::CircleResizeNearestNeighbor *node) final;
@@ -133,7 +137,7 @@ public:
// ShapeSignature visit(const luci::CircleSqueeze *node) final;
// ShapeSignature visit(const luci::CircleStridedSlice *node) final;
// ShapeSignature visit(const luci::CircleSub *node) final;
- // ShapeSignature visit(const luci::CircleSum *node) final;
+ ShapeSignature visit(const luci::CircleSum *node) final;
// ShapeSignature visit(const luci::CircleTanh *node) final;
// ShapeSignature visit(const luci::CircleTile *node) final;
// ShapeSignature visit(const luci::CircleTopKV2 *node) final;
@@ -152,10 +156,10 @@ public:
// ShapeSignature visit(const luci::CircleInstanceNorm *node) final;
// Virtual
- // ShapeSignature visit(const luci::CircleInput *node) final;
- // ShapeSignature visit(const luci::CircleOutput *node) final;
- // ShapeSignature visit(const luci::CircleOutputDummy *node) final;
- // ShapeSignature visit(const luci::CircleOutputExclude *node) final;
+ ShapeSignature visit(const luci::CircleInput *node) final;
+ ShapeSignature visit(const luci::CircleOutput *node) final;
+ ShapeSignature visit(const luci::CircleOutputDummy *node) final;
+ ShapeSignature visit(const luci::CircleOutputExclude *node) final;
// ShapeSignature visit(const luci::CircleCustomOut *node) final;
// ShapeSignature visit(const luci::CircleIfOut *node) final;
// ShapeSignature visit(const luci::CircleNonMaxSuppressionV4Out *node) final;
@@ -168,6 +172,8 @@ public:
// ShapeSignature visit(const luci::CircleWhileOut *node) final;
};
+} // namespace ssinf
+
} // namespace luci
-#endif // __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_RULE_H__
+#endif // __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__
diff --git a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h
new file mode 100644
index 000000000..fb5b3b302
--- /dev/null
+++ b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__
+#define __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleShapeSignature.h>
+
+namespace luci
+{
+
+namespace ssinf // Namespace for Shape Signature Inference
+{
+
+// Return empty signature if all of dimensions are known.
+// If at least one of dimensions is unknown, return signature without change.
+ShapeSignature legalized_signature(const luci::ShapeSignature &signature);
+
+// Return reduced input_signature with indices and keep_dims.
+// - indices : reduction index
+// - keep_dims : If true, rank is not changed. If false, rank is reduced along indices.
+ShapeSignature reduced_signature(const loco::Node *node, const loco::Node *indices, bool keep_dims);
+
+// Return signature of index-th argument of node.
+ShapeSignature input_arg_signature(const luci::CircleNode *node, uint32_t index);
+
+} // namespace ssinf
+
+} // namespace luci
+
+#endif // __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__
diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInference.h b/compiler/luci/service/include/luci/Service/CircleTypeInference.h
index ea7a3c5ed..342214887 100644
--- a/compiler/luci/service/include/luci/Service/CircleTypeInference.h
+++ b/compiler/luci/service/include/luci/Service/CircleTypeInference.h
@@ -21,6 +21,10 @@
#include <mio/circle/schema_generated.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Service/CircleTypeInferenceHelper.h>
+
namespace luci
{
@@ -37,6 +41,155 @@ struct TypeInference
static circle::TensorType get(loco::Node *node);
};
+namespace tinf // namespace for Type Inference
+{
+
+struct Rule
+{
+ bool infer(const luci::CircleNode *, loco::DataType &) const;
+};
+
+class Algorithm final : public luci::CircleNodeVisitor<loco::DataType>
+{
+public:
+ // TODO Remove this when all of visit function is implemented
+ loco::DataType visit(const luci::CircleNode *node) final { return node->dtype(); }
+
+ // loco::DataType visit(const luci::CircleAbs *node) final;
+ // loco::DataType visit(const luci::CircleAdd *node) final;
+ // loco::DataType visit(const luci::CircleAddN *node) final;
+ // loco::DataType visit(const luci::CircleArgMax *node) final;
+ // loco::DataType visit(const luci::CircleArgMin *node) final;
+ // loco::DataType visit(const luci::CircleAveragePool2D *node) final;
+ // loco::DataType visit(const luci::CircleBatchMatMul *node) final;
+ // loco::DataType visit(const luci::CircleBatchToSpaceND *node) final;
+ // loco::DataType visit(const luci::CircleCast *node) final;
+ // loco::DataType visit(const luci::CircleCeil *node) final;
+ // loco::DataType visit(const luci::CircleConcatenation *node) final;
+ // loco::DataType visit(const luci::CircleConst *node) final;
+ // loco::DataType visit(const luci::CircleConv2D *node) final;
+ // loco::DataType visit(const luci::CircleCos *node) final;
+ // loco::DataType visit(const luci::CircleCustom *node) final;
+ // loco::DataType visit(const luci::CircleDepthToSpace *node) final;
+ // loco::DataType visit(const luci::CircleDepthwiseConv2D *node) final;
+ // loco::DataType visit(const luci::CircleDequantize *node) final;
+ // loco::DataType visit(const luci::CircleDiv *node) final;
+ // loco::DataType visit(const luci::CircleElu *node) final;
+ // loco::DataType visit(const luci::CircleEqual *node) final;
+ // loco::DataType visit(const luci::CircleExp *node) final;
+ // loco::DataType visit(const luci::CircleExpandDims *node) final;
+ // loco::DataType visit(const luci::CircleFill *node) final;
+ // loco::DataType visit(const luci::CircleFloor *node) final;
+ // loco::DataType visit(const luci::CircleFloorDiv *node) final;
+ // loco::DataType visit(const luci::CircleFloorMod *node) final;
+ // loco::DataType visit(const luci::CircleFullyConnected *node) final;
+ // loco::DataType visit(const luci::CircleGather *node) final;
+ // loco::DataType visit(const luci::CircleGatherNd *node) final;
+ // loco::DataType visit(const luci::CircleGreater *node) final;
+ // loco::DataType visit(const luci::CircleGreaterEqual *node) final;
+ // loco::DataType visit(const luci::CircleIf *node) final;
+ // loco::DataType visit(const luci::CircleL2Normalize *node) final;
+ // loco::DataType visit(const luci::CircleL2Pool2D *node) final;
+ // loco::DataType visit(const luci::CircleLeakyRelu *node) final;
+ // loco::DataType visit(const luci::CircleLess *node) final;
+ // loco::DataType visit(const luci::CircleLessEqual *node) final;
+ // loco::DataType visit(const luci::CircleLocalResponseNormalization *node) final;
+ // loco::DataType visit(const luci::CircleLog *node) final;
+ // loco::DataType visit(const luci::CircleLogicalAnd *node) final;
+ // loco::DataType visit(const luci::CircleLogicalNot *node) final;
+ // loco::DataType visit(const luci::CircleLogicalOr *node) final;
+ // loco::DataType visit(const luci::CircleLogistic *node) final;
+ // loco::DataType visit(const luci::CircleLogSoftmax *node) final;
+ // loco::DataType visit(const luci::CircleMatrixDiag *node) final;
+ // loco::DataType visit(const luci::CircleMatrixSetDiag *node) final;
+ // loco::DataType visit(const luci::CircleMaximum *node) final;
+ // loco::DataType visit(const luci::CircleMaxPool2D *node) final;
+ // loco::DataType visit(const luci::CircleMean *node) final;
+ // loco::DataType visit(const luci::CircleMinimum *node) final;
+ // loco::DataType visit(const luci::CircleMirrorPad *node) final;
+ // loco::DataType visit(const luci::CircleNeg *node) final;
+ // loco::DataType visit(const luci::CircleNonMaxSuppressionV4 *node) final;
+ // loco::DataType visit(const luci::CircleNonMaxSuppressionV5 *node) final;
+ // loco::DataType visit(const luci::CircleNotEqual *node) final;
+ // loco::DataType visit(const luci::CirclePack *node) final;
+ // loco::DataType visit(const luci::CirclePad *node) final;
+ // loco::DataType visit(const luci::CirclePadV2 *node) final;
+ // loco::DataType visit(const luci::CirclePow *node) final;
+ // loco::DataType visit(const luci::CirclePRelu *node) final;
+ // loco::DataType visit(const luci::CircleRange *node) final;
+ // loco::DataType visit(const luci::CircleRank *node) final;
+ // loco::DataType visit(const luci::CircleMul *node) final;
+ // loco::DataType visit(const luci::CircleOneHot *node) final;
+ // loco::DataType visit(const luci::CircleReduceAny *node) final;
+ // loco::DataType visit(const luci::CircleReduceMax *node) final;
+ // loco::DataType visit(const luci::CircleReduceMin *node) final;
+ // loco::DataType visit(const luci::CircleReduceProd *node) final;
+ // loco::DataType visit(const luci::CircleRelu *node) final;
+ // loco::DataType visit(const luci::CircleRelu6 *node) final;
+ // loco::DataType visit(const luci::CircleReluN1To1 *node) final;
+ // loco::DataType visit(const luci::CircleReshape *node) final;
+ // loco::DataType visit(const luci::CircleResizeBilinear *node) final;
+ // loco::DataType visit(const luci::CircleResizeNearestNeighbor *node) final;
+ // loco::DataType visit(const luci::CircleReverseSequence *node) final;
+ // loco::DataType visit(const luci::CircleReverseV2 *node) final;
+ // loco::DataType visit(const luci::CircleRound *node) final;
+ // loco::DataType visit(const luci::CircleRsqrt *node) final;
+ // loco::DataType visit(const luci::CircleScatterNd *node) final;
+ // loco::DataType visit(const luci::CircleSegmentSum *node) final;
+ // loco::DataType visit(const luci::CircleSelect *node) final;
+ // loco::DataType visit(const luci::CircleSelectV2 *node) final;
+ // loco::DataType visit(const luci::CircleShape *node) final;
+ // loco::DataType visit(const luci::CircleSin *node) final;
+ // loco::DataType visit(const luci::CircleSlice *node) final;
+ // loco::DataType visit(const luci::CircleSoftmax *node) final;
+ // loco::DataType visit(const luci::CircleSpaceToBatchND *node) final;
+ // loco::DataType visit(const luci::CircleSpaceToDepth *node) final;
+ // loco::DataType visit(const luci::CircleSparseToDense *node) final;
+ // loco::DataType visit(const luci::CircleSplit *node) final;
+ // loco::DataType visit(const luci::CircleSplitV *node) final;
+ // loco::DataType visit(const luci::CircleSqrt *node) final;
+ // loco::DataType visit(const luci::CircleSquare *node) final;
+ // loco::DataType visit(const luci::CircleSquaredDifference *node) final;
+ // loco::DataType visit(const luci::CircleSqueeze *node) final;
+ // loco::DataType visit(const luci::CircleStridedSlice *node) final;
+ // loco::DataType visit(const luci::CircleSub *node) final;
+ // loco::DataType visit(const luci::CircleSum *node) final;
+ // loco::DataType visit(const luci::CircleTanh *node) final;
+ // loco::DataType visit(const luci::CircleTile *node) final;
+ // loco::DataType visit(const luci::CircleTopKV2 *node) final;
+ // loco::DataType visit(const luci::CircleTranspose *node) final;
+ // loco::DataType visit(const luci::CircleTransposeConv *node) final;
+ // loco::DataType visit(const luci::CircleUnidirectionalSequenceLSTM *node) final;
+ // loco::DataType visit(const luci::CircleUnique *node) final;
+ // loco::DataType visit(const luci::CircleUnpack *node) final;
+ // loco::DataType visit(const luci::CircleWhere *node) final;
+ // loco::DataType visit(const luci::CircleWhile *node) final;
+ // loco::DataType visit(const luci::CircleZerosLike *node) final;
+
+ // Circle Only
+ // loco::DataType visit(const luci::CircleBCQFullyConnected *node) final;
+ // loco::DataType visit(const luci::CircleBCQGather *node) final;
+ // loco::DataType visit(const luci::CircleInstanceNorm *node) final;
+
+ // Virtual
+ // loco::DataType visit(const luci::CircleInput *node) final;
+ // loco::DataType visit(const luci::CircleOutput *node) final;
+ // loco::DataType visit(const luci::CircleOutputDummy *node) final;
+ // loco::DataType visit(const luci::CircleOutputExclude *node) final;
+ // loco::DataType visit(const luci::CircleCustomOut *node) final;
+ // loco::DataType visit(const luci::CircleIfOut *node) final;
+ // loco::DataType visit(const luci::CircleNonMaxSuppressionV4Out *node) final;
+ // loco::DataType visit(const luci::CircleNonMaxSuppressionV5Out *node) final;
+ // loco::DataType visit(const luci::CircleSplitOut *node) final;
+ // loco::DataType visit(const luci::CircleSplitVOut *node) final;
+ // loco::DataType visit(const luci::CircleTopKV2Out *node) final;
+ // loco::DataType visit(const luci::CircleUniqueOut *node) final;
+ // loco::DataType visit(const luci::CircleUnpackOut *node) final;
+ // loco::DataType visit(const luci::CircleWhileOut *node) final;
+};
+
+} // namespace tinf
+
} // namespace luci
#endif // __LUCI_CIRCLE_TYPE_INFERENCE_H__
diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h b/compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h
new file mode 100644
index 000000000..296f99355
--- /dev/null
+++ b/compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CIRCLE_TYPE_INFERENCE_HELPER_H__
+#define __LUCI_CIRCLE_TYPE_INFERENCE_HELPER_H__
+
+#include <luci/IR/CircleNodes.h>
+
+#include <loco/IR/DataType.h>
+
+namespace luci
+{
+namespace tinf // Namespace for Type Inference
+{
+
+// Helper function will be added
+
+} // namespace tinf
+} // namespace luci
+
+#endif // __LUCI_CIRCLE_TYPE_INFERENCE_HELPER_H__
diff --git a/compiler/luci/service/include/luci/Service/ShapeDescription.h b/compiler/luci/service/include/luci/Service/ShapeDescription.h
index 949cce535..4d92be13f 100644
--- a/compiler/luci/service/include/luci/Service/ShapeDescription.h
+++ b/compiler/luci/service/include/luci/Service/ShapeDescription.h
@@ -20,6 +20,8 @@
#include <loco/IR/PermutingCodec.h>
#include <loco/IR/NodeShape.h>
+#include <luci/IR/CircleNodes.h>
+
#include <cstdint>
#include <vector>
@@ -33,6 +35,7 @@ struct ShapeDescription
};
// TODO remove these when CircleDialect is fully functioal
+ShapeDescription to_shape_description(const luci::CircleNode *node);
ShapeDescription to_shape_description(const loco::TensorShape &shape);
ShapeDescription to_shape_description(const loco::FeatureShape &shape);
ShapeDescription to_shape_description(const loco::FilterShape &shape);
diff --git a/compiler/luci/service/src/CircleShapeInference.cpp b/compiler/luci/service/src/CircleShapeInference.cpp
index 0732849db..db8ffd8ad 100644
--- a/compiler/luci/service/src/CircleShapeInference.cpp
+++ b/compiler/luci/service/src/CircleShapeInference.cpp
@@ -20,7 +20,10 @@
#include <loco.h>
#include <loco/Service/ShapeInference.h>
+#include <luci/Log.h>
+
#include <cassert>
+#include <iostream>
namespace luci
{
@@ -32,3 +35,60 @@ ShapeDescription ShapeInference::get(loco::Node *node)
}
} // namespace luci
+
+namespace
+{
+
+std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape)
+{
+ os << "[";
+ for (uint32_t r = 0; r < tensor_shape.rank(); ++r)
+ {
+ if (r)
+ os << ",";
+ os << tensor_shape.dim(r).value();
+ }
+ os << "]";
+ return os;
+}
+
+bool inputs_shape_ready(const luci::CircleNode *node)
+{
+ for (uint32_t arity = 0; arity < node->arity(); ++arity)
+ {
+ auto node_input = loco::must_cast<luci::CircleNode *>(node->arg(arity));
+ if (node_input->shape_status() == luci::ShapeStatus::UNDEFINED)
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+namespace sinf
+{
+
+bool Rule::infer(const luci::CircleNode *circle_node, loco::TensorShape &shape) const
+{
+ LOGGER(l);
+ VERBOSE(l, 1) << "[CircleShapeInference] " << circle_node->name();
+ VERBOSE(l, 1) << " before: " << circle_shape(circle_node);
+
+ if (!inputs_shape_ready(circle_node))
+ {
+ VERBOSE(l, 1) << " after: Some inputs are not ready for inference";
+ return false;
+ }
+
+ Algorithm alg;
+ shape = circle_node->accept(&alg);
+ VERBOSE(l, 1) << " after: " << shape;
+
+ return true;
+}
+
+} // namespace ssinf
+} // namespace luci
diff --git a/compiler/luci/service/src/CircleShapeInferenceHelper.cpp b/compiler/luci/service/src/CircleShapeInferenceHelper.cpp
new file mode 100644
index 000000000..f7eb6c3ec
--- /dev/null
+++ b/compiler/luci/service/src/CircleShapeInferenceHelper.cpp
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleShapeInferenceHelper.h"
+
+namespace luci
+{
+namespace sinf
+{
+
+loco::TensorShape circle_shape(const luci::CircleNode *node)
+{
+ loco::TensorShape shape;
+ shape.rank(node->rank());
+ for (uint32_t r = 0; r < node->rank(); ++r)
+ shape.dim(r) = loco::Dimension(node->dim(r).value());
+ return shape;
+}
+
+} // namespace sinf
+} // namespace luci
diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
index a55f50b19..38ff619ab 100644
--- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
@@ -102,7 +102,7 @@ private:
};
/**
- * @breif Expand shape x and y to same rank by align right and filling with 1
+ * @brief Expand shape x and y to same rank by align right and filling with 1
*/
void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
{
@@ -122,7 +122,7 @@ void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
}
/**
- * @breif Returns shape of expanded dimension of input x and y having same rank
+ * @brief Returns shape of expanded dimension of input x and y having same rank
*/
loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
{
diff --git a/compiler/luci/service/src/CircleShapeSignatureInferenceRule.cpp b/compiler/luci/service/src/CircleShapeSignatureInference.cpp
index dc7df3e39..1ccaa19d5 100644
--- a/compiler/luci/service/src/CircleShapeSignatureInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleShapeSignatureInference.cpp
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "luci/Service/CircleShapeSignatureInferenceRule.h"
+#include "luci/Service/CircleShapeSignatureInference.h"
#include <luci/Log.h>
@@ -39,14 +39,16 @@ std::ostream &operator<<(std::ostream &os, const luci::ShapeSignature &shape_sig
namespace luci
{
-bool CircleShapeSignatureInferenceRule::infer(const luci::CircleNode *circle_node,
- ShapeSignature &shape_signature) const
+namespace ssinf
+{
+
+bool Rule::infer(const luci::CircleNode *circle_node, ShapeSignature &shape_signature) const
{
LOGGER(l);
// There is nothing to check before ShapeSignatureInference.
- ShapeSignatureInferenceAlgorithm alg;
+ Algorithm alg;
shape_signature = circle_node->accept(&alg);
@@ -57,4 +59,6 @@ bool CircleShapeSignatureInferenceRule::infer(const luci::CircleNode *circle_nod
return true;
}
+} // namespace ssinf
+
} // namespace luci
diff --git a/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp b/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp
new file mode 100644
index 000000000..d7d1a24e8
--- /dev/null
+++ b/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp
@@ -0,0 +1,160 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleShapeSignatureInferenceHelper.h"
+
+#include <loco.h>
+
+#include <luci/Log.h>
+
+#include <oops/InternalExn.h>
+
+namespace luci
+{
+
+namespace ssinf
+{
+
+luci::ShapeSignature legalized_signature(const luci::ShapeSignature &signature)
+{
+ // If shape signature has at least one -1, it is not static.
+ for (uint32_t i = 0; i < signature.rank(); ++i)
+ if (signature.dim(i) == -1)
+ return signature;
+
+ // If all dimensions are static, return empty shape signature.
+ return luci::ShapeSignature();
+}
+
+ShapeSignature reduced_signature(const loco::Node *node, const loco::Node *indices, bool keep_dims)
+{
+ LOGGER(l);
+
+ ShapeSignature input_signature;
+ ShapeSignature output_signature;
+
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ if (circle_node->shape_signature().rank() > 0)
+ input_signature = circle_node->shape_signature();
+ else
+ {
+ input_signature.rank(circle_node->rank());
+ for (uint32_t i = 0; i < circle_node->rank(); ++i)
+ input_signature.dim(i) = circle_node->dim(i).value();
+ }
+
+ // If input rank is 0, it means that one of following case is occurred.
+ // - Input is scalar : result is always scalar
+ // - Input shape signature is not inferenced : cannot infer output shape signauture
+ // Therefore, when input signature rank is 0, always return empty signature.
+ if (input_signature.rank() == 0)
+ return output_signature;
+
+ // When reduction_indices is not constant
+ auto reduction_indices = dynamic_cast<const luci::CircleConst *>(indices);
+ if (reduction_indices == nullptr)
+ {
+ if (keep_dims)
+ {
+ // If keep_dims is true, rank is not changed.
+ output_signature.rank(input_signature.rank());
+ for (uint32_t i = 0; i < output_signature.rank(); ++i)
+ output_signature.dim(i) = -1;
+ }
+ else
+ {
+ // There is no way to inference for this case.
+ // Do nothing to return empty signature.
+ INFO(l) << "[CircleShapeSignatureInferenceHelper] " << circle_node->name() << std::endl;
+ INFO(l) << " reduced_signature : cannot infer because of non-constant node" << std::endl;
+ }
+
+ return output_signature;
+ }
+
+ std::vector<int32_t> reduction_values;
+ if (reduction_indices->dtype() == loco::DataType::S32)
+ {
+ auto reduction_size = reduction_indices->size<loco::DataType::S32>();
+ for (uint32_t i = 0; i < reduction_size; ++i)
+ {
+ int32_t axis = reduction_indices->at<loco::DataType::S32>(i);
+ if (axis < 0)
+ axis += input_signature.rank();
+
+ if (!(0 <= axis && axis < static_cast<int32_t>(input_signature.rank())))
+ INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis));
+
+ reduction_values.push_back(axis);
+ }
+ }
+ else if (reduction_indices->dtype() == loco::DataType::S64)
+ {
+ auto reduction_size = reduction_indices->size<loco::DataType::S64>();
+ for (uint32_t i = 0; i < reduction_size; ++i)
+ {
+ int32_t axis = static_cast<int32_t>(reduction_indices->at<loco::DataType::S64>(i));
+ if (axis < 0)
+ axis += input_signature.rank();
+
+ if (!(0 <= axis && axis < static_cast<int32_t>(input_signature.rank())))
+ INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis));
+
+ reduction_values.push_back(axis);
+ }
+ }
+ else
+ {
+ INTERNAL_EXN("Wrong reduction axis type, Only INT32, INT64 supported.");
+ }
+
+ if (keep_dims)
+ {
+ output_signature.rank(input_signature.rank());
+ for (uint32_t i = 0; i < input_signature.rank(); ++i)
+ output_signature.dim(i) = input_signature.dim(i);
+ for (uint32_t i = 0; i < reduction_values.size(); ++i)
+ output_signature.dim(reduction_values.at(i)) = 1;
+ }
+ else
+ {
+ std::vector<bool> check_reduce(input_signature.rank(), false);
+ for (uint32_t i = 0; i < reduction_values.size(); ++i)
+ check_reduce.at(reduction_values.at(i)) = true;
+
+ uint32_t reduce_cnt = 0;
+ for (uint32_t i = 0; i < check_reduce.size(); ++i)
+ if (check_reduce.at(i))
+ ++reduce_cnt;
+
+ output_signature.rank(input_signature.rank() - reduce_cnt);
+ for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
+ if (check_reduce.at(i) == false)
+ output_signature.dim(j++) = input_signature.dim(i);
+ }
+
+ return output_signature;
+}
+
+ShapeSignature input_arg_signature(const luci::CircleNode *node, uint32_t index)
+{
+ auto circle_input = loco::must_cast<luci::CircleNode *>(node->arg(index));
+ return circle_input->shape_signature();
+}
+
+} // namespace ssinf
+
+} // namespace luci
diff --git a/compiler/luci/service/src/CircleTypeInference.cpp b/compiler/luci/service/src/CircleTypeInference.cpp
index aa8524a55..b4755b51a 100644
--- a/compiler/luci/service/src/CircleTypeInference.cpp
+++ b/compiler/luci/service/src/CircleTypeInference.cpp
@@ -16,6 +16,8 @@
#include "luci/Service/CircleTypeInference.h"
+#include <luci/Log.h>
+
#include <loco.h>
#include <loco/Service/TypeInference.h>
@@ -70,3 +72,47 @@ circle::TensorType TypeInference::get(loco::Node *node)
}
} // namespace luci
+
+namespace
+{
+
+bool inputs_dtype_ready(const luci::CircleNode *node)
+{
+ for (uint32_t arity = 0; arity < node->arity(); ++arity)
+ {
+ if (node->dtype() == loco::DataType::Unknown)
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+namespace tinf
+{
+
+bool Rule::infer(const luci::CircleNode *circle_node, loco::DataType &dtype) const
+{
+ LOGGER(l);
+ VERBOSE(l, 1) << "[CircleTypeInference] " << circle_node->name();
+ VERBOSE(l, 1) << " before: " << static_cast<int>(circle_node->dtype());
+
+ if (!inputs_dtype_ready(circle_node))
+ {
+ VERBOSE(l, 1) << " after: Some inputs are not ready for inference";
+ return false;
+ }
+
+ Algorithm alg;
+ dtype = circle_node->accept(&alg);
+
+ VERBOSE(l, 1) << " after: " << static_cast<int>(dtype);
+
+ return true;
+}
+
+} // namespace tinf
+} // namespace luci
diff --git a/compiler/luci/service/src/CircleTypeInferenceHelper.cpp b/compiler/luci/service/src/CircleTypeInferenceHelper.cpp
new file mode 100644
index 000000000..75cd9f7b2
--- /dev/null
+++ b/compiler/luci/service/src/CircleTypeInferenceHelper.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleTypeInferenceHelper.h"
+
+namespace luci
+{
+namespace tinf
+{
+
+// Helper function will be added
+
+} // namespace tinf
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleInput.cpp b/compiler/luci/service/src/Nodes/CircleInput.cpp
new file mode 100644
index 000000000..24eab7bd6
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleInput.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleInput *node)
+{
+ return node->shape_signature();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleMean.cpp b/compiler/luci/service/src/Nodes/CircleMean.cpp
new file mode 100644
index 000000000..a78713698
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMean.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleMean *node)
+{
+ return legalized_signature(
+ reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleOutput.cpp b/compiler/luci/service/src/Nodes/CircleOutput.cpp
new file mode 100644
index 000000000..d4c8da2d8
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleOutput.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutput *node)
+{
+ return input_arg_signature(node, 0);
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp b/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp
new file mode 100644
index 000000000..e0f13c439
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp
@@ -0,0 +1,24 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutputDummy *) { return ShapeSignature(); }
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp b/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp
new file mode 100644
index 000000000..75bbbb3c0
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutputExclude *)
+{
+ return ShapeSignature();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReduceAny.cpp b/compiler/luci/service/src/Nodes/CircleReduceAny.cpp
new file mode 100644
index 000000000..27da81466
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReduceAny.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceAny *node)
+{
+ return legalized_signature(
+ reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReduceMax.cpp b/compiler/luci/service/src/Nodes/CircleReduceMax.cpp
new file mode 100644
index 000000000..48d9cb970
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReduceMax.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceMax *node)
+{
+ return legalized_signature(
+ reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReduceMin.cpp b/compiler/luci/service/src/Nodes/CircleReduceMin.cpp
new file mode 100644
index 000000000..9a9997118
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReduceMin.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceMin *node)
+{
+ return legalized_signature(
+ reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReduceProd.cpp b/compiler/luci/service/src/Nodes/CircleReduceProd.cpp
new file mode 100644
index 000000000..a9d381a74
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReduceProd.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceProd *node)
+{
+ return legalized_signature(
+ reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleRelu.cpp b/compiler/luci/service/src/Nodes/CircleRelu.cpp
new file mode 100644
index 000000000..a7a7f6f0a
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleRelu.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleRelu *node)
+{
+ return input_arg_signature(node, 0);
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleRelu6.cpp b/compiler/luci/service/src/Nodes/CircleRelu6.cpp
new file mode 100644
index 000000000..92a596d08
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleRelu6.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleRelu6 *node)
+{
+ return input_arg_signature(node, 0);
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp b/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp
new file mode 100644
index 000000000..1e8d9971d
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleReluN1To1 *node)
+{
+ return input_arg_signature(node, 0);
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSum.cpp b/compiler/luci/service/src/Nodes/CircleSum.cpp
new file mode 100644
index 000000000..9ef90e8e0
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSum.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleSum *node)
+{
+ return legalized_signature(
+ reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/ShapeDescription.cpp b/compiler/luci/service/src/ShapeDescription.cpp
index cbc302f70..01a638f8f 100644
--- a/compiler/luci/service/src/ShapeDescription.cpp
+++ b/compiler/luci/service/src/ShapeDescription.cpp
@@ -23,6 +23,19 @@
namespace luci
{
+ShapeDescription to_shape_description(const luci::CircleNode *circle_node)
+{
+ ShapeDescription res;
+
+ res._rank_known = true;
+
+ res._dims.resize(circle_node->rank());
+ for (uint32_t i = 0; i < circle_node->rank(); ++i)
+ res._dims.at(i) = circle_node->dim(i).value();
+
+ return res;
+}
+
ShapeDescription to_shape_description(const loco::TensorShape &shape)
{
ShapeDescription res;
diff --git a/compiler/luci/service/src/Validate.cpp b/compiler/luci/service/src/Validate.cpp
index d224fd172..3f732b6fe 100644
--- a/compiler/luci/service/src/Validate.cpp
+++ b/compiler/luci/service/src/Validate.cpp
@@ -42,6 +42,19 @@ std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape
return os;
}
+std::ostream &operator<<(std::ostream &os, const luci::CircleNode *circle_node)
+{
+ os << "[";
+ for (uint32_t r = 0; r < circle_node->rank(); ++r)
+ {
+ if (r)
+ os << ",";
+ os << circle_node->dim(r).value();
+ }
+ os << "]";
+ return os;
+}
+
/**
* @brief returns a node that is CircleOutput with index is out_index in nodes
*/
@@ -80,23 +93,28 @@ bool validate_shape_dtype(loco::Graph *g)
if (dynamic_cast<luci::CircleOutputExclude *>(circle_node))
continue;
- assert(loco::shape_known(circle_node));
+ assert(circle_node->shape_status() != luci::ShapeStatus::UNDEFINED);
// check if output node shape is same as graph output shape
- auto co_tensor_shape = loco::shape_get(circle_node).as<loco::TensorShape>();
auto go_tensor_shape = graph_out->shape();
assert(go_tensor_shape);
- if (!(co_tensor_shape == *go_tensor_shape))
+
+ bool is_shape_valid = (circle_node->rank() == go_tensor_shape->rank());
+ for (uint32_t i = 0; is_shape_valid && i < circle_node->rank(); ++i)
+ if (circle_node->dim(i).value() != go_tensor_shape->dim(i).value())
+ is_shape_valid = false;
+
+ if (is_shape_valid == false)
{
INFO(l) << "[luci] Shape for output #" << out_index << " not same " << std::endl;
- INFO(l) << "[luci] " << circle_node->name() << " " << co_tensor_shape << " vs "
+ INFO(l) << "[luci] " << circle_node->name() << " " << circle_node << " vs "
<< *go_tensor_shape << std::endl;
return false;
}
// check if data type match
- assert(loco::dtype_known(circle_node));
- if (graph_out->dtype() != loco::dtype_get(circle_node))
+ assert(circle_node->dtype() != loco::DataType::Unknown);
+ if (graph_out->dtype() != circle_node->dtype())
{
INFO(l) << "[luci] Type for output #" << out_index << " not same " << std::endl;
return false;
@@ -106,6 +124,55 @@ bool validate_shape_dtype(loco::Graph *g)
return true;
}
+bool validate_shape_signature(loco::Graph *g)
+{
+ LOGGER(l);
+
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ const auto shape_signature = circle_node->shape_signature();
+
+ if (shape_signature.rank() == 0)
+ continue;
+
+ // Rank of shape and shape signature should be same
+ if (circle_node->rank() != shape_signature.rank())
+ {
+ INFO(l) << "[luci] Rank of shape signature for " << circle_node->name() << " do not match"
+ << std::endl;
+ return false;
+ }
+
+ bool has_unknown = false;
+
+ // If shape siganture is not -1, dimension value should be same
+ for (uint32_t d = 0; d < shape_signature.rank(); ++d)
+ {
+ if (shape_signature.dim(d) != -1 &&
+ shape_signature.dim(d) != (int32_t)(circle_node->dim(d).value()))
+ {
+ INFO(l) << "[luci] Dimension " << d << "of shape signature for " << circle_node->name()
+ << " do not match" << std::endl;
+ return false;
+ }
+
+ if (shape_signature.dim(d) == -1)
+ has_unknown = true;
+ }
+
+ // Shape signature should have at least one -1 value.
+ if (!has_unknown)
+ {
+ INFO(l) << "[luci] Shape signature in " << circle_node->name()
+ << " do not have unknown dimension" << std::endl;
+ return false;
+ }
+ }
+
+ return true;
+}
+
} // namespace
namespace luci
@@ -119,6 +186,9 @@ bool validate(loco::Graph *g)
if (!validate_shape_dtype(g))
return false;
+ if (!validate_shape_signature(g))
+ return false;
+
// TODO add more validation
return true;
diff --git a/compiler/luci/tester/src/ReadTester.cpp b/compiler/luci/tester/src/ReadTester.cpp
index a1aead1bd..f270a232c 100644
--- a/compiler/luci/tester/src/ReadTester.cpp
+++ b/compiler/luci/tester/src/ReadTester.cpp
@@ -21,6 +21,9 @@
#include <luci/Pass/ShapeInferencePass.h>
#include <luci/Pass/TypeInferencePass.h>
+// Following passes will be removed after refactoring is finished
+#include <luci/Pass/MigrateLegacyShapeDtypePass.h>
+
#include <iostream>
#include <map>
#include <string>
@@ -95,6 +98,12 @@ int entry(int argc, char **argv)
while (pass.run(graph) == true)
;
}
+ {
+ // This pass will be removed after refactoring is finished
+ luci::MigrateLegacyShapeDtypePass pass;
+ while (pass.run(graph) == true)
+ ;
+ }
if (!luci::validate(graph))
return 255;
diff --git a/compiler/luci/tester/src/WriteTester.cpp b/compiler/luci/tester/src/WriteTester.cpp
index aa7085c77..9a6e8de05 100644
--- a/compiler/luci/tester/src/WriteTester.cpp
+++ b/compiler/luci/tester/src/WriteTester.cpp
@@ -23,6 +23,9 @@
#include <luci/CircleExporter.h>
#include <oops/InternalExn.h>
+// Following passes will be removed after refactoring is finished
+#include <luci/Pass/MigrateLegacyShapeDtypePass.h>
+
#include <fstream>
#include <iostream>
#include <map>
@@ -139,6 +142,12 @@ int entry(int argc, char **argv)
while (pass.run(graph) == true)
;
}
+ {
+ // This pass will be removed after refactoring is finished
+ luci::MigrateLegacyShapeDtypePass pass;
+ while (pass.run(graph) == true)
+ ;
+ }
if (!luci::validate(graph))
return 255;
diff --git a/compiler/moco/support/src/TFShapeInferenceHelper.cpp b/compiler/moco/support/src/TFShapeInferenceHelper.cpp
index 13e514a78..605fb9c37 100644
--- a/compiler/moco/support/src/TFShapeInferenceHelper.cpp
+++ b/compiler/moco/support/src/TFShapeInferenceHelper.cpp
@@ -66,7 +66,7 @@ private:
};
/**
- * @breif Expand shape x and y to same rank by align right and filling with 1
+ * @brief Expand shape x and y to same rank by align right and filling with 1
*/
void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
{
@@ -86,7 +86,7 @@ void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
}
/**
- * @breif Returns shape of expanded dimension of input x and y having same rank
+ * @brief Returns shape of expanded dimension of input x and y having same rank
*/
loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
{
diff --git a/compiler/nnc/include/Definitions.h.in b/compiler/nnc/include/Definitions.h.in
index 070cdd201..bd8642956 100644
--- a/compiler/nnc/include/Definitions.h.in
+++ b/compiler/nnc/include/Definitions.h.in
@@ -7,12 +7,12 @@
*/
/**
- * @breif absolute path to installation directory of *nnc* project
+ * @brief absolute path to installation directory of *nnc* project
*/
#define NNC_ROOT_PATH "@NNC_INSTALL_PATH@"
/**
- * @breif absolute path to directory contains libraries
+ * @brief absolute path to directory contains libraries
*/
#define NNC_LIB_PATH "@NNC_INSTALL_LIB_PATH@"
diff --git a/compiler/one-cmds/how-to-use-one-commands.txt b/compiler/one-cmds/how-to-use-one-commands.txt
index 62a497828..d4e3269e8 100644
--- a/compiler/one-cmds/how-to-use-one-commands.txt
+++ b/compiler/one-cmds/how-to-use-one-commands.txt
@@ -161,6 +161,7 @@ Current transformation options are
- make_batchnorm_gamma_positive: This makes negative gamma of batch normalization into a small positive value (1e-10).
Note that this pass can change the execution result of the model.
So, use it only when the impact is known to be acceptable.
+- replace_cw_mul_add_with_depthwise_conv: This will replace channel-wise Mul/Add with DepthwiseConv2D.
- resolve_customop_add: This will convert Custom(Add) to normal Add operator
- resolve_customop_batchmatmul: This will convert Custom(BatchMatMul) to
normal BatchMatMul operator
diff --git a/compiler/one-cmds/one-codegen b/compiler/one-cmds/one-codegen
index f2d82307c..fbe3d52d2 100644
--- a/compiler/one-cmds/one-codegen
+++ b/compiler/one-cmds/one-codegen
@@ -87,24 +87,19 @@ def main():
# verify arguments
_verify_arg(parser, args)
- # get file path to log
+ # make a command to run given backend driver
dir_path = os.path.dirname(os.path.realpath(__file__))
- logfile_path = os.path.realpath(args.output_path) + '.log'
-
- with open(logfile_path, 'wb') as f:
- # make a command to run given backend driver
- codegen_path = os.path.join(dir_path, getattr(args, 'backend') + '-compile')
- codegen_cmd = [codegen_path] + unknown_args
-
- f.write((' '.join(codegen_cmd) + '\n').encode())
-
- # run backend driver
- with subprocess.Popen(
- codegen_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
- bufsize=1) as p:
- for line in p.stdout:
- sys.stdout.buffer.write(line)
- f.write(line)
+ codegen_path = os.path.join(dir_path, getattr(args, 'backend') + '-compile')
+ codegen_cmd = [codegen_path] + unknown_args
+ if _utils._is_valid_attr(args, 'command'):
+ codegen_cmd += getattr(args, 'command').split()
+
+ # run backend driver
+ with subprocess.Popen(
+ codegen_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
+ bufsize=1) as p:
+ for line in p.stdout:
+ sys.stdout.buffer.write(line)
if __name__ == '__main__':
diff --git a/compiler/one-cmds/one-import-bcq b/compiler/one-cmds/one-import-bcq
index 5ea1f57fa..50f587946 100644
--- a/compiler/one-cmds/one-import-bcq
+++ b/compiler/one-cmds/one-import-bcq
@@ -43,13 +43,13 @@ def _get_parser():
converter_version.add_argument(
'--v1',
action='store_const',
- dest='converter_version',
+ dest='converter_version_cmd',
const='--v1',
help='use TensorFlow Lite Converter 1.x')
converter_version.add_argument(
'--v2',
action='store_const',
- dest='converter_version',
+ dest='converter_version_cmd',
const='--v2',
help='use TensorFlow Lite Converter 2.x')
diff --git a/compiler/one-cmds/one-import-tf b/compiler/one-cmds/one-import-tf
index 49009d331..3a7c69af3 100644
--- a/compiler/one-cmds/one-import-tf
+++ b/compiler/one-cmds/one-import-tf
@@ -52,8 +52,6 @@ def _get_parser():
const='--v2',
help='use TensorFlow Lite Converter 2.x')
- #converter_version.set_defaults(converter_version='--v1')
-
parser.add_argument('--converter_version', type=str, help=argparse.SUPPRESS)
# input model format
diff --git a/compiler/one-cmds/one-optimize b/compiler/one-cmds/one-optimize
index 4c5f10903..f03bb8dcc 100644
--- a/compiler/one-cmds/one-optimize
+++ b/compiler/one-cmds/one-optimize
@@ -73,6 +73,10 @@ def _get_parser():
circle2circle_group.add_argument(
'--fuse_instnorm', action='store_true', help='fuse ops to InstanceNorm operator')
circle2circle_group.add_argument(
+ '--replace_cw_mul_add_with_depthwise_conv',
+ action='store_true',
+ help='replace channel-wise Mul/Add with DepthwiseConv2D')
+ circle2circle_group.add_argument(
'--resolve_customop_add',
action='store_true',
help='convert Custom(Add) op to Add op')
diff --git a/compiler/one-cmds/tests/one-build_001.cfg b/compiler/one-cmds/tests/one-build_001.cfg
index 8524bbd1f..b022ba74b 100644
--- a/compiler/one-cmds/tests/one-build_001.cfg
+++ b/compiler/one-cmds/tests/one-build_001.cfg
@@ -13,7 +13,7 @@ output_path=inception_v3.circle
input_arrays=input
input_shapes=1,299,299,3
output_arrays=InceptionV3/Predictions/Reshape_1
-v2=True
+converter_version=v2
[one-optimize]
input_path=inception_v3.circle
diff --git a/compiler/one-cmds/tests/one-build_002.cfg b/compiler/one-cmds/tests/one-build_002.cfg
index 183077680..bbf09159b 100644
--- a/compiler/one-cmds/tests/one-build_002.cfg
+++ b/compiler/one-cmds/tests/one-build_002.cfg
@@ -13,7 +13,7 @@ output_path=inception_v3.circle
input_arrays=input
input_shapes=1,299,299,3
output_arrays=InceptionV3/Predictions/Reshape_1
-v2=True
+converter_version=v2
[one-optimize]
input_path=inception_v3.circle
diff --git a/compiler/one-cmds/tests/one-build_neg_002.cfg b/compiler/one-cmds/tests/one-build_neg_002.cfg
index 360c601e0..99db96651 100644
--- a/compiler/one-cmds/tests/one-build_neg_002.cfg
+++ b/compiler/one-cmds/tests/one-build_neg_002.cfg
@@ -13,7 +13,7 @@ output_path=inception_v3.circle
input_arrays=input
input_shapes=1,299,299,3
output_arrays=InceptionV3/Predictions/Reshape_1
-v2=True
+converter_version=v2
[one-optimize]
input_path=inception_v3.circle
diff --git a/compiler/one-cmds/tests/one-build_neg_003.cfg b/compiler/one-cmds/tests/one-build_neg_003.cfg
index 91e7875ac..fa027cb95 100644
--- a/compiler/one-cmds/tests/one-build_neg_003.cfg
+++ b/compiler/one-cmds/tests/one-build_neg_003.cfg
@@ -4,7 +4,7 @@ output_path=inception_v3.circle
input_arrays=input
input_shapes=1,299,299,3
output_arrays=InceptionV3/Predictions/Reshape_1
-v2=True
+converter_version=v2
[one-optimize]
input_path=inception_v3.circle
diff --git a/compiler/one-cmds/tests/one-build_neg_004.cfg b/compiler/one-cmds/tests/one-build_neg_004.cfg
index 4d312c47c..571077b42 100644
--- a/compiler/one-cmds/tests/one-build_neg_004.cfg
+++ b/compiler/one-cmds/tests/one-build_neg_004.cfg
@@ -13,7 +13,7 @@ output_path=inception_v3.circle
input_arrays=input
input_shapes=1,299,299,3
output_arrays=InceptionV3/Predictions/Reshape_1
-v2=True
+converter_version=v2
[one-optimize]
input_path=inception_v3.circle
diff --git a/compiler/one-cmds/tests/one-import_002.cfg b/compiler/one-cmds/tests/one-import_002.cfg
index 9a90abecd..8d6ae2c35 100644
--- a/compiler/one-cmds/tests/one-import_002.cfg
+++ b/compiler/one-cmds/tests/one-import_002.cfg
@@ -13,4 +13,4 @@ output_path=inception_v3.circle
input_arrays=input
input_shapes=1,299,299,3
output_arrays=InceptionV3/Predictions/Reshape_1
-v2=True
+converter_version=v2
diff --git a/compiler/one-cmds/tests/one-import_003.cfg b/compiler/one-cmds/tests/one-import_003.cfg
new file mode 100644
index 000000000..b679ebdb3
--- /dev/null
+++ b/compiler/one-cmds/tests/one-import_003.cfg
@@ -0,0 +1,13 @@
+[one-build]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=False
+one-quantize=False
+one-pack=False
+one-codegen=False
+
+[one-import-tf]
+model_format=saved_model
+input_path=test_saved_model
+output_path=test_saved_model.circle
diff --git a/compiler/one-cmds/tests/one-import_003.test b/compiler/one-cmds/tests/one-import_003.test
new file mode 100644
index 000000000..6093f1422
--- /dev/null
+++ b/compiler/one-cmds/tests/one-import_003.test
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# import of TF 2.x saved model
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="one-import_003.cfg"
+outputfile="test_saved_model.circle"
+
+rm -f ${outputfile}
+
+# run test
+one-import tf -C ${configfile} > /dev/null
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
diff --git a/compiler/one-cmds/tests/one-import_004.cfg b/compiler/one-cmds/tests/one-import_004.cfg
new file mode 100644
index 000000000..d28c8dff6
--- /dev/null
+++ b/compiler/one-cmds/tests/one-import_004.cfg
@@ -0,0 +1,13 @@
+[one-build]
+one-import-tf=True
+one-import-tflite=False
+one-import-bcq=False
+one-optimize=False
+one-quantize=False
+one-pack=False
+one-codegen=False
+
+[one-import-tf]
+model_format=keras_model
+input_path=test_keras_model.h5
+output_path=test_keras_model.circle
diff --git a/compiler/one-cmds/tests/one-import_004.test b/compiler/one-cmds/tests/one-import_004.test
new file mode 100644
index 000000000..9d10c431a
--- /dev/null
+++ b/compiler/one-cmds/tests/one-import_004.test
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# import of TF 2.x keras model
+
+filename_ext="$(basename -- $0)"
+filename="${filename_ext%.*}"
+
+trap_err_onexit()
+{
+ echo "${filename_ext} FAILED"
+ exit 255
+}
+
+trap trap_err_onexit ERR
+
+configfile="one-import_004.cfg"
+outputfile="test_keras_model.circle"
+
+rm -f ${outputfile}
+
+# run test
+one-import tf -C ${configfile} > /dev/null
+
+if [[ ! -s "${outputfile}" ]]; then
+ trap_err_onexit
+fi
+
+echo "${filename_ext} SUCCESS"
diff --git a/compiler/one-cmds/tests/prepare_test_materials.sh b/compiler/one-cmds/tests/prepare_test_materials.sh
index cb1067e28..bc3d65d92 100644
--- a/compiler/one-cmds/tests/prepare_test_materials.sh
+++ b/compiler/one-cmds/tests/prepare_test_materials.sh
@@ -63,6 +63,20 @@ if [[ ! -s "inception_v3_test_data.h5" ]]; then
--output_path inception_v3_test_data.h5
fi
+if [[ ! -d "test_saved_model" ]]; then
+ rm -rf test_saved_model.zip
+ wget https://github.com/Samsung/ONE/files/5516226/test_saved_model.zip
+ unzip test_saved_model.zip
+ # https://github.com/Samsung/ONE/issues/4268#issuecomment-724578237
+fi
+
+if [[ ! -s "test_keras_model.h5" ]]; then
+ rm -rf test_keras_model.zip
+ wget https://github.com/Samsung/ONE/files/5520777/test_keras_model.zip
+ unzip test_keras_model.zip
+ # https://github.com/Samsung/ONE/issues/4268#issuecomment-725025805
+fi
+
# prepare 'inception_v3.circle' file used for quantization test
inputfile="./inception_v3.pb"
outputfile="./inception_v3.circle"
diff --git a/compiler/oops/include/oops/InternalExn.h b/compiler/oops/include/oops/InternalExn.h
index 0e11085c0..e14332bb2 100644
--- a/compiler/oops/include/oops/InternalExn.h
+++ b/compiler/oops/include/oops/InternalExn.h
@@ -40,20 +40,20 @@ class InternalExn : public std::exception
{
public:
InternalExn(const char *filename, const int line, const std::string &msg)
- : _filename(filename), _line(line), _msg(msg)
+ : _filename(filename), _line(to_uint32(line)), _msg(msg)
{
construct_full_msg();
}
explicit InternalExn(const char *filename, const int line, const std::string &msg, uint32_t val)
- : _filename(filename), _line(line), _msg(msg + ": " + std::to_string(val))
+ : _filename(filename), _line(to_uint32(line)), _msg(msg + ": " + std::to_string(val))
{
construct_full_msg();
}
explicit InternalExn(const char *filename, const int line, const std::string &msg,
const std::string &val)
- : _filename(filename), _line(line), _msg(msg + ": " + val)
+ : _filename(filename), _line(to_uint32(line)), _msg(msg + ": " + val)
{
construct_full_msg();
}
diff --git a/compiler/pota-quantization-value-test/CMakeLists.txt b/compiler/pota-quantization-value-test/CMakeLists.txt
index 73b9ead73..80661e566 100644
--- a/compiler/pota-quantization-value-test/CMakeLists.txt
+++ b/compiler/pota-quantization-value-test/CMakeLists.txt
@@ -1,6 +1,12 @@
unset(QUANTIZATION_VALUE_TEST)
unset(QUANTIZATION_VALUE_TEST_WITH_PARAM)
+nnas_find_package(FlatBuffers QUIET)
+if(NOT FlatBuffers_FOUND)
+ message(STATUS "Build pota-quantization-value-test: FAILED (missing FlatBuffers)")
+ return()
+endif(NOT FlatBuffers_FOUND)
+
macro(addTest NAME GRANULARITY DTYPE)
list(APPEND QUANTIZATION_VALUE_TEST ${NAME})
list(APPEND QUANTIZATION_VALUE_TEST_WITH_PARAM ${NAME} ${GRANULARITY} ${DTYPE})
@@ -14,8 +20,12 @@ include("test.local.lst" OPTIONAL)
unset(TEST_DEPS)
get_target_property(ARTIFACTS_BIN_PATH testDataGenerator BINARY_DIR)
+get_target_property(SCHEMA_BIN_PATH mio_circle BINARY_DIR)
+
+configure_file("${CMAKE_CURRENT_SOURCE_DIR}/gen_h5_explicit_inputs.py"
+ "${CMAKE_CURRENT_BINARY_DIR}/gen_h5_explicit_inputs.py" COPYONLY)
-set(VIRTUALENV "${NNCC_OVERLAY_DIR}/venv_1_13_2")
+set(VIRTUALENV "${NNCC_OVERLAY_DIR}/venv_2_3_0")
###
### Generate test.config
@@ -35,7 +45,21 @@ add_custom_command(
COMMENT "Generate test configuration"
)
-list(APPEND TEST_DEPS "${TEST_CONFIG}")
+###
+### Generate python interface for circle schema
+###
+set(CIRCLE_SCHEMA_PYTHON_DIR "${CMAKE_CURRENT_BINARY_DIR}/circle")
+
+add_custom_command(
+ OUTPUT ${CIRCLE_SCHEMA_PYTHON_DIR}
+ COMMAND ${CMAKE_COMMAND} -E remove_directory "${CIRCLE_SCHEMA_PYTHON_DIR}"
+ COMMAND "$<TARGET_FILE:flatbuffers::flatc>" --python
+ -o "${CMAKE_CURRENT_BINARY_DIR}" "${SCHEMA_BIN_PATH}/schema.fbs"
+ DEPENDS flatbuffers::flatc
+ COMMENT "Generate python interface for circle schema"
+)
+
+list(APPEND TEST_DEPS "${TEST_CONFIG}" "${CIRCLE_SCHEMA_PYTHON_DIR}")
# This enforces CMake to generate all the dependencies during "build" phase
add_custom_target(pota_quantization_value_test_deps ALL DEPENDS ${TEST_DEPS})
diff --git a/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/beta.json b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/beta.json
new file mode 100644
index 000000000..fa2cdae3d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/beta.json
@@ -0,0 +1,20 @@
+{
+ "weights": [
+ 1,
+ 0,
+ 1,
+ 1
+ ],
+ "scale": [
+ 0.7023000121116638,
+ 0.3091999888420105,
+ 0.7552000284194946,
+ 0.2728999853134155
+ ],
+ "zero_point": [
+ 0,
+ 1,
+ 0,
+ 0
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/gamma.json b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/gamma.json
new file mode 100644
index 000000000..393a44ab0
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/gamma.json
@@ -0,0 +1,20 @@
+{
+ "weights": [
+ 1,
+ 0,
+ 1,
+ 0
+ ],
+ "scale": [
+ 0.012299999594688416,
+ 0.33239999413490295,
+ 0.23240000009536743,
+ 3.3359999656677246
+ ],
+ "zero_point": [
+ 0,
+ 1,
+ 0,
+ 1
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/ifm.json b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/ifm.json
new file mode 100644
index 000000000..94c4e0f06
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/ifm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.003919127397239208,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/ofm.json
new file mode 100644
index 000000000..27a1c8547
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.051219820976257324,
+ "zero_point": 104.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/record_minmax/ifm.json b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/record_minmax/ifm.json
new file mode 100644
index 000000000..910e855c3
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/record_minmax/ifm.json
@@ -0,0 +1,4 @@
+{
+ "min": 0.006417479291558266,
+ "max": 0.9993774032592774
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/record_minmax/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/record_minmax/ofm.json
new file mode 100644
index 000000000..190da3048
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/channel/uint8/record_minmax/ofm.json
@@ -0,0 +1,4 @@
+{
+ "min": -5.316554107666015,
+ "max": 7.744499607086182
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/beta.json b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/beta.json
new file mode 100644
index 000000000..9dcefd552
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/beta.json
@@ -0,0 +1,10 @@
+{
+ "weights": [
+ 242,
+ 0,
+ 255,
+ 139
+ ],
+ "scale": 0.004174117464572191,
+ "zero_point": 74.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/gamma.json b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/gamma.json
new file mode 100644
index 000000000..6d85a1ebb
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/gamma.json
@@ -0,0 +1,10 @@
+{
+ "weights": [
+ 239,
+ 214,
+ 255,
+ 0
+ ],
+ "scale": 0.013993725180625916,
+ "zero_point": 238.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/ifm.json b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/ifm.json
new file mode 100644
index 000000000..df3df56cc
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/ifm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.003914226312190294,
+ "zero_point": 0.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/ofm.json
new file mode 100644
index 000000000..098816af9
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.04870154336094856,
+ "zero_point": 122.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/record_minmax/ifm.json b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/record_minmax/ifm.json
new file mode 100644
index 000000000..d2e7923b5
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/record_minmax/ifm.json
@@ -0,0 +1,4 @@
+{
+ "min": 0.011221568882465362,
+ "max": 0.9981276893615723
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/record_minmax/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/record_minmax/ofm.json
new file mode 100644
index 000000000..b4ea58647
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/InstanceNorm_001/layer/uint8/record_minmax/ofm.json
@@ -0,0 +1,4 @@
+{
+ "min": -5.94246238708496,
+ "max": 6.4764308166503906
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/alpha.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/alpha.json
index 5f6db8d72..6f99899d5 100644
--- a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/alpha.json
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/alpha.json
@@ -2,12 +2,20 @@
"weights": [
[
[
- 6553,
- 19660,
- 32767
+ 1,
+ 1,
+ 1
]
]
],
- "scale": 1.5259254723787308e-05,
- "zero_point": 0.0
+ "scale": [
+ 0.10000000149011612,
+ 0.30000001192092896,
+ 0.5
+ ],
+ "zero_point": [
+ 0,
+ 0,
+ 0
+ ]
}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/ifm.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/ifm.json
index e75377c9e..7d1f4c795 100644
--- a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/ifm.json
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/ifm.json
@@ -1,4 +1,4 @@
{
- "scale": 0.0001509107678430155,
+ "scale": 0.00015214986342471093,
"zero_point": 0.0
}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/ofm.json
index e4a89e2c0..533c1e3e0 100644
--- a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/ofm.json
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/quantization/ofm.json
@@ -1,4 +1,4 @@
{
- "scale": 0.00015084103506524116,
+ "scale": 0.00015159364556893706,
"zero_point": 0.0
}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/record_minmax/ifm.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/record_minmax/ifm.json
index a34d48c2a..edbbff9cb 100644
--- a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/record_minmax/ifm.json
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/record_minmax/ifm.json
@@ -1,4 +1,4 @@
{
- "min": -4.944893226623535,
- "max": 4.942608108520508
+ "min": -4.985494499206543,
+ "max": 4.967269058227539
}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/record_minmax/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/record_minmax/ofm.json
index 640397c4d..954d5eff1 100644
--- a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/record_minmax/ofm.json
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/int16/record_minmax/ofm.json
@@ -1,4 +1,4 @@
{
- "min": -2.451441249847412,
- "max": 4.942608108520508
+ "min": -2.4895002365112306,
+ "max": 4.967269058227539
}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/quantization/alpha.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/quantization/alpha.json
new file mode 100644
index 000000000..6f99899d5
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/quantization/alpha.json
@@ -0,0 +1,21 @@
+{
+ "weights": [
+ [
+ [
+ 1,
+ 1,
+ 1
+ ]
+ ]
+ ],
+ "scale": [
+ 0.10000000149011612,
+ 0.30000001192092896,
+ 0.5
+ ],
+ "zero_point": [
+ 0,
+ 0,
+ 0
+ ]
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/quantization/ifm.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/quantization/ifm.json
new file mode 100644
index 000000000..d661df363
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/quantization/ifm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.03893596678972244,
+ "zero_point": 128.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/quantization/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/quantization/ofm.json
new file mode 100644
index 000000000..6dfffd563
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/quantization/ofm.json
@@ -0,0 +1,4 @@
+{
+ "scale": 0.029139429330825806,
+ "zero_point": 85.0
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/record_minmax/ifm.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/record_minmax/ifm.json
new file mode 100644
index 000000000..8de6b3dc2
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/record_minmax/ifm.json
@@ -0,0 +1,4 @@
+{
+ "min": -4.977406520843505,
+ "max": 4.951265411376953
+}
diff --git a/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/record_minmax/ofm.json b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/record_minmax/ofm.json
new file mode 100644
index 000000000..c88f6ca92
--- /dev/null
+++ b/compiler/pota-quantization-value-test/expected_outputs/PRelu_001/channel/uint8/record_minmax/ofm.json
@@ -0,0 +1,4 @@
+{
+ "min": -2.4792890548706055,
+ "max": 4.951265411376953
+}
diff --git a/compiler/pota-quantization-value-test/gen_h5_explicit_inputs.py b/compiler/pota-quantization-value-test/gen_h5_explicit_inputs.py
index 9863c807a..a00cbeba3 100755
--- a/compiler/pota-quantization-value-test/gen_h5_explicit_inputs.py
+++ b/compiler/pota-quantization-value-test/gen_h5_explicit_inputs.py
@@ -1,16 +1,17 @@
#!/usr/bin/env python3
import h5py as h5
import numpy as np
-import tensorflow as tf
+from circle.Model import Model
+from circle.TensorType import TensorType
import argparse
import glob
#
-# This script generates a pack of random input data (.h5) expected by the input tflite model
+# This script generates a pack of random input data (.h5) expected by the input circle model
#
# Basic usage:
# gen_h5_explicit_inputs.py --model <path/to/model/file> --input <path/to/input/directory> --output <path/to/output/file>
-# ex: gen_h5_explicit_inputs.py --model Add_000.tflite --input Add_000 --output Add_000.input.h5
+# ex: gen_h5_explicit_inputs.py --model Add_000.circle --input Add_000 --output Add_000.input.h5
# (This will create Add_000.input.h5)
#
# The input directory should be organized as follows
@@ -33,15 +34,30 @@ model = args.model
input = args.input
output = args.output
-# Build TFLite interpreter. (to get the information of model input)
-interpreter = tf.lite.Interpreter(model)
-input_details = interpreter.get_input_details()
+with open(model, 'rb') as f:
+ buf = f.read()
+ circle_model = Model.GetRootAsModel(buf, 0)
+
+# Assume one subgraph
+assert (circle_model.SubgraphsLength() == 1)
+graph = circle_model.Subgraphs(0)
+inputs = graph.InputsAsNumpy()
# Create h5 file
h5_file = h5.File(output, 'w')
group = h5_file.create_group("value")
group.attrs['desc'] = "Input data for " + model
+
+def toNumpyType(circle_type):
+ if circle_type == TensorType.UINT8:
+ return np.uint8
+ if circle_type == TensorType.FLOAT32:
+ return np.float32
+ if circle_type == TensorType.INT16:
+ return np.int16
+
+
# Input files
records = sorted(glob.glob(input + "/*.txt"))
for i, record in enumerate(records):
@@ -51,9 +67,10 @@ for i, record in enumerate(records):
lines = f.readlines()
for j, line in enumerate(lines):
data = np.array(line.split(','))
- input_detail = input_details[j]
- input_data = np.array(
- data.reshape(input_detail["shape"]), input_detail["dtype"])
+ input_index = inputs[j]
+ tensor = graph.Tensors(input_index)
+ np_type = toNumpyType(tensor.Type())
+ input_data = np.array(data.reshape(tensor.ShapeAsNumpy()), np_type)
sample.create_dataset(str(j), data=input_data)
h5_file.close()
diff --git a/compiler/pota-quantization-value-test/test.lst b/compiler/pota-quantization-value-test/test.lst
index 15606b8e4..dd1640428 100644
--- a/compiler/pota-quantization-value-test/test.lst
+++ b/compiler/pota-quantization-value-test/test.lst
@@ -13,6 +13,8 @@ addTest(DepthwiseConv2D_002 layer uint8)
addTest(FullyConnected_003 channel uint8)
addTest(FullyConnected_003 channel int16)
addTest(FullyConnected_003 layer uint8)
+addTest(InstanceNorm_001 layer uint8)
+addTest(InstanceNorm_001 channel uint8)
addTest(Mean_000 layer uint8)
addTest(Mean_000 channel int16)
addTest(MaxPool2D_000 layer uint8)
@@ -20,6 +22,7 @@ addTest(MaxPool2D_000 channel int16)
addTest(Mul_001 layer uint8)
addTest(Mul_001 channel int16)
addTest(PRelu_001 layer uint8)
+addTest(PRelu_001 channel uint8)
addTest(PRelu_001 channel int16)
addTest(ReLU_000 layer uint8)
addTest(ReLU_000 channel int16)
diff --git a/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/0.txt
new file mode 100644
index 000000000..5e926a2d9
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/0.txt
@@ -0,0 +1 @@
+0.15500909,0.32379007,0.12717001,0.60674316,0.07691418,0.437071 ,0.3737046 ,0.798342 ,0.65901846,0.40579247,0.15460491,0.80063623,0.591834 ,0.6617658 ,0.5617774 ,0.44884747,0.7996519 ,0.75895494,0.6239346 ,0.56500244,0.8955974 ,0.32503998,0.05756519,0.11889575,0.19635268,0.33958906,0.916527 ,0.16366032,0.51954055,0.2615102 ,0.07677322,0.6970092 ,0.27848312,0.97694606,0.73990864,0.96292055
diff --git a/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/1.txt
new file mode 100644
index 000000000..eb5de0c0e
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/1.txt
@@ -0,0 +1 @@
+0.85332185,0.03102963,0.54344934,0.6300742 ,0.3323267 ,0.1701224 ,0.36199054,0.23949413,0.11960976,0.668403 ,0.7907452 ,0.4377144 ,0.87145853,0.75605077,0.37314144,0.3622036 ,0.4321453 ,0.8770253 ,0.10936793,0.0734281 ,0.2922192 ,0.5829591 ,0.5422962 ,0.84274834,0.48475483,0.23154257,0.20037153,0.27911612,0.30018023,0.23753181,0.98804647,0.61455756,0.90376633,0.8255312 ,0.21020697,0.6272272
diff --git a/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/2.txt
new file mode 100644
index 000000000..16561ef0d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/2.txt
@@ -0,0 +1 @@
+0.29736656,0.5712386 ,0.55447775,0.9014779 ,0.6208391 ,0.3413809 ,0.043885 ,0.5474101 ,0.8642339 ,0.05225753,0.36101478,0.15561381,0.776422 ,0.9997885 ,0.35188794,0.23418508,0.0882741 ,0.5797471 ,0.99945694,0.22190607,0.12337059,0.3701574 ,0.65161157,0.9830193 ,0.46270686,0.10077237,0.23681253,0.8734158 ,0.8358533 ,0.08817147,0.3845248 ,0.12799203,0.66830546,0.14838815,0.90201443,0.21123447
diff --git a/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/3.txt
new file mode 100644
index 000000000..deba38b2d
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/3.txt
@@ -0,0 +1 @@
+0.92424273,0.35776526,0.0776509 ,0.93697083,0.6559925 ,0.78421926,0.7511033 ,0.71389145,0.52217877,0.41876563,0.3560251 ,0.5862293 ,0.53027606,0.32203177,0.24654935,0.55851364,0.35312092,0.38102064,0.21245371,0.87299466,0.94972914,0.54950166,0.3445233 ,0.98951054,0.37458083,0.3778964 ,0.64035404,0.10410193,0.18511558,0.1942945 ,0.07018933,0.6113747 ,0.38076922,0.08337755,0.98258 ,0.91440874
diff --git a/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/4.txt
new file mode 100644
index 000000000..78b783a74
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/channel/uint8/4.txt
@@ -0,0 +1 @@
+0.3790198 ,0.6347678 ,0.42544237,0.37033263,0.08057033,0.49041638,0.61705315,0.15411597,0.6455052 ,0.6857795 ,0.9613043 ,0.60357374,0.57679754,0.22550431,0.05105425,0.8641173 ,0.65559083,0.18274343,0.8963692 ,0.22369736,0.3133119 ,0.27507883,0.00539197,0.6846556 ,0.5969273 ,0.78488904,0.87746257,0.15459861,0.23133573,0.59048635,0.07172906,0.28935516,0.02084327,0.09926946,0.02687503,0.7306079
diff --git a/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/0.txt
new file mode 100644
index 000000000..25b600c5f
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/0.txt
@@ -0,0 +1 @@
+0.641226 ,0.68639857,0.87044334,0.9448475 ,0.21544299,0.5202749 ,0.5077167 ,0.23931624,0.5712026 ,0.4167988 ,0.56711906,0.52392703,0.42762014,0.5277072 ,0.03028643,0.18017273,0.8823869 ,0.5752544 ,0.09368648,0.50277 ,0.784248 ,0.04220072,0.55217946,0.75145644,0.7957966 ,0.6563401 ,0.54975605,0.17231019,0.4219812 ,0.27839735,0.5850074 ,0.24070603,0.00957893,0.3669335 ,0.03722228,0.8705231
diff --git a/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/1.txt
new file mode 100644
index 000000000..caadfed22
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/1.txt
@@ -0,0 +1 @@
+0.76871806,0.65729177,0.946514 ,0.4308198 ,0.65200335,0.5745432 ,0.2990488 ,0.3156028 ,0.3218111 ,0.44709972,0.9411461 ,0.4828708 ,0.5707792 ,0.10645963,0.74497086,0.3563156 ,0.07986172,0.64869064,0.73329425,0.8848129 ,0.3027897 ,0.8753744 ,0.8884493 ,0.3606782 ,0.88617206,0.20232914,0.10251648,0.6366529 ,0.20422891,0.24426484,0.6952833 ,0.21889713,0.11477511,0.40650114,0.9637219 ,0.9751801
diff --git a/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/2.txt
new file mode 100644
index 000000000..bc4a49454
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/2.txt
@@ -0,0 +1 @@
+0.5773043 ,0.6733178 ,0.22994593,0.32895002,0.74122405,0.6671442 ,0.1899878 ,0.35264668,0.31084946,0.3864719 ,0.7035006 ,0.46563607,0.44263086,0.2414678 ,0.7430625 ,0.72898006,0.9982008 ,0.8989132 ,0.45622516,0.17876478,0.9356994 ,0.85493064,0.73729265,0.9804242 ,0.8735895 ,0.14825071,0.33990774,0.76397645,0.14657325,0.2492199 ,0.43957144,0.20367876,0.43692476,0.28123745,0.24346785,0.21133597
diff --git a/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/3.txt
new file mode 100644
index 000000000..18f8666a0
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/3.txt
@@ -0,0 +1 @@
+0.74837255,0.7530814 ,0.05257462,0.06676125,0.26824346,0.05064487,0.23974492,0.5355457 ,0.97374374,0.38518724,0.3781766 ,0.7047476 ,0.95856845,0.09918232,0.36570287,0.5659468 ,0.8793284 ,0.7967468 ,0.99486005,0.11670698,0.42955273,0.25254622,0.06959745,0.5107888 ,0.88106513,0.3649466 ,0.7039582 ,0.8535825 ,0.3979168 ,0.9560912 ,0.17733434,0.69954944,0.35459924,0.28516313,0.75249106,0.7197228
diff --git a/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/4.txt
new file mode 100644
index 000000000..b51c5ebd0
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/InstanceNorm_001/layer/uint8/4.txt
@@ -0,0 +1 @@
+0.73320377,0.33635676,0.05811058,0.7032399 ,0.26380542,0.99637365,0.36622 ,0.47471517,0.5940316 ,0.39782768,0.46486765,0.5167471 ,0.61612487,0.93076104,0.8955697 ,0.5320168 ,0.41166067,0.29174343,0.07476811,0.60023075,0.0961028 ,0.77073896,0.17360727,0.48763612,0.31430086,0.37943754,0.7456216 ,0.16767363,0.9368368 ,0.09397154,0.68992966,0.5829225 ,0.7521187 ,0.06086114,0.13137193,0.22886442
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/0.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/0.txt
index 107491f8e..081a1e6ee 100644
--- a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/0.txt
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/0.txt
@@ -1 +1 @@
- 0.5590226 ,-0.2806683 ,-1.6237477 ,-0.9041292 ,-2.2877202 , 3.4275887 , 0.7413508 ,-2.4284103 ,-0.39940628, 2.431437 ,-3.681079 ,-0.24288087, 3.3011584 ,-4.9507365 , 0.63297826, 3.0742207 ,-4.407745 ,-3.1469536 , 0.28014645, 1.7506292 ,-2.2447422 ,-0.5647249 , 4.763762 ,-1.9554822 ,-1.0236452 , 1.4784483 ,-0.15040281, 3.009691 , 4.0685706 ,-4.3577633 , 3.9074588 , 3.3200462 , 0.7937705 ,-4.491444 ,-1.5227276 ,-4.907054 , 3.0078046 ,-3.3134713 ,-4.180262 , 0.42208448,-4.764361 , 1.7373432 ,-2.4944234 , 1.3338212 , 0.5318029 , 2.0201192 , 1.274291 ,-3.891372
+-1.9927613e+00,-1.7386111e+00, 4.0895696e+00, 3.7818990e+00, 1.9420158e+00, 2.8482721e+00, 1.9165717e+00, 3.0059583e+00, 1.8346788e+00,-1.9055414e-03, 4.9277787e+00,-2.2794118e+00, 4.4005270e+00, 4.9703922e+00,-4.5275192e+00,-4.0446317e-01,-4.9363256e+00, 4.9506269e+00, 5.5874938e-01, 3.9949589e+00,-3.8152415e-01,-4.1024357e-01,-3.8472393e+00, 4.2956004e+00, 4.8097472e+00, 1.7960385e+00, 1.6767026e+00,-2.2773645e+00, 2.6808765e+00,-3.7214172e+00, 4.0978761e+00, 3.6202488e+00,-3.3211513e+00, 3.6200387e+00,-3.6106458e+00,-3.9778764e+00, 3.8779631e+00,-4.8502750e+00,-2.1901150e+00, 3.1800017e+00, 4.6261444e+00, 3.5151103e+00, 2.8659137e-02, 4.5340648e+00, 1.9836371e+00,-2.1751235e+00,-4.6762753e+00,-3.6951694e+00
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/1.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/1.txt
index f95a6c3ba..f6b31db38 100644
--- a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/1.txt
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/1.txt
@@ -1 +1 @@
--2.5172353 , 1.8682998 , 2.6845884 , 1.8813597 ,-4.6693754 ,-3.2414548 ,-3.1801097 ,-1.5670214 , 1.9862102 , 3.857179 ,-3.0402668 ,-1.4183347 ,-2.7983398 ,-4.087585 ,-1.1274861 , 1.8738103 ,-2.563316 ,-2.973781 ,-0.872552 ,-4.4504313 ,-0.9188538 , 4.5734954 , 1.3559026 , 4.943204 ,-3.6803703 , 4.577067 ,-0.6116983 , 4.5055084 , 2.5480487 , 3.7308915 ,-0.3163238 ,-0.00772368, 3.0286303 ,-0.43645218, 0.87748104,-2.6953583 , 0.21743219, 2.431181 ,-1.2284794 , 0.35975334, 0.87034357,-2.5191767 , 4.030477 ,-1.2849646 ,-4.537441 ,-0.8822066 , 4.5059347 ,-0.9273924
+-4.7488093 , 4.805902 ,-0.29828382, 0.57486725,-4.864297 , 1.1832287 ,-1.7611881 ,-2.7058024 , 2.707353 ,-3.9832466 , 3.1243927 ,-4.795229 , 1.9835415 , 3.2291937 , 2.4303932 ,-3.556881 , 4.316894 ,-0.6444627 ,-3.8289468 , 4.012964 , 0.7878584 ,-1.8921386 , 2.779619 ,-3.762597 , 3.4239094 ,-0.9103423 ,-3.9791772 ,-2.5613685 ,-4.4910364 , 0.19411987, 4.6296096 ,-0.6827259 , 3.7645729 , 1.5309091 , 3.5163064 , 3.4726381 , 3.5372822 , 1.7671971 , 1.4374614 , 3.5783768 ,-2.4927518 , 3.9427729 , 2.431568 , 2.6959393 , 3.8100271 ,-2.099064 , 3.3663592 ,-2.0818436
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/2.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/2.txt
index 106889e6b..acc01cb55 100644
--- a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/2.txt
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/2.txt
@@ -1 +1 @@
- 4.523605 ,-2.1303053 , 2.7449381 ,-4.449816 ,-1.4482541 , 4.643309 ,-2.5644886 , 4.3115034 ,-4.7736797 ,-1.9451635 ,-2.1877592 , 2.3639698 ,-1.8480709 ,-4.560132 ,-0.40588248, 4.368528 ,-0.25666243, 1.1258887 , 2.33142 ,-3.8270295 ,-4.337086 ,-0.6709232 , 4.9283085 ,-3.5181348 , 2.225021 ,-0.0831629 , 2.0482597 , 3.161154 ,-0.49435407, 2.9382129 ,-1.248886 ,-3.7053974 , 1.6736145 ,-1.3524985 ,-1.4007242 ,-4.291275 ,-3.391911 , 4.803692 , 1.631321 , 0.13381048,-2.9587808 , 3.9878602 ,-3.3585925 , 4.6802793 ,-1.7605352 , 3.4168313 , 1.2318416 ,-4.40287
+ 4.279912 ,-2.2746763 , 4.0609813 , 4.5353827 , 3.624241 ,-3.9593613 , 4.189409 ,-3.9370356 ,-2.7063863 ,-1.9987059 , 4.172294 ,-4.5454354 , 4.362368 , 2.2204642 ,-4.9866576 , 3.31571 , 0.12623785, 4.7834573 ,-1.3521448 ,-1.5408021 ,-4.6578984 ,-2.93307 ,-1.5684534 ,-1.6875995 ,-0.4278419 , 1.1314197 ,-2.9655704 ,-0.48032767,-1.9200082 , 1.3321692 , 0.87586147,-0.1761448 , 3.939337 ,-1.0270193 ,-4.807054 , 2.8373904 ,-1.1184337 ,-0.8979197 , 2.1442132 ,-2.8509672 ,-3.3741531 , 3.6592414 , 0.7632272 ,-4.11465 , 4.892313 , 4.715815 ,-4.6481915 , 0.24676175
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/3.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/3.txt
index 488c3483a..0f0b7a939 100644
--- a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/3.txt
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/3.txt
@@ -1 +1 @@
- 1.249105 ,-3.2594535 ,-1.7899538 ,-4.804654 ,-2.0324056 ,-1.9959925 , 3.5215054 , 0.5371311 , 1.9365969 ,-3.130136 ,-2.3590457 ,-4.653209 ,-2.0184708 , 3.5759254 ,-1.3521014 , 1.910826 , 3.8221822 ,-2.8988552 , 0.6571995 , 1.0839036 , 3.5422468 , 2.4680734 , 0.6148754 ,-3.4008195 , 4.558109 , 2.0105803 , 0.58087206, 1.3398736 , 2.770545 , 0.29666626, 4.1851935 , 0.04321287, 2.7680604 , 4.5661645 , 4.0127945 ,-4.8027678 , 4.1711125 ,-0.24452859, 0.4101852 , 1.5963763 ,-2.8356924 , 1.2876563 , 0.90424466, 2.965566 ,-1.9058269 , 4.759825 ,-2.2063546 ,-1.1309439
+-2.0949495 ,-1.1370499 , 4.6457314 ,-2.243915 ,-1.7996464 , 1.2268789 ,-4.938172 ,-3.2802615 , 1.8788282 , 4.4162655 ,-4.8805113 , 3.1269526 , 3.2644348 , 0.89842725,-1.4484432 ,-0.28381723, 3.046261 ,-1.0718596 ,-3.996107 ,-4.9575796 ,-2.2279077 , 1.5326967 , 4.4588428 ,-2.042381 , 4.6604958 , 4.6422915 ,-1.097833 , 3.666126 , 0.4735639 ,-4.480704 ,-4.831033 ,-0.27288163, 4.588138 , 4.5297036 , 4.3675694 ,-1.6098841 ,-3.4147859 , 2.1168516 ,-1.9529305 ,-0.12548867, 3.4388335 ,-1.4071734 , 0.9507897 , 4.8206787 , 1.676873 ,-1.7102181 , 1.7746873 , 0.02711739
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/4.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/4.txt
index a59688e23..d23450db6 100644
--- a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/4.txt
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/int16/4.txt
@@ -1 +1 @@
--3.0078897 , 1.6800234 , 4.350201 , 0.22538732, 2.9894316 ,-4.234071 , 2.733158 ,-3.8551323 , 3.9647048 , 1.4266169 , 0.78519976,-0.5334222 , 0.6681823 , 2.8409274 , 2.335872 ,-3.757666 ,-3.321705 , 2.9423573 , 1.3080943 , 1.0453726 , 3.222387 , 3.1813147 ,-1.8588669 ,-3.2523947 ,-4.4175825 , 3.7631783 ,-3.4176416 , 1.2141145 , 1.3725096 ,-1.2283872 ,-2.9829195 ,-3.6383085 ,-2.0126016 ,-3.7627625 , 4.916868 , 0.73052526,-0.02047114,-3.9506733 , 2.3569562 ,-4.247723 ,-1.8913685 , 1.7365774 , 4.59158 , 3.654596 ,-4.2133813 ,-4.6193404 ,-1.3968121 ,-3.580963
+-4.707647 ,-4.0921726 , 3.5813692 ,-4.71081 , 3.157816 ,-3.0034213 ,-0.21858999,-1.1736552 ,-1.6042249 ,-3.93102 ,-4.0407577 , 3.7350774 ,-4.9545655 ,-1.5413756 , 0.34996858, 2.0339615 , 0.99290746,-3.9916334 ,-4.149016 ,-3.2332835 , 3.6728513 , 2.4537466 ,-3.103485 ,-0.4829316 , 4.8046784 ,-1.753812 , 4.878712 ,-1.4039769 , 1.6640003 ,-1.2041731 , 0.8046477 , 0.9196048 ,-0.6475092 , 1.1409346 , 2.0324717 ,-0.04227797,-0.5379897 , 3.205104 , 3.3556423 , 4.8447986 ,-1.9695646 ,-2.6304977 ,-3.7261262 ,-4.725599 , 2.1162436 ,-0.5631174 ,-0.5820323 , 0.8398242
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/0.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/0.txt
new file mode 100644
index 000000000..bcda22cb6
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/0.txt
@@ -0,0 +1 @@
+ 0.29413325,-0.5246354 , 2.5049045 , 4.9534087 , 0.9885207 ,-4.9603324 ,-2.534284 ,-1.2587626 ,-4.6054525 ,-4.0071754 , 3.204513 , 1.9254771 ,-3.0781755 ,-2.225973 , 3.3524523 , 3.817767 , 3.4921055 , 4.3435416 , 3.0849605 ,-1.4030998 ,-1.0506575 ,-0.42979953,-2.2500112 , 3.4057455 , 4.5414543 , 2.9366746 , 4.8639297 ,-0.1028097 , 2.3421814 , 0.6463296 ,-4.906506 ,-0.7544193 ,-4.0089574 , 2.3837643 ,-0.62171113,-3.349577 , 0.63758767,-3.6872568 ,-2.4398334 ,-1.1556609 ,-3.116043 ,-1.9698795 , 0.7246678 , 2.1801088 ,-2.5762403 , 2.5748649 ,-2.8637013 , 2.8755338
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/1.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/1.txt
new file mode 100644
index 000000000..937e08f69
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/1.txt
@@ -0,0 +1 @@
+-3.5664022e+00, 3.7696166e+00,-2.0404069e+00,-3.2197843e+00, 2.0149478e-01, 4.1116104e+00, 1.9678035e+00,-7.5975507e-01,-2.1460054e+00, 4.6308274e+00,-1.8927828e+00, 3.0689645e+00,-7.0773923e-01,-6.7477709e-01,-1.6248076e+00, 2.7095401e+00, 2.9545853e+00, 8.5142839e-01,-2.7683893e-01,-2.0586762e+00,-3.5001924e+00,-1.7622359e+00, 2.2262762e+00,-4.0617161e+00,-2.4704919e+00,-3.6333869e+00, 2.3401244e+00,-4.6641917e+00,-4.0812837e-03, 1.1013873e+00, 1.4518824e-01, 2.4135842e+00, 4.1183419e+00, 3.0343807e+00,-3.7195799e-01,-9.7189492e-01,-3.0425618e+00, 4.6822820e+00,-1.7649661e+00, 3.9648254e+00,-3.1084957e+00,-7.3071235e-01,-5.1578474e-01,-3.5188673e+00,-4.7018051e+00,-4.1592669e+00,-3.5443991e-01, 1.3961188e+00
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/2.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/2.txt
new file mode 100644
index 000000000..fb30491cd
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/2.txt
@@ -0,0 +1 @@
+ 4.2618856 , 0.4364266 , 0.5258691 , 3.5147502 ,-4.025428 , 3.143039 , 1.3707066 , 4.7792606 , 1.1539228 , 3.785161 ,-1.9495047 , 2.7047534 , 0.5673139 ,-0.5191105 ,-2.5284607 , 4.076998 , 2.9433093 ,-2.1924984 , 1.1020935 ,-2.126009 , 0.7586875 , 1.1708144 ,-4.594603 ,-3.252912 ,-3.057344 , 3.8008513 ,-4.9164753 ,-4.560891 , 1.724639 ,-3.0877826 , 0.55354726,-3.969067 , 4.17461 ,-1.901139 ,-4.8903475 , 4.7866077 ,-1.3506653 ,-4.2624874 , 0.8842832 , 4.672003 ,-2.5649548 ,-3.6606123 ,-1.6794366 ,-2.0534387 ,-2.9902222 , 3.078469 , 2.846819 , 1.2788221
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/3.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/3.txt
new file mode 100644
index 000000000..fb9d40ae0
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/3.txt
@@ -0,0 +1 @@
+-2.6751792 ,-2.5436802 , 0.30533552, 1.0443643 ,-4.4327927 , 2.813772 ,-4.27514 , 2.5894637 , 2.8684394 ,-2.2010357 , 1.5827026 , 0.01609957, 0.38605672,-4.978118 ,-0.30794173, 0.7372266 ,-1.2931277 , 2.8435483 , 2.8204155 , 1.5801594 , 0.853025 , 1.0665054 ,-2.3281817 ,-4.2512784 , 2.379218 , 2.6335719 , 0.17575608,-2.7761426 ,-2.8164017 , 1.8392245 , 2.6495574 , 0.82702005, 3.8548648 ,-3.179834 , 0.25908127, 2.4930098 , 0.71019745,-3.193962 ,-1.1381371 ,-3.5847874 ,-1.3353258 , 2.942422 , 0.11944559,-3.0676606 , 3.534187 , 0.86664987,-1.4781127 , 4.8873277
diff --git a/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/4.txt b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/4.txt
new file mode 100644
index 000000000..aeecd56c3
--- /dev/null
+++ b/compiler/pota-quantization-value-test/test_inputs/PRelu_001/channel/uint8/4.txt
@@ -0,0 +1 @@
+ 4.2327642 , 4.644095 ,-2.8978996 , 4.39419 , 2.897952 ,-3.330613 ,-3.9131684 ,-1.4672462 ,-3.9219787 , 2.1286428 ,-4.313653 , 2.65426 ,-4.201722 , 2.5390174 ,-3.821772 ,-1.9420135 , 3.3508427 ,-1.2804624 , 4.899826 ,-4.165279 ,-0.38920662, 3.594253 ,-2.367396 , 3.8604352 , 0.40077925, 3.7654843 ,-2.7208197 , 3.4325044 ,-2.921729 , 2.0519714 ,-0.6181836 ,-0.12342291,-4.1059036 ,-3.653849 ,-3.5340316 ,-0.2782715 , 0.32330513, 3.360021 , 2.5673623 , 2.1614027 ,-4.438277 , 3.3010736 , 0.3992392 , 0.82871836,-2.8720777 , 0.29633927, 0.25286415,-4.191315
diff --git a/compiler/pota-quantization-value-test/test_record_minmax.sh b/compiler/pota-quantization-value-test/test_record_minmax.sh
index acb7574c0..fa8f506d4 100755
--- a/compiler/pota-quantization-value-test/test_record_minmax.sh
+++ b/compiler/pota-quantization-value-test/test_record_minmax.sh
@@ -9,11 +9,11 @@
# work_dir : build directory of quantization-value-test (ex: build/compiler/quantization-value-test)
SOURCE_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
-GEN_SCRIPT_PATH="${SOURCE_PATH}/gen_h5_explicit_inputs.py"
COMPARE_SCRIPT_PATH="${SOURCE_PATH}/compare_tensors.py"
CONFIG_PATH="$1"; shift
BIN_PATH=$(dirname "${CONFIG_PATH}")
TEST_INPUT_PATH="${SOURCE_PATH}/test_inputs"
+GEN_SCRIPT_PATH="${BIN_PATH}/gen_h5_explicit_inputs.py"
WORKDIR="$1"; shift
source "${CONFIG_PATH}"
@@ -48,7 +48,7 @@ while [ "$1" != "" ]; do
# Generate h5 input data
source "${VIRTUALENV}/bin/activate"
"${VIRTUALENV}/bin/python" "${GEN_SCRIPT_PATH}" \
- --model "${WORKDIR}/${MODELNAME}.tflite" \
+ --model "${WORKDIR}/${MODELNAME}.circle" \
--input "${TEST_INPUT_PATH}/${MODELNAME}/${GRANULARITY}/${DTYPE}" \
--output "${TESTCASE_FILE}.input.h5"
diff --git a/compiler/tflchef/core/src/CustomOp/MaxPoolWithArgMax.cpp b/compiler/tflchef/core/src/CustomOp/MaxPoolWithArgMax.cpp
index b1c92ecbd..13bf2e5e9 100644
--- a/compiler/tflchef/core/src/CustomOp/MaxPoolWithArgMax.cpp
+++ b/compiler/tflchef/core/src/CustomOp/MaxPoolWithArgMax.cpp
@@ -65,13 +65,13 @@ MaxPoolWithArgMaxChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
flex_buffers->Add(1);
flex_buffers->EndVector(start, /*typed=*/true, /*fixed=*/false);
auto output_type = operation.max_pool_with_argmax_options().output_type();
- assert(output_type == tflite::TensorType_INT64 || output_type == tflite::TensorType_INT32);
+ assert(output_type == tflchef::INT64 || output_type == tflchef::INT32);
flex_buffers->Int("Targmax", output_type);
std::string padding = operation.max_pool_with_argmax_options().padding() ? "VALID" : "SAME";
flex_buffers->String("padding", padding);
flex_buffers->Bool("include_batch_in_index",
operation.max_pool_with_argmax_options().include_batch_in_index());
- flex_buffers->Int("T", tflite::TensorType_FLOAT32);
+ flex_buffers->Int("T", tflchef::FLOAT32);
flex_buffers->EndMap(map_start);
flex_buffers->Finish();
diff --git a/compiler/tfldump/src/Dump.cpp b/compiler/tfldump/src/Dump.cpp
index 8c8178f93..20e1343e6 100644
--- a/compiler/tfldump/src/Dump.cpp
+++ b/compiler/tfldump/src/Dump.cpp
@@ -349,6 +349,7 @@ void dump_model(std::ostream &os, const tflite::Model *model)
auto opcodes = reader.opcodes();
auto buffers = reader.buffers();
+ auto metadata = reader.metadata();
// dump operator_codes
os << "Operator Codes: [order] OpCodeName (OpCode Enum)" << std::endl;
@@ -382,6 +383,17 @@ void dump_model(std::ostream &os, const tflite::Model *model)
}
os << std::endl;
+ // dump metadata
+ if (metadata != nullptr)
+ {
+ os << "metadata : B(index) name" << std::endl;
+ for (uint32_t i = 0; i < metadata->Length(); ++i)
+ {
+ os << "B(" << metadata->Get(i)->buffer() << ") " << metadata->Get(i)->name()->c_str();
+ }
+ os << std::endl;
+ }
+
for (uint32_t sg = 0; sg < num_subgraph; ++sg)
{
reader.select_subgraph(sg);
diff --git a/compiler/tfldump/src/OpPrinter.cpp b/compiler/tfldump/src/OpPrinter.cpp
index 5d279632c..c35848047 100644
--- a/compiler/tfldump/src/OpPrinter.cpp
+++ b/compiler/tfldump/src/OpPrinter.cpp
@@ -694,6 +694,7 @@ OpPrinterRegistry::OpPrinterRegistry()
// There is no Option for LOGISTIC
// There is no Option for LOG_SOFTMAX
_op_map[tflite::BuiltinOperator_MAX_POOL_2D] = make_unique<Pool2DPrinter>();
+ _op_map[tflite::BuiltinOperator_MEAN] = make_unique<ReducerPrinter>();
_op_map[tflite::BuiltinOperator_MIRROR_PAD] = make_unique<MirrorPadPrinter>();
_op_map[tflite::BuiltinOperator_MUL] = make_unique<MulPrinter>();
// There is no Option for NON_MAX_SUPPRESSION_V4
diff --git a/compiler/tfldump/src/Read.cpp b/compiler/tfldump/src/Read.cpp
index f9782d9ef..856cc5699 100644
--- a/compiler/tfldump/src/Read.cpp
+++ b/compiler/tfldump/src/Read.cpp
@@ -81,6 +81,7 @@ Reader::Reader(const tflite::Model *model)
_version = model->version();
_subgraphs = model->subgraphs();
_buffers = model->buffers();
+ _metadata = model->metadata();
auto opcodes = model->operator_codes();
for (const ::tflite::OperatorCode *opcode : *opcodes)
diff --git a/compiler/tfldump/src/Read.h b/compiler/tfldump/src/Read.h
index 7af2fa59b..f835be140 100644
--- a/compiler/tfldump/src/Read.h
+++ b/compiler/tfldump/src/Read.h
@@ -52,6 +52,7 @@ private:
using TFliteBuffers_t = flatbuffers::Vector<flatbuffers::Offset<tflite::Buffer>>;
using TFliteTensors_t = flatbuffers::Vector<flatbuffers::Offset<tflite::Tensor>>;
using TFliteOperators_t = flatbuffers::Vector<flatbuffers::Offset<tflite::Operator>>;
+ using TFliteMetadata_t = flatbuffers::Vector<flatbuffers::Offset<tflite::Metadata>>;
public:
Reader(const tflite::Model *model);
@@ -67,6 +68,7 @@ public:
const TFliteOperators_t *operators() { return _operators; }
const std::vector<int32_t> &inputs() const { return _inputs; }
const std::vector<int32_t> &outputs() const { return _outputs; }
+ const TFliteMetadata_t *metadata() const { return _metadata; }
uint32_t num_subgraph() const { return _subgraphs->Length(); }
@@ -86,6 +88,7 @@ private:
const TFliteBuffers_t *_buffers{nullptr};
const TFliteTensors_t *_tensors{nullptr};
const TFliteOperators_t *_operators{nullptr};
+ const TFliteMetadata_t *_metadata{nullptr};
uint32_t _subgraph_index;
std::string _subgraph_name;
diff --git a/compiler/vconone/CMakeLists.txt b/compiler/vconone/CMakeLists.txt
index 905515401..595bbfd99 100644
--- a/compiler/vconone/CMakeLists.txt
+++ b/compiler/vconone/CMakeLists.txt
@@ -1,5 +1,5 @@
if (NOT VCONONE_VERSION)
- set(VCONONE_VERSION 0x00000000000b0001)
+ set(VCONONE_VERSION 0x00000000000c0001)
# NOTE order is [build patch minor major]
# if VCONONE_VERSION is set with -D option, it will be cached
# you may have to remove cache file if you remove -D option