diff options
Diffstat (limited to 'tools/tflitefile_tool/parser')
-rw-r--r-- | tools/tflitefile_tool/parser/__init__.py | 0 | ||||
-rwxr-xr-x | tools/tflitefile_tool/parser/model_parser.py | 31 | ||||
-rw-r--r-- | tools/tflitefile_tool/parser/tflite/tflite_enum_str_maps.py | 40 | ||||
-rwxr-xr-x | tools/tflitefile_tool/parser/tflite/tflite_operator.py | 63 | ||||
-rw-r--r-- | tools/tflitefile_tool/parser/tflite/tflite_option.py | 96 | ||||
-rwxr-xr-x | tools/tflitefile_tool/parser/tflite/tflite_parser.py | 112 | ||||
-rwxr-xr-x | tools/tflitefile_tool/parser/tflite/tflite_subgraph.py | 30 | ||||
-rwxr-xr-x | tools/tflitefile_tool/parser/tflite/tflite_tensor.py | 124 |
8 files changed, 496 insertions, 0 deletions
diff --git a/tools/tflitefile_tool/parser/__init__.py b/tools/tflitefile_tool/parser/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tools/tflitefile_tool/parser/__init__.py diff --git a/tools/tflitefile_tool/parser/model_parser.py b/tools/tflitefile_tool/parser/model_parser.py new file mode 100755 index 000000000..68cd31a23 --- /dev/null +++ b/tools/tflitefile_tool/parser/model_parser.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parser.tflite.tflite_parser import TFLiteParser + + +class ModelParser(object): + def __init__(self, model_file): + self.parser = None + # model_file: _io.BufferedReader + if model_file.name.endswith("tflite"): + self.parser = TFLiteParser(model_file) + # TODO: Add more parser + + def Parse(self): + if self.parser is None: + raise NotImplementedError + return self.parser.Parse() diff --git a/tools/tflitefile_tool/parser/tflite/tflite_enum_str_maps.py b/tools/tflitefile_tool/parser/tflite/tflite_enum_str_maps.py new file mode 100644 index 000000000..6a3a2054f --- /dev/null +++ b/tools/tflitefile_tool/parser/tflite/tflite_enum_str_maps.py @@ -0,0 +1,40 @@ +#!/usr/bin/python + +# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tflite.BuiltinOperator +import tflite.ActivationFunctionType +import tflite.BuiltinOptions + + +# Match enum value integer to name string +# Assumption 1: enum value is defined by old style (can be used on python 2) +# Assumption 2: when class define enum value, only constant value is defined and methods are not defined +# Assumption 3: only integer value is set by constant definition +def BuildEnumClassStrMap(obj): + ret = {} + for fieldName in dir(obj): + if (not fieldName.startswith('_')): + fieldValue = getattr(obj, fieldName) + if (isinstance(fieldValue, (int))): + ret[fieldValue] = fieldName + return ret + + +class EnumStrMaps(): + BuiltinOpcode = BuildEnumClassStrMap(tflite.BuiltinOperator.BuiltinOperator()) + ActivationFunctionType = BuildEnumClassStrMap( + tflite.ActivationFunctionType.ActivationFunctionType()) + BuiltinOptions = BuildEnumClassStrMap(tflite.BuiltinOptions.BuiltinOptions()) diff --git a/tools/tflitefile_tool/parser/tflite/tflite_operator.py b/tools/tflitefile_tool/parser/tflite/tflite_operator.py new file mode 100755 index 000000000..211007e1c --- /dev/null +++ b/tools/tflitefile_tool/parser/tflite/tflite_operator.py @@ -0,0 +1,63 @@ +#!/usr/bin/python + +# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ir.operator import Operator +from .tflite_enum_str_maps import EnumStrMaps +from .tflite_option import OptionLoader, GetStringOptions + + +class TFLiteOperator(Operator): + def __init__(self, operator_idx, tf_operator, input_tensors, output_tensors, + opcode_str): + super(TFLiteOperator, self).__init__() + + self.index = operator_idx + self.inputs = input_tensors + self.outputs = output_tensors + self.op_name = opcode_str + self.activation = "NONE" + self.options = "" + + self.tf_operator = tf_operator + self.tf_options = None + self.SetupBuiltinOption() + self.SetupFusedActivation() + + def SetupBuiltinOption(self): + # FIXME: workaround for ops such as custom + try: + self.tf_options = OptionLoader.GetBuiltinOptions( + self.tf_operator.BuiltinOptionsType(), self.tf_operator.BuiltinOptions()) + if self.tf_options == None: + return + + option_str = GetStringOptions(self.op_name, self.tf_options) + if option_str is None: + return + + self.options = option_str + except KeyError: + return + + def SetupFusedActivation(self): + if self.tf_options == None: + return + try: + activation_code = self.tf_options.FusedActivationFunction() + self.activation = EnumStrMaps.ActivationFunctionType[activation_code] + except AttributeError: + # This operator does not support FusedActivationFunction + pass diff --git a/tools/tflitefile_tool/parser/tflite/tflite_option.py b/tools/tflitefile_tool/parser/tflite/tflite_option.py new file mode 100644 index 000000000..b85fbae90 --- /dev/null +++ b/tools/tflitefile_tool/parser/tflite/tflite_option.py @@ -0,0 +1,96 @@ +#!/usr/bin/python + +# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .tflite_enum_str_maps import EnumStrMaps + + +def GetAttribute(o, *args): + import functools + return functools.reduce(getattr, args, o) + + +def BuildBuiltinOptionGen(): + bo_gen = {} + for val_enum in EnumStrMaps.BuiltinOptions: + val_str = EnumStrMaps.BuiltinOptions[val_enum] + try: + # Dynamically import Builtin Option classes + # 0 (NONE) is the only exception that does not have no corresponding flatbuffer-generated class + module = __import__("tflite." + val_str) + bo_gen[val_enum] = GetAttribute(module, val_str, val_str) + except ImportError as e: + assert val_enum == 0 and val_str == "NONE" + return bo_gen + + +class OptionLoader: + builtinOptionGen = BuildBuiltinOptionGen() + + @staticmethod + def GetBuiltinOptions(options_type, options_table): + if (options_table == None) and (options_type != 0): + print( + "Bad flatbuffer file: undefined builtin option table with defined option type" + ) + exit(1) + options = OptionLoader.builtinOptionGen[options_type]() + options.Init(options_table.Bytes, options_table.Pos) + return options + + +def GetStringPadding(options): + if options.Padding() == 0: + return "SAME" + elif options.Padding() == 1: + return "VALID" + else: + return "** wrong padding value **" + + +def GetStringOptions(op_name, options): + if (op_name == "AVERAGE_POOL_2D" or op_name == "MAX_POOL_2D"): + return "{}, {}, {}".format( + "Filter W:H = {}:{}".format(options.FilterWidth(), options.FilterHeight()), + "Stride W:H = {}:{}".format(options.StrideW(), + options.StrideH()), "Padding = {}".format( + GetStringPadding(options))) + elif (op_name == "CONV_2D"): + return "{}, {}, {}".format( + "Stride W:H = {}:{}".format(options.StrideW(), options.StrideH()), + "Dilation W:H = {}:{}".format(options.DilationWFactor(), + options.DilationHFactor()), + "Padding = {}".format(GetStringPadding(options))) + elif (op_name == "DEPTHWISE_CONV_2D"): + # yapf: disable + return "{}, {}, {}, {}".format( + "Stride W:H = {}:{}".format(options.StrideW(), + options.StrideH()), + "Dilation W:H = {}:{}".format(options.DilationWFactor(), + options.DilationHFactor()), + "Padding = {}".format(GetStringPadding(options)), + "DepthMultiplier = {}".format(options.DepthMultiplier())) + # yapf: enable + elif (op_name == "STRIDED_SLICE"): + # yapf: disable + return "{}, {}, {}, {}, {}".format( + "begin_mask({})".format(options.BeginMask()), + "end_mask({})".format(options.EndMask()), + "ellipsis_mask({})".format(options.EllipsisMask()), + "new_axis_mask({})".format(options.NewAxisMask()), + "shrink_axis_mask({})".format(options.ShrinkAxisMask())) + # yapf: enable + else: + return None diff --git a/tools/tflitefile_tool/parser/tflite/tflite_parser.py b/tools/tflitefile_tool/parser/tflite/tflite_parser.py new file mode 100755 index 000000000..6a8f2b8ab --- /dev/null +++ b/tools/tflitefile_tool/parser/tflite/tflite_parser.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python + +# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tflite.Model +from .tflite_subgraph import TFLiteSubgraph +from .tflite_operator import TFLiteOperator, EnumStrMaps +from .tflite_tensor import TFLiteTensor, SetTensorTypeStr + + +def HasOptionalTensor(tf_subgraph): + for operator_idx in range(tf_subgraph.OperatorsLength()): + tf_operator = tf_subgraph.Operators(operator_idx) + if -1 in tf_operator.InputsAsNumpy(): + return True + output_tensors = tf_operator.OutputsAsNumpy() + if -1 in tf_operator.OutputsAsNumpy(): + return True + + return False + + +class TFLiteSubgraphParser(object): + def __init__(self, tf_model, subgraph_index): + self.tf_model = tf_model + self.tf_subgraph = tf_model.Subgraphs(subgraph_index) + self.subg = TFLiteSubgraph(subgraph_index, self.tf_subgraph) + + # Tensor type string table + SetTensorTypeStr() + + def Parse(self): + if HasOptionalTensor(self.tf_subgraph): + # Prepare that optional input and output tensors are indicated by -1 + self.subg.tensors_map[-1] = TFLiteTensor(-1, None, None) + + # tensors + for tensor_idx in range(self.tf_subgraph.TensorsLength()): + tf_tensor = self.tf_subgraph.Tensors(tensor_idx) + buffer_idx = tf_tensor.Buffer() + tf_buffer = self.tf_model.Buffers(buffer_idx) + t = TFLiteTensor(tensor_idx, tf_tensor, tf_buffer) + self.subg.tensors_map[tensor_idx] = t + + # operators + for operator_idx in range(self.tf_subgraph.OperatorsLength()): + tf_operator = self.tf_subgraph.Operators(operator_idx) + op_name = self.GetOpcodeStr(tf_operator) + input_tensors = self.GetTensors(tf_operator.InputsAsNumpy()) + output_tensors = self.GetTensors(tf_operator.OutputsAsNumpy()) + + op = TFLiteOperator(operator_idx, tf_operator, input_tensors, output_tensors, + op_name) + self.subg.operators_map[op.index] = op + self.subg.optypes_map[op.op_name] = op + + self.subg.inputs = self.GetTensors(self.tf_subgraph.InputsAsNumpy()) + self.subg.outputs = self.GetTensors(self.tf_subgraph.OutputsAsNumpy()) + + return self.subg + + def GetOpcodeStr(self, tf_operator): + opcode_list_idx = tf_operator.OpcodeIndex() + opcode_id = self.tf_model.OperatorCodes(opcode_list_idx).BuiltinCode() + opcode_str = EnumStrMaps.BuiltinOpcode[opcode_id] + if opcode_id == 32: + # Custom operator + custom_operator = self.tf_model.OperatorCodes(tf_operator.OpcodeIndex()) + custom_op_name = custom_operator.CustomCode().decode('utf-8') + opcode_str = opcode_str + "(" + custom_op_name + ")" + return opcode_str + + def GetTensors(self, tf_tensors_index): + assert len(self.subg.tensors_map.keys()) > 0 + + return_list = [] + for tensor_idx in tf_tensors_index: + return_list.append(self.subg.tensors_map[tensor_idx]) + return return_list + + +class TFLiteParser(object): + def __init__(self, model_file): + self.model_file = model_file + + def Parse(self): + # Generate Model: top structure of tflite model file + buf = self.model_file.read() + buf = bytearray(buf) + tf_model = tflite.Model.Model.GetRootAsModel(buf, 0) + + # Model file can have many models + subg_list = [] + for subgraph_index in range(tf_model.SubgraphsLength()): + # Parse Subgraphs + subg_parser = TFLiteSubgraphParser(tf_model, subgraph_index) + subg = subg_parser.Parse() + subg_list.append(subg) + + return subg_list diff --git a/tools/tflitefile_tool/parser/tflite/tflite_subgraph.py b/tools/tflitefile_tool/parser/tflite/tflite_subgraph.py new file mode 100755 index 000000000..0c6338ec6 --- /dev/null +++ b/tools/tflitefile_tool/parser/tflite/tflite_subgraph.py @@ -0,0 +1,30 @@ +#!/usr/bin/python + +# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ir.subgraph import Subgraph + + +class TFLiteSubgraph(Subgraph): + def __init__(self, subg_idx, tf_subgraph): + super(TFLiteSubgraph, self).__init__() + self.tf_subgraph = tf_subgraph + + self.index = subg_idx + if tf_subgraph.Name() is not None: + self.subg_name = str(tf_subgraph.Name()) + self.model_name = "#{0} {1}".format(subg_idx, self.subg_name) + if (subg_idx == 0): # 0th subgraph is main subgraph + self.model_name += " (MAIN)" diff --git a/tools/tflitefile_tool/parser/tflite/tflite_tensor.py b/tools/tflitefile_tool/parser/tflite/tflite_tensor.py new file mode 100755 index 000000000..5eb35e63e --- /dev/null +++ b/tools/tflitefile_tool/parser/tflite/tflite_tensor.py @@ -0,0 +1,124 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import tflite.Tensor +import tflite.TensorType +from ir.tensor import Tensor + +TensorTypeList = {} + + +def SetTensorTypeStr(): + tensorTypeObj = tflite.TensorType.TensorType() + + for fieldName in dir(tensorTypeObj): + if (not fieldName.startswith('_')): + fieldValue = getattr(tensorTypeObj, fieldName) + if (isinstance(fieldValue, (int))): + TensorTypeList[fieldValue] = fieldName + + +TYPES_SIZE = { + 'BOOL': 1, + 'COMPLEX64': 8, + 'FLOAT16': 2, + 'FLOAT32': 4, + 'INT16': 2, + 'INT32': 4, + 'INT64': 8, + 'UINT8': 1, + 'NONE': 0, +} + + +def GetTypeSize(type_name): + try: + return TYPES_SIZE[type_name] + + except KeyError as error: + return 0 + + +TYPE_TO_NPTYPE = { + 'BOOL': np.bool_, + 'COMPLEX64': np.cdouble, + 'FLOAT16': np.float16, + 'FLOAT32': np.float32, + 'INT16': np.int16, + 'INT32': np.int32, + 'INT64': np.int64, + 'UINT8': np.uint8, +} + + +def ConvertProperNPArrayType(np_arr, np_shape, type_name): + try: + return np_arr.view(TYPE_TO_NPTYPE[type_name]).reshape(np_shape) + except KeyError as error: + return np_arr.view().reshape(np_shape) + + +class TFLiteTensor(Tensor): + def __init__(self, tensor_idx, tf_tensor, tf_buffer): + super(TFLiteTensor, self).__init__() + self.tf_tensor = tf_tensor + self.tf_buffer = tf_buffer + + self.index = int(tensor_idx) + self.tensor = tf_tensor + + # optional input + if self.index == -1: + self.type_name = "NONE" + # general input + else: + assert tf_tensor is not None + assert tf_buffer is not None + self.tensor_name = str(tf_tensor.Name()) + self.type_name = TensorTypeList[tf_tensor.Type()] + self.buffer_index = tf_tensor.Buffer() + if (tf_buffer.DataLength() > 0): + self.buffer = ConvertProperNPArrayType(tf_buffer.DataAsNumpy(), + tf_tensor.ShapeAsNumpy(), + self.type_name) + + # shape: Empty list([]) will mean Scalar + for shape_idx in range(tf_tensor.ShapeLength()): + # when shape signature is -1, that means unknown dim + if tf_tensor.ShapeSignature(shape_idx) != -1: + self.shape.append(int(tf_tensor.Shape(shape_idx))) + else: + self.shape.append(-1) + + self.memory_size = self.GetMemorySize() + + def GetMemorySize(self): + type_size = GetTypeSize(self.type_name) + if type_size == 0: + return 0 + + # memory size in bytes + size = int(type_size) + shape_length = len(self.shape) + if shape_length == 0: + return size + + for shape_idx in range(shape_length): + shape_size = int(self.shape[shape_idx]) + size *= shape_size + + return size |