summaryrefslogtreecommitdiff
path: root/tools/tflitefile_tool/parser
diff options
context:
space:
mode:
Diffstat (limited to 'tools/tflitefile_tool/parser')
-rw-r--r--tools/tflitefile_tool/parser/__init__.py0
-rwxr-xr-xtools/tflitefile_tool/parser/model_parser.py31
-rw-r--r--tools/tflitefile_tool/parser/tflite/tflite_enum_str_maps.py40
-rwxr-xr-xtools/tflitefile_tool/parser/tflite/tflite_operator.py63
-rw-r--r--tools/tflitefile_tool/parser/tflite/tflite_option.py96
-rwxr-xr-xtools/tflitefile_tool/parser/tflite/tflite_parser.py112
-rwxr-xr-xtools/tflitefile_tool/parser/tflite/tflite_subgraph.py30
-rwxr-xr-xtools/tflitefile_tool/parser/tflite/tflite_tensor.py124
8 files changed, 496 insertions, 0 deletions
diff --git a/tools/tflitefile_tool/parser/__init__.py b/tools/tflitefile_tool/parser/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/tools/tflitefile_tool/parser/__init__.py
diff --git a/tools/tflitefile_tool/parser/model_parser.py b/tools/tflitefile_tool/parser/model_parser.py
new file mode 100755
index 000000000..68cd31a23
--- /dev/null
+++ b/tools/tflitefile_tool/parser/model_parser.py
@@ -0,0 +1,31 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from parser.tflite.tflite_parser import TFLiteParser
+
+
+class ModelParser(object):
+ def __init__(self, model_file):
+ self.parser = None
+ # model_file: _io.BufferedReader
+ if model_file.name.endswith("tflite"):
+ self.parser = TFLiteParser(model_file)
+ # TODO: Add more parser
+
+ def Parse(self):
+ if self.parser is None:
+ raise NotImplementedError
+ return self.parser.Parse()
diff --git a/tools/tflitefile_tool/parser/tflite/tflite_enum_str_maps.py b/tools/tflitefile_tool/parser/tflite/tflite_enum_str_maps.py
new file mode 100644
index 000000000..6a3a2054f
--- /dev/null
+++ b/tools/tflitefile_tool/parser/tflite/tflite_enum_str_maps.py
@@ -0,0 +1,40 @@
+#!/usr/bin/python
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tflite.BuiltinOperator
+import tflite.ActivationFunctionType
+import tflite.BuiltinOptions
+
+
+# Match enum value integer to name string
+# Assumption 1: enum value is defined by old style (can be used on python 2)
+# Assumption 2: when class define enum value, only constant value is defined and methods are not defined
+# Assumption 3: only integer value is set by constant definition
+def BuildEnumClassStrMap(obj):
+ ret = {}
+ for fieldName in dir(obj):
+ if (not fieldName.startswith('_')):
+ fieldValue = getattr(obj, fieldName)
+ if (isinstance(fieldValue, (int))):
+ ret[fieldValue] = fieldName
+ return ret
+
+
+class EnumStrMaps():
+ BuiltinOpcode = BuildEnumClassStrMap(tflite.BuiltinOperator.BuiltinOperator())
+ ActivationFunctionType = BuildEnumClassStrMap(
+ tflite.ActivationFunctionType.ActivationFunctionType())
+ BuiltinOptions = BuildEnumClassStrMap(tflite.BuiltinOptions.BuiltinOptions())
diff --git a/tools/tflitefile_tool/parser/tflite/tflite_operator.py b/tools/tflitefile_tool/parser/tflite/tflite_operator.py
new file mode 100755
index 000000000..211007e1c
--- /dev/null
+++ b/tools/tflitefile_tool/parser/tflite/tflite_operator.py
@@ -0,0 +1,63 @@
+#!/usr/bin/python
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ir.operator import Operator
+from .tflite_enum_str_maps import EnumStrMaps
+from .tflite_option import OptionLoader, GetStringOptions
+
+
+class TFLiteOperator(Operator):
+ def __init__(self, operator_idx, tf_operator, input_tensors, output_tensors,
+ opcode_str):
+ super(TFLiteOperator, self).__init__()
+
+ self.index = operator_idx
+ self.inputs = input_tensors
+ self.outputs = output_tensors
+ self.op_name = opcode_str
+ self.activation = "NONE"
+ self.options = ""
+
+ self.tf_operator = tf_operator
+ self.tf_options = None
+ self.SetupBuiltinOption()
+ self.SetupFusedActivation()
+
+ def SetupBuiltinOption(self):
+ # FIXME: workaround for ops such as custom
+ try:
+ self.tf_options = OptionLoader.GetBuiltinOptions(
+ self.tf_operator.BuiltinOptionsType(), self.tf_operator.BuiltinOptions())
+ if self.tf_options == None:
+ return
+
+ option_str = GetStringOptions(self.op_name, self.tf_options)
+ if option_str is None:
+ return
+
+ self.options = option_str
+ except KeyError:
+ return
+
+ def SetupFusedActivation(self):
+ if self.tf_options == None:
+ return
+ try:
+ activation_code = self.tf_options.FusedActivationFunction()
+ self.activation = EnumStrMaps.ActivationFunctionType[activation_code]
+ except AttributeError:
+ # This operator does not support FusedActivationFunction
+ pass
diff --git a/tools/tflitefile_tool/parser/tflite/tflite_option.py b/tools/tflitefile_tool/parser/tflite/tflite_option.py
new file mode 100644
index 000000000..b85fbae90
--- /dev/null
+++ b/tools/tflitefile_tool/parser/tflite/tflite_option.py
@@ -0,0 +1,96 @@
+#!/usr/bin/python
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .tflite_enum_str_maps import EnumStrMaps
+
+
+def GetAttribute(o, *args):
+ import functools
+ return functools.reduce(getattr, args, o)
+
+
+def BuildBuiltinOptionGen():
+ bo_gen = {}
+ for val_enum in EnumStrMaps.BuiltinOptions:
+ val_str = EnumStrMaps.BuiltinOptions[val_enum]
+ try:
+ # Dynamically import Builtin Option classes
+ # 0 (NONE) is the only exception that does not have no corresponding flatbuffer-generated class
+ module = __import__("tflite." + val_str)
+ bo_gen[val_enum] = GetAttribute(module, val_str, val_str)
+ except ImportError as e:
+ assert val_enum == 0 and val_str == "NONE"
+ return bo_gen
+
+
+class OptionLoader:
+ builtinOptionGen = BuildBuiltinOptionGen()
+
+ @staticmethod
+ def GetBuiltinOptions(options_type, options_table):
+ if (options_table == None) and (options_type != 0):
+ print(
+ "Bad flatbuffer file: undefined builtin option table with defined option type"
+ )
+ exit(1)
+ options = OptionLoader.builtinOptionGen[options_type]()
+ options.Init(options_table.Bytes, options_table.Pos)
+ return options
+
+
+def GetStringPadding(options):
+ if options.Padding() == 0:
+ return "SAME"
+ elif options.Padding() == 1:
+ return "VALID"
+ else:
+ return "** wrong padding value **"
+
+
+def GetStringOptions(op_name, options):
+ if (op_name == "AVERAGE_POOL_2D" or op_name == "MAX_POOL_2D"):
+ return "{}, {}, {}".format(
+ "Filter W:H = {}:{}".format(options.FilterWidth(), options.FilterHeight()),
+ "Stride W:H = {}:{}".format(options.StrideW(),
+ options.StrideH()), "Padding = {}".format(
+ GetStringPadding(options)))
+ elif (op_name == "CONV_2D"):
+ return "{}, {}, {}".format(
+ "Stride W:H = {}:{}".format(options.StrideW(), options.StrideH()),
+ "Dilation W:H = {}:{}".format(options.DilationWFactor(),
+ options.DilationHFactor()),
+ "Padding = {}".format(GetStringPadding(options)))
+ elif (op_name == "DEPTHWISE_CONV_2D"):
+ # yapf: disable
+ return "{}, {}, {}, {}".format(
+ "Stride W:H = {}:{}".format(options.StrideW(),
+ options.StrideH()),
+ "Dilation W:H = {}:{}".format(options.DilationWFactor(),
+ options.DilationHFactor()),
+ "Padding = {}".format(GetStringPadding(options)),
+ "DepthMultiplier = {}".format(options.DepthMultiplier()))
+ # yapf: enable
+ elif (op_name == "STRIDED_SLICE"):
+ # yapf: disable
+ return "{}, {}, {}, {}, {}".format(
+ "begin_mask({})".format(options.BeginMask()),
+ "end_mask({})".format(options.EndMask()),
+ "ellipsis_mask({})".format(options.EllipsisMask()),
+ "new_axis_mask({})".format(options.NewAxisMask()),
+ "shrink_axis_mask({})".format(options.ShrinkAxisMask()))
+ # yapf: enable
+ else:
+ return None
diff --git a/tools/tflitefile_tool/parser/tflite/tflite_parser.py b/tools/tflitefile_tool/parser/tflite/tflite_parser.py
new file mode 100755
index 000000000..6a8f2b8ab
--- /dev/null
+++ b/tools/tflitefile_tool/parser/tflite/tflite_parser.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tflite.Model
+from .tflite_subgraph import TFLiteSubgraph
+from .tflite_operator import TFLiteOperator, EnumStrMaps
+from .tflite_tensor import TFLiteTensor, SetTensorTypeStr
+
+
+def HasOptionalTensor(tf_subgraph):
+ for operator_idx in range(tf_subgraph.OperatorsLength()):
+ tf_operator = tf_subgraph.Operators(operator_idx)
+ if -1 in tf_operator.InputsAsNumpy():
+ return True
+ output_tensors = tf_operator.OutputsAsNumpy()
+ if -1 in tf_operator.OutputsAsNumpy():
+ return True
+
+ return False
+
+
+class TFLiteSubgraphParser(object):
+ def __init__(self, tf_model, subgraph_index):
+ self.tf_model = tf_model
+ self.tf_subgraph = tf_model.Subgraphs(subgraph_index)
+ self.subg = TFLiteSubgraph(subgraph_index, self.tf_subgraph)
+
+ # Tensor type string table
+ SetTensorTypeStr()
+
+ def Parse(self):
+ if HasOptionalTensor(self.tf_subgraph):
+ # Prepare that optional input and output tensors are indicated by -1
+ self.subg.tensors_map[-1] = TFLiteTensor(-1, None, None)
+
+ # tensors
+ for tensor_idx in range(self.tf_subgraph.TensorsLength()):
+ tf_tensor = self.tf_subgraph.Tensors(tensor_idx)
+ buffer_idx = tf_tensor.Buffer()
+ tf_buffer = self.tf_model.Buffers(buffer_idx)
+ t = TFLiteTensor(tensor_idx, tf_tensor, tf_buffer)
+ self.subg.tensors_map[tensor_idx] = t
+
+ # operators
+ for operator_idx in range(self.tf_subgraph.OperatorsLength()):
+ tf_operator = self.tf_subgraph.Operators(operator_idx)
+ op_name = self.GetOpcodeStr(tf_operator)
+ input_tensors = self.GetTensors(tf_operator.InputsAsNumpy())
+ output_tensors = self.GetTensors(tf_operator.OutputsAsNumpy())
+
+ op = TFLiteOperator(operator_idx, tf_operator, input_tensors, output_tensors,
+ op_name)
+ self.subg.operators_map[op.index] = op
+ self.subg.optypes_map[op.op_name] = op
+
+ self.subg.inputs = self.GetTensors(self.tf_subgraph.InputsAsNumpy())
+ self.subg.outputs = self.GetTensors(self.tf_subgraph.OutputsAsNumpy())
+
+ return self.subg
+
+ def GetOpcodeStr(self, tf_operator):
+ opcode_list_idx = tf_operator.OpcodeIndex()
+ opcode_id = self.tf_model.OperatorCodes(opcode_list_idx).BuiltinCode()
+ opcode_str = EnumStrMaps.BuiltinOpcode[opcode_id]
+ if opcode_id == 32:
+ # Custom operator
+ custom_operator = self.tf_model.OperatorCodes(tf_operator.OpcodeIndex())
+ custom_op_name = custom_operator.CustomCode().decode('utf-8')
+ opcode_str = opcode_str + "(" + custom_op_name + ")"
+ return opcode_str
+
+ def GetTensors(self, tf_tensors_index):
+ assert len(self.subg.tensors_map.keys()) > 0
+
+ return_list = []
+ for tensor_idx in tf_tensors_index:
+ return_list.append(self.subg.tensors_map[tensor_idx])
+ return return_list
+
+
+class TFLiteParser(object):
+ def __init__(self, model_file):
+ self.model_file = model_file
+
+ def Parse(self):
+ # Generate Model: top structure of tflite model file
+ buf = self.model_file.read()
+ buf = bytearray(buf)
+ tf_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+
+ # Model file can have many models
+ subg_list = []
+ for subgraph_index in range(tf_model.SubgraphsLength()):
+ # Parse Subgraphs
+ subg_parser = TFLiteSubgraphParser(tf_model, subgraph_index)
+ subg = subg_parser.Parse()
+ subg_list.append(subg)
+
+ return subg_list
diff --git a/tools/tflitefile_tool/parser/tflite/tflite_subgraph.py b/tools/tflitefile_tool/parser/tflite/tflite_subgraph.py
new file mode 100755
index 000000000..0c6338ec6
--- /dev/null
+++ b/tools/tflitefile_tool/parser/tflite/tflite_subgraph.py
@@ -0,0 +1,30 @@
+#!/usr/bin/python
+
+# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ir.subgraph import Subgraph
+
+
+class TFLiteSubgraph(Subgraph):
+ def __init__(self, subg_idx, tf_subgraph):
+ super(TFLiteSubgraph, self).__init__()
+ self.tf_subgraph = tf_subgraph
+
+ self.index = subg_idx
+ if tf_subgraph.Name() is not None:
+ self.subg_name = str(tf_subgraph.Name())
+ self.model_name = "#{0} {1}".format(subg_idx, self.subg_name)
+ if (subg_idx == 0): # 0th subgraph is main subgraph
+ self.model_name += " (MAIN)"
diff --git a/tools/tflitefile_tool/parser/tflite/tflite_tensor.py b/tools/tflitefile_tool/parser/tflite/tflite_tensor.py
new file mode 100755
index 000000000..5eb35e63e
--- /dev/null
+++ b/tools/tflitefile_tool/parser/tflite/tflite_tensor.py
@@ -0,0 +1,124 @@
+#!/usr/bin/python
+
+# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import tflite.Tensor
+import tflite.TensorType
+from ir.tensor import Tensor
+
+TensorTypeList = {}
+
+
+def SetTensorTypeStr():
+ tensorTypeObj = tflite.TensorType.TensorType()
+
+ for fieldName in dir(tensorTypeObj):
+ if (not fieldName.startswith('_')):
+ fieldValue = getattr(tensorTypeObj, fieldName)
+ if (isinstance(fieldValue, (int))):
+ TensorTypeList[fieldValue] = fieldName
+
+
+TYPES_SIZE = {
+ 'BOOL': 1,
+ 'COMPLEX64': 8,
+ 'FLOAT16': 2,
+ 'FLOAT32': 4,
+ 'INT16': 2,
+ 'INT32': 4,
+ 'INT64': 8,
+ 'UINT8': 1,
+ 'NONE': 0,
+}
+
+
+def GetTypeSize(type_name):
+ try:
+ return TYPES_SIZE[type_name]
+
+ except KeyError as error:
+ return 0
+
+
+TYPE_TO_NPTYPE = {
+ 'BOOL': np.bool_,
+ 'COMPLEX64': np.cdouble,
+ 'FLOAT16': np.float16,
+ 'FLOAT32': np.float32,
+ 'INT16': np.int16,
+ 'INT32': np.int32,
+ 'INT64': np.int64,
+ 'UINT8': np.uint8,
+}
+
+
+def ConvertProperNPArrayType(np_arr, np_shape, type_name):
+ try:
+ return np_arr.view(TYPE_TO_NPTYPE[type_name]).reshape(np_shape)
+ except KeyError as error:
+ return np_arr.view().reshape(np_shape)
+
+
+class TFLiteTensor(Tensor):
+ def __init__(self, tensor_idx, tf_tensor, tf_buffer):
+ super(TFLiteTensor, self).__init__()
+ self.tf_tensor = tf_tensor
+ self.tf_buffer = tf_buffer
+
+ self.index = int(tensor_idx)
+ self.tensor = tf_tensor
+
+ # optional input
+ if self.index == -1:
+ self.type_name = "NONE"
+ # general input
+ else:
+ assert tf_tensor is not None
+ assert tf_buffer is not None
+ self.tensor_name = str(tf_tensor.Name())
+ self.type_name = TensorTypeList[tf_tensor.Type()]
+ self.buffer_index = tf_tensor.Buffer()
+ if (tf_buffer.DataLength() > 0):
+ self.buffer = ConvertProperNPArrayType(tf_buffer.DataAsNumpy(),
+ tf_tensor.ShapeAsNumpy(),
+ self.type_name)
+
+ # shape: Empty list([]) will mean Scalar
+ for shape_idx in range(tf_tensor.ShapeLength()):
+ # when shape signature is -1, that means unknown dim
+ if tf_tensor.ShapeSignature(shape_idx) != -1:
+ self.shape.append(int(tf_tensor.Shape(shape_idx)))
+ else:
+ self.shape.append(-1)
+
+ self.memory_size = self.GetMemorySize()
+
+ def GetMemorySize(self):
+ type_size = GetTypeSize(self.type_name)
+ if type_size == 0:
+ return 0
+
+ # memory size in bytes
+ size = int(type_size)
+ shape_length = len(self.shape)
+ if shape_length == 0:
+ return size
+
+ for shape_idx in range(shape_length):
+ shape_size = int(self.shape[shape_idx])
+ size *= shape_size
+
+ return size