diff options
Diffstat (limited to 'tools/tflitefile_tool')
35 files changed, 1363 insertions, 218 deletions
diff --git a/tools/tflitefile_tool/model_parser.py b/tools/tflitefile_tool/model_parser.py index b8967d33f..0edabbba1 100755 --- a/tools/tflitefile_tool/model_parser.py +++ b/tools/tflitefile_tool/model_parser.py @@ -1,4 +1,19 @@ #!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import sys import numpy @@ -13,6 +28,7 @@ import tflite.Model import tflite.SubGraph import argparse from operator_parser import OperatorParser +from model_printer import ModelPrinter from perf_predictor import PerfPredictor @@ -22,7 +38,6 @@ class TFLiteModelFileParser(object): self.tflite_file = args.input_file # Set print level (0 ~ 2) - # TODO: print information based on level self.print_level = args.verbose if (args.verbose > 2): self.print_level = 2 @@ -30,33 +45,34 @@ class TFLiteModelFileParser(object): self.print_level = 0 # Set tensor index list to print information - # TODO: - # Print tensors in list only - # Print all tensors if argument used and not specified index number + self.print_all_tensor = True if (args.tensor != None): - if (len(args.tensor) == 0): - self.print_all_tensor = True - else: + if (len(args.tensor) != 0): self.print_all_tensor = False self.print_tensor_index = [] - for tensor_index in args.tensor: self.print_tensor_index.append(int(tensor_index)) # Set operator index list to print information - # TODO: - # Print operators in list only - # Print all operators if argument used and not specified index number + self.print_all_operator = True if (args.operator != None): - if (len(args.operator) == 0): - self.print_all_oeprator = True - else: - self.print_all_oeprator = False + if (len(args.operator) != 0): + self.print_all_operator = False self.print_operator_index = [] - for operator_index in args.operator: self.print_operator_index.append(int(operator_index)) + def PrintModel(self, model_name, op_parser): + printer = ModelPrinter(self.print_level, op_parser, model_name) + + if self.print_all_tensor == False: + printer.SetPrintSpecificTensors(self.print_tensor_index) + + if self.print_all_operator == False: + printer.SetPrintSpecificOperators(self.print_operator_index) + + printer.PrintInfo() + def main(self): # Generate Model: top structure of tflite model file buf = self.tflite_file.read() @@ -71,19 +87,12 @@ class TFLiteModelFileParser(object): if (subgraph_index != 0): model_name = "Model #" + str(subgraph_index) - print("[" + model_name + "]\n") - - # Model inputs & outputs - model_inputs = tf_subgraph.InputsAsNumpy() - model_outputs = tf_subgraph.OutputsAsNumpy() - - print(model_name + " input tensors: " + str(model_inputs)) - print(model_name + " output tensors: " + str(model_outputs)) - - # Parse Operators and print all of operators + # Parse Operators op_parser = OperatorParser(tf_model, tf_subgraph, PerfPredictor()) op_parser.Parse() - op_parser.PrintAll() + + # print all of operators or requested objects + self.PrintModel(model_name, op_parser) if __name__ == '__main__': @@ -92,11 +101,7 @@ if __name__ == '__main__': arg_parser.add_argument( "input_file", type=argparse.FileType('rb'), help="tflite file to read") arg_parser.add_argument( - '-v', - '--verbose', - action='count', - default=0, - help="set print level (0~2, default: 0)") + '-v', '--verbose', type=int, default=1, help="set print level (0~2, default: 1)") arg_parser.add_argument( '-t', '--tensor', nargs='*', help="tensor ID to print information (default: all)") arg_parser.add_argument( diff --git a/tools/tflitefile_tool/model_printer.py b/tools/tflitefile_tool/model_printer.py new file mode 100644 index 000000000..ad06fa7a7 --- /dev/null +++ b/tools/tflitefile_tool/model_printer.py @@ -0,0 +1,142 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from operator_printer import OperatorPrinter +from tensor_printer import TensorPrinter + + +class ModelPrinter(object): + def __init__(self, verbose, op_parser, model_name): + self.verbose = verbose + self.op_parser = op_parser + self.model_name = model_name + self.print_all_tensor = True + self.print_tensor_index_list = None + self.print_all_operator = True + self.print_operator_index_list = None + + def SetPrintSpecificTensors(self, tensor_indices): + if len(tensor_indices) != 0: + self.print_all_tensor = False + self.print_tensor_index_list = tensor_indices + + def SetPrintSpecificOperators(self, operator_indices): + if len(operator_indices) != 0: + self.print_all_operator = False + self.print_operator_index_list = operator_indices + + def PrintInfo(self): + if self.print_all_tensor == True and self.print_all_operator == True: + self.PrintModelInfo() + self.PrintAllOperatorsInList() + self.PrintAllTypesInfo() + self.PrintTotalMemory() + + if self.print_all_tensor == False: + print('') + self.PrintSpecificTensors() + + if self.print_all_operator == False: + print('') + self.PrintSpecificOperators() + + def PrintModelInfo(self): + print("[" + self.model_name + "]\n") + if self.verbose > 0: + model_inputs = self.op_parser.tf_subgraph.InputsAsNumpy() + model_outputs = self.op_parser.tf_subgraph.OutputsAsNumpy() + print(self.model_name + " input tensors: " + str(model_inputs)) + print(self.model_name + " output tensors: " + str(model_outputs)) + print('') + + def PrintAllOperatorsInList(self): + if (self.verbose < 1): + return + + for operator in self.op_parser.operators_in_list: + printer = OperatorPrinter(self.verbose, operator) + printer.PrintInfo(self.op_parser.perf_predictor) + print('') + + print('') + + def PrintAllTypesInfo(self): + print("Number of all operator types: {0}".format( + len(self.op_parser.operators_per_type))) + + # number of instructions of all operator types to print if verbose level is 2 + total_instrs = 0 + + # (a string of the operator type, a list of operators which are the same operator type) + for type_str, oper_list in self.op_parser.operators_per_type.items(): + # number of occurrence of this operator type + occur = len(oper_list) + + optype_info_str = "\t{type_str:38}: {occur:4}".format( + type_str=type_str, occur=occur) + + if self.verbose == 2: + # this operator type can be computed? + can_compute = oper_list[0].operation.can_compute + + # total number of instructions of the same operator types + if can_compute: + instrs = sum( + operator.operation.TotalInstrNum() for operator in oper_list) + total_instrs = total_instrs + instrs + instrs = "{:,}".format(instrs) + else: + instrs = "???" + + optype_info_str = optype_info_str + " \t (instrs: {instrs})".format( + instrs=instrs) + + print(optype_info_str) + + summary_str = "{0:46}: {1:4}".format("Number of all operators", + len(self.op_parser.operators_in_list)) + if self.verbose == 2: + total_instrs = "{:,}".format(total_instrs) + summary_str = summary_str + " \t (total instrs: {0})".format(total_instrs) + + print(summary_str) + print('') + + def PrintSpecificTensors(self): + for tensor in self.op_parser.GetAllTensors(): + if tensor.tensor_idx in self.print_tensor_index_list: + printer = TensorPrinter(self.verbose, tensor) + printer.PrintInfo() + print('') + print('') + + def PrintSpecificOperators(self): + for operator in self.op_parser.operators_in_list: + if operator.operator_idx in self.print_operator_index_list: + printer = OperatorPrinter(self.verbose, operator) + printer.PrintInfo(self.op_parser.perf_predictor) + print('') + + print('') + + def PrintTotalMemory(self): + total_memory = 0 + for tensor in self.op_parser.GetAllTensors(): + total_memory += tensor.memory_size + + from tensor_printer import ConvertBytesToHuman + print("Expected total memory for allocating all tensors: {0}".format( + ConvertBytesToHuman(total_memory))) diff --git a/tools/tflitefile_tool/operation.py b/tools/tflitefile_tool/operation.py index 77fc5db9a..127d6c566 100755 --- a/tools/tflitefile_tool/operation.py +++ b/tools/tflitefile_tool/operation.py @@ -1,5 +1,19 @@ #!/usr/bin/python +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import tflite.Conv2DOptions import tflite.Pool2DOptions import tflite.BuiltinOptions diff --git a/tools/tflitefile_tool/operator_parser.py b/tools/tflitefile_tool/operator_parser.py index 9728d53b7..71b1a6d93 100755 --- a/tools/tflitefile_tool/operator_parser.py +++ b/tools/tflitefile_tool/operator_parser.py @@ -1,5 +1,19 @@ #!/usr/bin/python +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import tflite.Model import tflite.SubGraph import tflite.Operator @@ -62,6 +76,18 @@ class OperatorParser(object): return_list.append(Tensor(tensor_idx, tf_tensor, tf_buffer)) return return_list + def GetAllTensors(self): + return_list = list() + for tensor_idx in range(self.tf_subgraph.TensorsLength()): + if (tensor_idx < 0): + return_list.append(Tensor(tensor_idx, 0, 0)) + continue + tf_tensor = self.tf_subgraph.Tensors(tensor_idx) + buffer_idx = tf_tensor.Buffer() + tf_buffer = self.tf_model.Buffers(buffer_idx) + return_list.append(Tensor(tensor_idx, tf_tensor, tf_buffer)) + return return_list + def AppendOperator(self, operator): self.operators_in_list.append(operator) @@ -69,45 +95,3 @@ class OperatorParser(object): if opcode_str not in self.operators_per_type: self.operators_per_type[opcode_str] = list() self.operators_per_type[opcode_str].append(operator) - - def PrintAll(self): - print('') - self.PrintAllOperatorsInList() - print('') - self.PrintAllTypesInfo() - print('') - - def PrintAllOperatorsInList(self): - for operator in self.operators_in_list: - operator.PrintInfo(self.perf_predictor) - print('') - - def PrintAllTypesInfo(self): - print("Number of all operator types: {0}".format(len(self.operators_per_type))) - - # number of instructions of all operator types - total_instrs = 0 - - # (a string of the operator type, a list of operators which are the same operator type) - for type_str, oper_list in self.operators_per_type.items(): - # this operator type can be computed? - can_compute = oper_list[0].operation.can_compute - - # number of occurrence of this operator type - occur = len(oper_list) - - # total number of instructions of the same operator types - if can_compute: - instrs = sum(operator.operation.TotalInstrNum() for operator in oper_list) - total_instrs = total_instrs + instrs - instrs = "{:,}".format(instrs) - else: - instrs = "???" - - print("\t{type_str:38}: {occur:4} \t (instrs: {instrs})".format( - type_str=type_str, occur=occur, instrs=instrs)) - - total_instrs = "{:,}".format(total_instrs) - print("{0:46}: {1:4} \t (total instrs: {2})".format("Number of all operators", - len(self.operators_in_list), - total_instrs)) diff --git a/tools/tflitefile_tool/operator_printer.py b/tools/tflitefile_tool/operator_printer.py new file mode 100644 index 000000000..9b6f97d24 --- /dev/null +++ b/tools/tflitefile_tool/operator_printer.py @@ -0,0 +1,72 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from operator_wrapping import Operator +from tensor_printer import TensorPrinter +from option_printer import OptionPrinter +from perf_predictor import PerfPredictor + + +def GetStrTensorIndex(tensors): + return_string = "[" + for idx in range(len(tensors)): + if idx != 0: + return_string += ", " + return_string += str(tensors[idx].tensor_idx) + return_string += "]" + return return_string + + +class OperatorPrinter(object): + def __init__(self, verbose, operator): + self.verbose = verbose + self.operator = operator + + def PrintInfo(self, perf_predictor=None): + if (self.verbose < 1): + return + + op_str = "Operator {0}: {1}".format(self.operator.operator_idx, + self.operator.opcode_str) + + if self.verbose == 2: + # total instruction num + instrs = "{:,}".format(self.operator.operation.TotalInstrNum() + ) if self.operator.operation.can_compute else "???" + + # total operation cycles + cycles = "{:,}".format( + (perf_predictor.PredictCycles(self.operator.operation)) + ) if self.operator.operation.can_compute and perf_predictor != None else "???" + + op_str = op_str + "(instrs: {0}, cycls: {1})".format(instrs, cycles) + + print(op_str) + print("\tFused Activation: " + self.operator.fused_activation) + self.PrintTensors() + + def PrintTensors(self): + print("\tInput Tensors" + GetStrTensorIndex(self.operator.inputs)) + for tensor in self.operator.inputs: + TensorPrinter(self.verbose, tensor).PrintInfo("\t\t") + print("\tOutput Tensors" + GetStrTensorIndex(self.operator.outputs)) + for tensor in self.operator.outputs: + TensorPrinter(self.verbose, tensor).PrintInfo("\t\t") + + # operator option + # Some operations does not have option. In such case no option is printed + OptionPrinter(self.verbose, self.operator.opcode_str, + self.operator.options).PrintInfo("\t") diff --git a/tools/tflitefile_tool/operator_wrapping.py b/tools/tflitefile_tool/operator_wrapping.py index 1b7f55a4c..64bad1f08 100755 --- a/tools/tflitefile_tool/operator_wrapping.py +++ b/tools/tflitefile_tool/operator_wrapping.py @@ -1,12 +1,24 @@ #!/usr/bin/python +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import tflite.Operator import tflite.OperatorCode import tflite.BuiltinOperator import tflite.ActivationFunctionType -from tensor_wrapping import Tensor from operation import Operation -from perf_predictor import PerfPredictor # Match enum value integer to name string @@ -30,16 +42,6 @@ class EnumStrMaps(): BuiltinOptions = BuildEnumClassStrMap(tflite.BuiltinOptions.BuiltinOptions()) -def GetStrTensorIndex(tensors): - return_string = "[" - for idx in range(len(tensors)): - if idx != 0: - return_string += ", " - return_string += str(tensors[idx].tensor_idx) - return_string += "]" - return return_string - - def GetAttribute(o, *args): import functools return functools.reduce(getattr, args, o) @@ -64,6 +66,11 @@ class OptionLoader: @staticmethod def GetBuiltinOptions(options_type, options_table): + if (options_table == None) and (options_type != 0): + print( + "Bad flatbuffer file: undefined builtin option table with defined option type" + ) + exit(1) options = OptionLoader.builtinOptionGen[options_type]() options.Init(options_table.Bytes, options_table.Pos) return options @@ -79,30 +86,19 @@ class Operator(object): self.opcode_str = opcode_str self.operation = Operation(self.tf_operator, self.opcode_str, self.inputs, self.outputs) + self.fused_activation = "NONE" + self.SetupBuiltinOption() + self.SetupFusedActivation() - def PrintInfo(self, perf_predictor=None): - # total instruction num - instrs = "{:,}".format( - self.operation.TotalInstrNum()) if self.operation.can_compute else "???" - - # total operation cycles - cycles = "{:,}".format( - (perf_predictor.PredictCycles(self.operation) - )) if self.operation.can_compute and perf_predictor != None else "???" - - print("Operator {0}: {1} (instrs: {2}, cycls: {3})".format( - self.operator_idx, self.opcode_str, instrs, cycles)) - - self.PrintOptionInfo() - - print("\tInput Tensors" + GetStrTensorIndex(self.inputs)) - for tensor in self.inputs: - tensor.PrintInfo("\t\t") - print("\tOutput Tensors" + GetStrTensorIndex(self.outputs)) - for tensor in self.outputs: - tensor.PrintInfo("\t\t") + def SetupBuiltinOption(self): + try: + self.options = OptionLoader.GetBuiltinOptions( + self.tf_operator.BuiltinOptionsType(), self.tf_operator.BuiltinOptions()) + except KeyError: + self.options = 0 + return - def PrintOptionInfo(self): + def SetupFusedActivation(self): # FIXME: workaround for ops such as custom try: options = OptionLoader.GetBuiltinOptions( @@ -113,8 +109,7 @@ class Operator(object): # fused activation function try: activation_code = options.FusedActivationFunction() - fused_activation = EnumStrMaps.ActivationFunctionType[activation_code] - print("\tFused Activation: " + fused_activation) + self.fused_activation = EnumStrMaps.ActivationFunctionType[activation_code] except AttributeError: # This operator does not support FusedActivationFunction pass diff --git a/tools/tflitefile_tool/option_printer.py b/tools/tflitefile_tool/option_printer.py new file mode 100644 index 000000000..08754f1ce --- /dev/null +++ b/tools/tflitefile_tool/option_printer.py @@ -0,0 +1,69 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class OptionPrinter(object): + def __init__(self, verbose, op_name, options): + self.verbose = verbose + self.op_name = op_name + self.options = options + + def GetPadding(self): + if self.options.Padding() == 0: + return "SAME" + elif self.options.Padding() == 1: + return "VALID" + else: + return "** wrong padding value **" + + def PrintInfo(self, tab=""): + if (self.verbose < 1): + pass + if (self.options == 0): + return + + if (self.op_name == "AVERAGE_POOL_2D" or self.op_name == "MAX_POOL_2D"): + print("{}Options".format(tab)) + + print("{}\t{}, {}, {}".format( + tab, "Filter W:H = {}:{}".format(self.options.FilterWidth(), + self.options.FilterHeight()), + "Stride W:H = {}:{}".format(self.options.StrideW(), + self.options.StrideH()), + "Padding = {}".format(self.GetPadding()))) + + elif (self.op_name == "CONV_2D"): + print("{}Options".format(tab)) + + print("{}\t{}, {}, {}".format( + tab, "Stride W:H = {}:{}".format(self.options.StrideW(), + self.options.StrideH()), + "Dilation W:H = {}:{}".format(self.options.DilationWFactor(), + self.options.DilationHFactor()), + "Padding = {}".format(self.GetPadding()))) + + elif (self.op_name == "DEPTHWISE_CONV_2D"): + print("{}Options".format(tab)) + + # yapf: disable + print("{}\t{}, {}, {}, {}".format( + tab, "Stride W:H = {}:{}".format(self.options.StrideW(), + self.options.StrideH()), + "Dilation W:H = {}:{}".format(self.options.DilationWFactor(), + self.options.DilationHFactor()), + "Padding = {}".format(self.GetPadding()), + "DepthMultiplier = {}".format(self.options.DepthMultiplier()))) + # yapf: enable diff --git a/tools/tflitefile_tool/perf_predictor.py b/tools/tflitefile_tool/perf_predictor.py index 8880c8e71..ea5c15a33 100755 --- a/tools/tflitefile_tool/perf_predictor.py +++ b/tools/tflitefile_tool/perf_predictor.py @@ -1,5 +1,18 @@ #!/usr/bin/python +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from operation import Operation diff --git a/tools/tflitefile_tool/select_operator.py b/tools/tflitefile_tool/select_operator.py index 55ca1acd9..c5d311d59 100755..100644 --- a/tools/tflitefile_tool/select_operator.py +++ b/tools/tflitefile_tool/select_operator.py @@ -1,4 +1,19 @@ #!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import sys import numpy @@ -61,8 +76,8 @@ def GenerateOperatorCodes(new_builder, sample_model, used_operators_dic): if operator_code_idx in used_operators_dic: operator_code = sample_model.OperatorCodes(operator_code_idx) operator_code_string = operator_code.CustomCode() - if (operator_code_string != - "") and (not operator_code_string in new_operator_code_string_list): + if operator_code_string and (operator_code_string != "") and ( + not operator_code_string in new_operator_code_string_list): new_operator_code_string_list[ operator_code_string] = new_builder.CreateString(operator_code_string) @@ -209,26 +224,10 @@ def GenerateTensors(new_builder, selected_subgraph, used_tensors_dic, used_buffe return new_builder.EndVector(new_tensor_num) -import tflite.Conv2DOptions -import tflite.DepthwiseConv2DOptions -import tflite.Pool2DOptions -import tflite.FullyConnectedOptions -import tflite.SoftmaxOptions -import tflite.ConcatenationOptions -import tflite.ReshapeOptions -import tflite.AddOptions -import tflite.SubOptions -import tflite.MulOptions -import tflite.DivOptions -import tflite.ResizeBilinearOptions -import tflite.StridedSliceOptions -import tflite.CastOptions -import tflite.TopKV2Options -import tflite.GatherOptions - - def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_type): + # Conv2D option + import tflite.Conv2DOptions if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().Conv2DOptions: conv2d_options = tflite.Conv2DOptions.Conv2DOptions() @@ -245,6 +244,8 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t new_builder, conv2d_options.FusedActivationFunction()) return tflite.Conv2DOptions.Conv2DOptionsEnd(new_builder) + # DepthwiseConv2D option + import tflite.DepthwiseConv2DOptions if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions( ).DepthwiseConv2DOptions: @@ -263,8 +264,17 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t new_builder, depthconv2d_option.DepthMultiplier()) tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddFusedActivationFunction( new_builder, depthconv2d_option.FusedActivationFunction()) + tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDilationWFactor( + new_builder, depthconv2d_option.DilationWFactor()) + tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDilationHFactor( + new_builder, depthconv2d_option.DilationHFactor()) return tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsEnd(new_builder) + # ConcatEmbeddingsOptions: not supported + # LSHProjectionOptions: not supported + + # Pool2DPOption + import tflite.Pool2DOptions if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().Pool2DOptions: pool2d_option = tflite.Pool2DOptions.Pool2DOptions() @@ -282,6 +292,22 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t new_builder, pool2d_option.FusedActivationFunction()) return tflite.Pool2DOptions.Pool2DOptionsEnd(new_builder) + # SVDFOptions: not supported + + # RNNOptions + import tflite.RNNOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().RNNOptions: + + rnn_option = tflite.RNNOptions.RNNOptions() + rnn_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.RNNOptions.RNNOptionsStart(new_builder) + tflite.RNNOptions.RNNOptionsAddFusedActivationFunction( + new_builder, rnn_option.FusedActivationFunction()) + return tflite.RNNOptions.RNNOptionsEnd(new_builder) + + # FullyConnectedOptions + import tflite.FullyConnectedOptions if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions( ).FullyConnectedOptions: @@ -293,6 +319,8 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t new_builder, fc_option.FusedActivationFunction()) return tflite.FullyConnectedOptions.FullyConnectedOptionsEnd(new_builder) + # SoftmaxOptions + import tflite.SoftmaxOptions if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SoftmaxOptions: softmax_option = tflite.SoftmaxOptions.SoftmaxOptions() @@ -302,6 +330,8 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t tflite.SoftmaxOptions.SoftmaxOptionsAddBeta(new_builder, softmax_option.Beta()) return tflite.SoftmaxOptions.SoftmaxOptionsEnd(new_builder) + # ConcatenationOptions + import tflite.ConcatenationOptions if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ConcatenationOptions: concat_option = tflite.ConcatenationOptions.ConcatenationOptions() @@ -314,6 +344,72 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t new_builder, concat_option.FusedActivationFunction()) return tflite.ConcatenationOptions.ConcatenationOptionsEnd(new_builder) + # AddOptions + import tflite.AddOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().AddOptions: + + add_option = tflite.AddOptions.AddOptions() + add_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.AddOptions.AddOptionsStart(new_builder) + tflite.AddOptions.AddOptionsAddFusedActivationFunction( + new_builder, add_option.FusedActivationFunction()) + return tflite.AddOptions.AddOptionsEnd(new_builder) + + # L2NormOptions + import tflite.L2NormOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().L2NormOptions: + + l2norm_option = tflite.L2NormOptions.L2NormOptions() + l2norm_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.L2NormOptions.L2NormOptionsStart(new_builder) + tflite.L2NormOptions.L2NormOptionsAddFusedActivationFunction( + new_builder, l2norm_option.FusedActivationFunction()) + return tflite.L2NormOptions.L2NormOptionsEnd(new_builder) + + # LocalResponseNormalizationOptions + import tflite.LocalResponseNormalizationOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions( + ).LocalResponseNormalizationOptions: + + lrn_option = tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptions( + ) + lrn_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsStart( + new_builder) + tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddRadius( + new_builder, lrn_option.Radius()) + tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddBias( + new_builder, lrn_option.Bias()) + tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddAlpha( + new_builder, lrn_option.Alpha()) + tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsAddBeta( + new_builder, lrn_option.Beta()) + return tflite.LocalResponseNormalizationOptions.LocalResponseNormalizationOptionsEnd( + new_builder) + + # LSTMOptions: not supported + + # ResizeBilinearOptions + import tflite.ResizeBilinearOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions( + ).ResizeBilinearOptions: + + resize_bilinear_option = tflite.ResizeBilinearOptions.ResizeBilinearOptions() + resize_bilinear_option.Init(selected_builtin_option.Bytes, + selected_builtin_option.Pos) + + tflite.ResizeBilinearOptions.ResizeBilinearOptionsStart(new_builder) + tflite.ResizeBilinearOptions.ResizeBilinearOptionsAddAlignCorners( + new_builder, resize_bilinear_option.AlignCorners()) + return tflite.ResizeBilinearOptions.ResizeBilinearOptionsEnd(new_builder) + + # CallOptions: not supported + + # ReshapeOptions + import tflite.ReshapeOptions if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ReshapeOptions: reshape_option = tflite.ReshapeOptions.ReshapeOptions() @@ -333,26 +429,25 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t tflite.ReshapeOptions.ReshapeOptionsAddNewShape(new_builder, new_shape) return tflite.ReshapeOptions.ReshapeOptionsEnd(new_builder) - if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().AddOptions: - - add_option = tflite.AddOptions.AddOptions() - add_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + # SkipGramOptions: not supported - tflite.AddOptions.AddOptionsStart(new_builder) - tflite.AddOptions.AddOptionsAddFusedActivationFunction( - new_builder, add_option.FusedActivationFunction()) - return tflite.AddOptions.AddOptionsEnd(new_builder) + # SpaceToDepthOptions + import tflite.SpaceToDepthOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SpaceToDepthOptions: - if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SubOptions: + space_to_depth_option = tflite.SpaceToDepthOptions.SpaceToDepthOptions() + space_to_depth_option.Init(selected_builtin_option.Bytes, + selected_builtin_option.Pos) - sub_option = tflite.SubOptions.SubOptions() - sub_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + tflite.SpaceToDepthOptions.SpaceToDepthOptionsStart(new_builder) + tflite.SpaceToDepthOptions.SpaceToDepthOptionsAddBlockSize( + new_builder, space_to_depth_option.BlockSize()) + return tflite.SpaceToDepthOptions.SpaceToDepthOptionsEnd(new_builder) - tflite.SubOptions.SubOptionsStart(new_builder) - tflite.SubOptions.SubOptionsAddFusedActivationFunction( - new_builder, sub_option.FusedActivationFunction()) - return tflite.SubOptions.SubOptionsEnd(new_builder) + # EmbeddingLookupSparseOptions: not supported + # MulOptions + import tflite.MulOptions if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().MulOptions: mul_option = tflite.MulOptions.MulOptions() @@ -363,6 +458,85 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t new_builder, mul_option.FusedActivationFunction()) return tflite.MulOptions.MulOptionsEnd(new_builder) + # PadOptions + import tflite.PadOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().PadOptions: + + pad_option = tflite.PadOptions.PadOptions() + pad_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.PadOptions.PadOptionsStart(new_builder) + return tflite.PadOptions.PadOptionsEnd(new_builder) + + # GatherOptions + import tflite.GatherOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().GatherOptions: + + gather_option = tflite.GatherOptions.GatherOptions() + gather_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.GatherOptions.GatherOptionsStart(new_builder) + tflite.GatherOptions.GatherOptionsAddAxis(new_builder, gather_option.Axis()) + return tflite.GatherOptions.GatherOptionsEnd(new_builder) + + # BatchToSpaceNDOptions + import tflite.BatchToSpaceNDOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions( + ).BatchToSpaceNDOptions: + + btsnd_option = tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptions() + btsnd_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptionsStart(new_builder) + return tflite.BatchToSpaceNDOptions.BatchToSpaceNDOptionsEnd(new_builder) + + # SpaceToBatchNDOptions + import tflite.SpaceToBatchNDOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions( + ).SpaceToBatchNDOptions: + + stbnd_option = tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptions() + stbnd_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptionsStart(new_builder) + return tflite.SpaceToBatchNDOptions.SpaceToBatchNDOptionsEnd(new_builder) + + # TransposeOptions: + import tflite.TransposeOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TransposeOptions: + + transpose_option = tflite.TransposeOptions.TransposeOptions() + transpose_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.TransposeOptions.TransposeOptionsStart(new_builder) + return tflite.TransposeOptions.TransposeOptionsEnd(new_builder) + + # ReducerOptions + import tflite.ReducerOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ReducerOptions: + + reducer_option = tflite.ReducerOptions.ReducerOptions() + reducer_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.ReducerOptions.ReducerOptionsStart(new_builder) + tflite.ReducerOptions.ReducerOptionsAddKeepDims(new_builder, + reducer_option.KeepDims()) + return tflite.ReducerOptions.ReducerOptionsEnd(new_builder) + + # SubOptions + import tflite.SubOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SubOptions: + + sub_option = tflite.SubOptions.SubOptions() + sub_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.SubOptions.SubOptionsStart(new_builder) + tflite.SubOptions.SubOptionsAddFusedActivationFunction( + new_builder, sub_option.FusedActivationFunction()) + return tflite.SubOptions.SubOptionsEnd(new_builder) + + # DivOptions + import tflite.DivOptions if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().DivOptions: div_option = tflite.DivOptions.DivOptions() @@ -373,18 +547,32 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t new_builder, div_option.FusedActivationFunction()) return tflite.DivOptions.DivOptionsEnd(new_builder) - if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions( - ).ResizeBilinearOptions: + # SqueezeOptions + import tflite.SqueezeOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SqueezeOptions: - resize_bilinear_option = tflite.ResizeBilinearOptions.ResizeBilinearOptions() - resize_bilinear_option.Init(selected_builtin_option.Bytes, - selected_builtin_option.Pos) + squeeze_option = tflite.SqueezeOptions.SqueezeOptions() + squeeze_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) - tflite.ResizeBilinearOptions.ResizeBilinearOptionsStart(new_builder) - tflite.ResizeBilinearOptions.ResizeBilinearOptionsAddAlignCorners( - new_builder, resize_bilinear_option.AlignCorners()) - return tflite.ResizeBilinearOptions.ResizeBilinearOptionsEnd(new_builder) + squeeze_dims_num = squeeze_option.SqueezeDimsLength() + if squeeze_dims_num != 0: + tflite.SqueezeOptions.SqueezeOptionsStartSqueezeDimsVector( + new_builder, squeeze_dims_num) + for squeeze_dims_idx in reversed(range(squeeze_dims_num)): + squeeze_dims_val = squeeze_option.SqueezeDims(squeeze_dims_idx) + new_builder.PrependInt32(squeeze_dims_val) + new_squeeze_dims = new_builder.EndVector(squeeze_dims_num) + + tflite.SqueezeOptions.SqueezeOptionsStart(new_builder) + if squeeze_dims_num != 0: + tflite.SqueezeOptions.SqueezeOptionsAddSqueezeDims(new_builder, + new_squeeze_dims) + return tflite.SqueezeOptions.SqueezeOptionsEnd(new_builder) + + # SequenceRNNOptions: not supported + # StridedSliceOptions + import tflite.StridedSliceOptions if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().StridedSliceOptions: stride_slice_option = tflite.StridedSliceOptions.StridedSliceOptions() @@ -405,14 +593,18 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t return tflite.StridedSliceOptions.StridedSliceOptionsEnd(new_builder) - if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().CastOptions: + # ExpOptions + import tflite.ExpOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ExpOptions: - cast_option = tflite.CastOptions.CastOptions() - cast_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + exp_option = tflite.ExpOptions.ExpOptions() + exp_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) - tflite.CastOptions.CastOptionsStart(new_builder) - return tflite.CastOptions.CastOptionsEnd(new_builder) + tflite.ExpOptions.ExpOptionsStart(new_builder) + return tflite.ExpOptions.ExpOptionsEnd(new_builder) + # TopKV2Options + import tflite.TopKV2Options if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TopKV2Options: topkv2_option = tflite.TopKV2Options.TopKV2Options() @@ -421,17 +613,185 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t tflite.TopKV2Options.TopKV2OptionsStart(new_builder) return tflite.TopKV2Options.TopKV2OptionsEnd(new_builder) - if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().GatherOptions: + # SplitOptions + import tflite.SplitOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SplitOptions: - gather_option = tflite.GatherOptions.GatherOptions() - gather_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + split_option = tflite.SplitOptions.SplitOptions() + split_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) - tflite.GatherOptions.GatherOptionsStart(new_builder) - tflite.GatherOptions.GatherOptionsAddAxis(new_builder, gather_option.Axis()) - return tflite.GatherOptions.GatherOptionsEnd(new_builder) + tflite.SplitOptions.SplitOptionsStart(new_builder) + tflite.SplitOptions.SplitOptionsAddNumSplits(new_builder, + split_option.NumSplits()) + return tflite.SplitOptions.SplitOptionsEnd(new_builder) + + # LogSoftmaxOptions: not supported + + # CastOptions: not supported + import tflite.CastOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().CastOptions: + + cast_option = tflite.CastOptions.CastOptions() + cast_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.CastOptions.CastOptionsStart(new_builder) + return tflite.CastOptions.CastOptionsEnd(new_builder) + + # DequantizeOptions: + import tflite.DequantizeOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().DequantizeOptions: + + dequantize_option = tflite.DequantizeOptions.DequantizeOptions() + dequantize_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.EqualOptions.DequantizeOptionsStart(new_builder) + return tflite.DequantizeOptions.DequantizeOptionsEnd(new_builder) + + # MaximumMinimumOptions: not supported + + # ArgMaxOptions + import tflite.ArgMaxOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ArgMaxOptions: + + arg_max_option = tflite.ArgMaxOptions.ArgMaxOptions() + arg_max_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.ArgMaxOptions.ArgMaxOptionsStart(new_builder) + tflite.ArgMaxOptions.ArgMaxOptionsAddOutputType(new_builder, + arg_max_option.OutputType()) + return tflite.ArgMaxOptions.ArgMaxOptionsEnd(new_builder) + + # LessOptions: not supported + + # NegOptions + import tflite.NegOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().NegOptions: + + neg_option = tflite.NegOptions.NegOptions() + neg_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.NegOptions.NegOptionsStart(new_builder) + return tflite.NegOptions.NegOptionsEnd(new_builder) + + # EqualOptions + import tflite.EqualOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().EqualOptions: + + equal_option = tflite.EqualOptions.EqualOptions() + equal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.EqualOptions.EqualOptionsStart(new_builder) + return tflite.EqualOptions.EqualOptionsEnd(new_builder) + + # PadV2Options: not supported + # GreaterOptions: not supported + # GreaterEqualOptions: not supported + # LessEqualOptions: not supported + # SelectOptions: not supported + # SliceOptions: not supported + + # TransposeConvOptions + import tflite.TransposeConvOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().TransposeConvOptions: + + transposeconv_option = tflite.TransposeConvOptions.TransposeConvOptions() + transposeconv_option.Init(selected_builtin_option.Bytes, + selected_builtin_option.Pos) + + tflite.TransposeConvOptions.TransposeConvOptionsStart(new_builder) + tflite.TransposeConvOptions.TransposeConvOptionsAddPadding( + new_builder, transposeconv_option.Padding()) + tflite.TransposeConvOptions.TransposeConvOptionsAddStrideW( + new_builder, transposeconv_option.StrideW()) + tflite.TransposeConvOptions.TransposeConvOptionsAddStrideH( + new_builder, transposeconv_option.StrideH()) + return tflite.TransposeConvOptions.TransposeConvOptionsEnd(new_builder) + + # SparseToDenseOptions: not supported + # TileOptions: not supported + # ExpandDimsOptions: not supported + + # NotEqualOptions: + import tflite.NotEqualOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().NotEqualOptions: + + notequal_option = tflite.NotEqualOptions.NotEqualOptions() + notequal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.NotEqualOptions.NotEqualOptionsStart(new_builder) + return tflite.NotEqualOptions.NotEqualOptionsEnd(new_builder) + + # ShapeOptions: not supported + # PowOptions: not supported + # ArgMinOptions: not supported + # FakeQuantOptions: not supported + + # PackOptions: + import tflite.PackOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().PackOptions: + + pack_option = tflite.PackOptions.PackOptions() + pack_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.PackOptions.PackOptionsStart(new_builder) + tflite.PackOptions.PackOptionsAddValuesCount(new_builder, + pack_option.ValuesCount()) + tflite.PackOptions.PackOptionsAddAxis(new_builder, pack_option.Axis()) + return tflite.PackOptions.PackOptionsEnd(new_builder) + + # LogicalOrOptions: + import tflite.LogicalOrOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalOrOptions: + + logical_or_option = tflite.LogicalAndOptions.LogicalOrOptions() + logical_or_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.LogicalOrOptions.LogicalOrOptionsStart(new_builder) + return tflite.LogicalOrOptions.LogicalOrOptionsEnd(new_builder) + + # OneHotOptions: not supported + + # LogicalNotOptions + import tflite.LogicalNotOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalNotOptions: + + equal_option = tflite.LogicalNotOptions.LogicalNotOptions() + equal_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.LogicalNotOptions.LogicalNotOptionsStart(new_builder) + return tflite.LogicalNotOptions.LogicalNotOptionsEnd(new_builder) + + # UnpackOptions: + import tflite.UnpackOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().UnpackOptions: + + unpack_option = tflite.unpackOptions.unpackOptions() + unpack_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.unpackOptions.UnpackOptionsStart(new_builder) + tflite.unpackOptions.UnpackOptionsAddNum(new_builder, unpack_option.Num()) + tflite.PackOptions.UnpackOptionsAddAxis(new_builder, unpack_option.Axis()) + return tflite.UnpackOptions.UnpackOptionsEnd(new_builder) + + # FloorDivOptions: not supported + # SquareOptions: not supported + # ZerosLikeOptions: not supported + # FillOptions: not supported + + # LogicalAndOptions + import tflite.LogicalAndOptions + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().LogicalAndOptions: + + logical_and_option = tflite.LogicalAndOptions.LogicalAndOptions() + logical_and_option.Init(selected_builtin_option.Bytes, + selected_builtin_option.Pos) + + tflite.LogicalAndOptions.LogicalAndOptionsStart(new_builder) + return tflite.LogicalAndOptions.LogicalAndOptionsEnd(new_builder) # Cannot handle builtin option type yet - return 0 + print("Cannot handle this option yet") + exit(1) def GenerateOperator(new_builder, selected_operator, used_tensors_dic, @@ -556,7 +916,7 @@ def GenerateSubgraph(new_builder, selected_subgraph, opcode_list, new_input_tens # Name subgraph_name = selected_subgraph.Name() have_name = False - if subgraph_name != "": + if subgraph_name and subgraph_name != "": have_name = True new_subgraph_name = new_builder.CreateString(subgraph_name) diff --git a/tools/tflitefile_tool/tensor_printer.py b/tools/tflitefile_tool/tensor_printer.py new file mode 100644 index 000000000..f566a6e10 --- /dev/null +++ b/tools/tflitefile_tool/tensor_printer.py @@ -0,0 +1,80 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tensor_wrapping import Tensor + +SYMBOLS = ['B', 'K', 'M', 'G', 'T'] + + +def ConvertBytesToHuman(n): + n = int(n) + if n < 0: + return 0 + + format_str = "%(val)3.1f%(symb)s" + prefix = {} + for i, s in enumerate(SYMBOLS[1:]): + prefix[s] = 1 << (i + 1) * 10 + + for symbol in reversed(SYMBOLS[1:]): + if n >= prefix[symbol]: + v = float(n) / prefix[symbol] + return format_str % dict(symb=symbol, val=v) + + return format_str % dict(symb=SYMBOLS[0], val=n) + + +class TensorPrinter(object): + def __init__(self, verbose, tensor): + self.verbose = verbose + self.tensor = tensor + + def PrintInfo(self, depth_str=""): + if (self.verbose < 1): + pass + + print_str = "" + if self.tensor.tensor_idx < 0: + print_str = "Tensor {0:4}".format(self.tensor.tensor_idx) + else: + buffer_idx = self.tensor.tf_tensor.Buffer() + isEmpty = "Filled" + if (self.tensor.tf_buffer.DataLength() == 0): + isEmpty = " Empty" + shape_str = self.GetShapeString() + type_name = self.tensor.type_name + + shape_name = "" + if self.tensor.tf_tensor.Name() != 0: + shape_name = self.tensor.tf_tensor.Name() + + memory_size = ConvertBytesToHuman(self.tensor.memory_size) + + print_str = "Tensor {0:4} : buffer {1:4} | {2} | {3:7} | Memory {4:6} | Shape {5} ({6})".format( + self.tensor.tensor_idx, buffer_idx, isEmpty, type_name, memory_size, + shape_str, shape_name) + print(depth_str + print_str) + + def GetShapeString(self): + if self.tensor.tf_tensor.ShapeLength() == 0: + return "Scalar" + return_string = "[" + for shape_idx in range(self.tensor.tf_tensor.ShapeLength()): + if (shape_idx != 0): + return_string += ", " + return_string += str(self.tensor.tf_tensor.Shape(shape_idx)) + return_string += "]" + return return_string diff --git a/tools/tflitefile_tool/tensor_wrapping.py b/tools/tflitefile_tool/tensor_wrapping.py index b1fba57d2..a32a573ce 100755 --- a/tools/tflitefile_tool/tensor_wrapping.py +++ b/tools/tflitefile_tool/tensor_wrapping.py @@ -1,5 +1,19 @@ #!/usr/bin/python +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import tflite.Tensor import tflite.TensorType @@ -16,39 +30,47 @@ def SetTensorTypeStr(): TensorTypeList[fieldValue] = fieldName +TYPES = { + 'BOOL': 1, + 'COMPLEX64': 8, + 'FLOAT16': 2, + 'FLOAT32': 4, + 'INT16': 2, + 'INT32': 4, + 'INT64': 8, + 'UINT8': 1 +} + + +def GetTypeSize(type_name): + try: + return TYPES[type_name] + + except KeyError as error: + return 0 + + class Tensor(object): def __init__(self, tensor_idx, tf_tensor, tf_buffer): self.tensor_idx = tensor_idx self.tf_tensor = tf_tensor self.tf_buffer = tf_buffer + self.type_name = TensorTypeList[self.tf_tensor.Type()] + self.memory_size = self.GetMemorySize() + + def GetMemorySize(self): + type_size = GetTypeSize(self.type_name) + if type_size == 0: + return 0 + + # memory size in bytes + size = int(type_size) + shape_length = self.tf_tensor.ShapeLength() + if shape_length == 0: + return size + + for shape_idx in range(shape_length): + shape_size = int(self.tf_tensor.Shape(shape_idx)) + size *= shape_size - def PrintInfo(self, depth_str=""): - print_str = "" - if self.tensor_idx < 0: - print_str = "Tensor {0:4}".format(self.tensor_idx) - else: - buffer_idx = self.tf_tensor.Buffer() - isEmpty = "Filled" - if (self.tf_buffer.DataLength() == 0): - isEmpty = " Empty" - shape_str = self.GetShapeString() - type_name = TensorTypeList[self.tf_tensor.Type()] - - shape_name = "" - if self.tf_tensor.Name() != 0: - shape_name = self.tf_tensor.Name() - - print_str = "Tensor {0:4} : buffer {1:4} | {2} | {3:7} | Shape {4} ({5})".format( - self.tensor_idx, buffer_idx, isEmpty, type_name, shape_str, shape_name) - print(depth_str + print_str) - - def GetShapeString(self): - if self.tf_tensor.ShapeLength() == 0: - return "Scalar" - return_string = "[" - for shape_idx in range(self.tf_tensor.ShapeLength()): - if (shape_idx != 0): - return_string += ", " - return_string += str(self.tf_tensor.Shape(shape_idx)) - return_string += "]" - return return_string + return size diff --git a/tools/tflitefile_tool/tflite/BidirectionalSequenceRNNOptions.py b/tools/tflitefile_tool/tflite/BidirectionalSequenceRNNOptions.py index 5c057b6bf..474ee4ba2 100644 --- a/tools/tflitefile_tool/tflite/BidirectionalSequenceRNNOptions.py +++ b/tools/tflitefile_tool/tflite/BidirectionalSequenceRNNOptions.py @@ -23,8 +23,9 @@ class BidirectionalSequenceRNNOptions(object): def TimeMajor(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) if o != 0: - return self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos) - return 0 + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False # BidirectionalSequenceRNNOptions def FusedActivationFunction(self): diff --git a/tools/tflitefile_tool/tflite/BuiltinOperator.py b/tools/tflitefile_tool/tflite/BuiltinOperator.py index 2beda098e..8e2f9c680 100644 --- a/tools/tflitefile_tool/tflite/BuiltinOperator.py +++ b/tools/tflitefile_tool/tflite/BuiltinOperator.py @@ -84,3 +84,17 @@ class BuiltinOperator(object): POW = 78 ARG_MIN = 79 FAKE_QUANT = 80 + REDUCE_PROD = 81 + REDUCE_MAX = 82 + PACK = 83 + LOGICAL_OR = 84 + ONE_HOT = 85 + LOGICAL_AND = 86 + LOGICAL_NOT = 87 + UNPACK = 88 + REDUCE_MIN = 89 + FLOOR_DIV = 90 + REDUCE_ANY = 91 + SQUARE = 92 + ZEROS_LIKE = 93 + FILL = 94 diff --git a/tools/tflitefile_tool/tflite/BuiltinOptions.py b/tools/tflitefile_tool/tflite/BuiltinOptions.py index 5d3040839..7e1eb34ac 100644 --- a/tools/tflitefile_tool/tflite/BuiltinOptions.py +++ b/tools/tflitefile_tool/tflite/BuiltinOptions.py @@ -63,3 +63,13 @@ class BuiltinOptions(object): PowOptions = 56 ArgMinOptions = 57 FakeQuantOptions = 58 + PackOptions = 59 + LogicalOrOptions = 60 + OneHotOptions = 61 + LogicalAndOptions = 62 + LogicalNotOptions = 63 + UnpackOptions = 64 + FloorDivOptions = 65 + SquareOptions = 66 + ZerosLikeOptions = 67 + FillOptions = 68 diff --git a/tools/tflitefile_tool/tflite/DepthwiseConv2DOptions.py b/tools/tflitefile_tool/tflite/DepthwiseConv2DOptions.py index 9f0b3388f..786f7c53d 100644 --- a/tools/tflitefile_tool/tflite/DepthwiseConv2DOptions.py +++ b/tools/tflitefile_tool/tflite/DepthwiseConv2DOptions.py @@ -54,9 +54,23 @@ class DepthwiseConv2DOptions(object): return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) return 0 + # DepthwiseConv2DOptions + def DilationWFactor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 1 + + # DepthwiseConv2DOptions + def DilationHFactor(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 1 + def DepthwiseConv2DOptionsStart(builder): - builder.StartObject(5) + builder.StartObject(7) def DepthwiseConv2DOptionsAddPadding(builder, padding): @@ -79,5 +93,13 @@ def DepthwiseConv2DOptionsAddFusedActivationFunction(builder, fusedActivationFun builder.PrependInt8Slot(4, fusedActivationFunction, 0) +def DepthwiseConv2DOptionsAddDilationWFactor(builder, dilationWFactor): + builder.PrependInt32Slot(5, dilationWFactor, 1) + + +def DepthwiseConv2DOptionsAddDilationHFactor(builder, dilationHFactor): + builder.PrependInt32Slot(6, dilationHFactor, 1) + + def DepthwiseConv2DOptionsEnd(builder): return builder.EndObject() diff --git a/tools/tflitefile_tool/tflite/FakeQuantOptions.py b/tools/tflitefile_tool/tflite/FakeQuantOptions.py index fc8023e60..c266bfc9d 100644 --- a/tools/tflitefile_tool/tflite/FakeQuantOptions.py +++ b/tools/tflitefile_tool/tflite/FakeQuantOptions.py @@ -44,8 +44,9 @@ class FakeQuantOptions(object): def NarrowRange(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) if o != 0: - return self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos) - return 0 + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False def FakeQuantOptionsStart(builder): diff --git a/tools/tflitefile_tool/tflite/FillOptions.py b/tools/tflitefile_tool/tflite/FillOptions.py new file mode 100644 index 000000000..ee6273514 --- /dev/null +++ b/tools/tflitefile_tool/tflite/FillOptions.py @@ -0,0 +1,28 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tflite + +import flatbuffers + + +class FillOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsFillOptions(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = FillOptions() + x.Init(buf, n + offset) + return x + + # FillOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def FillOptionsStart(builder): + builder.StartObject(0) + + +def FillOptionsEnd(builder): + return builder.EndObject() diff --git a/tools/tflitefile_tool/tflite/FloorDivOptions.py b/tools/tflitefile_tool/tflite/FloorDivOptions.py new file mode 100644 index 000000000..90b797112 --- /dev/null +++ b/tools/tflitefile_tool/tflite/FloorDivOptions.py @@ -0,0 +1,28 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tflite + +import flatbuffers + + +class FloorDivOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsFloorDivOptions(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = FloorDivOptions() + x.Init(buf, n + offset) + return x + + # FloorDivOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def FloorDivOptionsStart(builder): + builder.StartObject(0) + + +def FloorDivOptionsEnd(builder): + return builder.EndObject() diff --git a/tools/tflitefile_tool/tflite/LogicalAndOptions.py b/tools/tflitefile_tool/tflite/LogicalAndOptions.py new file mode 100644 index 000000000..84cdfd92a --- /dev/null +++ b/tools/tflitefile_tool/tflite/LogicalAndOptions.py @@ -0,0 +1,28 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tflite + +import flatbuffers + + +class LogicalAndOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsLogicalAndOptions(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = LogicalAndOptions() + x.Init(buf, n + offset) + return x + + # LogicalAndOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def LogicalAndOptionsStart(builder): + builder.StartObject(0) + + +def LogicalAndOptionsEnd(builder): + return builder.EndObject() diff --git a/tools/tflitefile_tool/tflite/LogicalNotOptions.py b/tools/tflitefile_tool/tflite/LogicalNotOptions.py new file mode 100644 index 000000000..966a419b7 --- /dev/null +++ b/tools/tflitefile_tool/tflite/LogicalNotOptions.py @@ -0,0 +1,28 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tflite + +import flatbuffers + + +class LogicalNotOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsLogicalNotOptions(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = LogicalNotOptions() + x.Init(buf, n + offset) + return x + + # LogicalNotOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def LogicalNotOptionsStart(builder): + builder.StartObject(0) + + +def LogicalNotOptionsEnd(builder): + return builder.EndObject() diff --git a/tools/tflitefile_tool/tflite/LogicalOrOptions.py b/tools/tflitefile_tool/tflite/LogicalOrOptions.py new file mode 100644 index 000000000..0a820cdaa --- /dev/null +++ b/tools/tflitefile_tool/tflite/LogicalOrOptions.py @@ -0,0 +1,28 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tflite + +import flatbuffers + + +class LogicalOrOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsLogicalOrOptions(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = LogicalOrOptions() + x.Init(buf, n + offset) + return x + + # LogicalOrOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def LogicalOrOptionsStart(builder): + builder.StartObject(0) + + +def LogicalOrOptionsEnd(builder): + return builder.EndObject() diff --git a/tools/tflitefile_tool/tflite/Model.py b/tools/tflitefile_tool/tflite/Model.py index 4d1e01f44..b5072b171 100644 --- a/tools/tflitefile_tool/tflite/Model.py +++ b/tools/tflitefile_tool/tflite/Model.py @@ -71,7 +71,7 @@ class Model(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) if o != 0: return self._tab.String(o + self._tab.Pos) - return "" + return None # Model def Buffers(self, j): diff --git a/tools/tflitefile_tool/tflite/OneHotOptions.py b/tools/tflitefile_tool/tflite/OneHotOptions.py new file mode 100644 index 000000000..fba03f85e --- /dev/null +++ b/tools/tflitefile_tool/tflite/OneHotOptions.py @@ -0,0 +1,39 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tflite + +import flatbuffers + + +class OneHotOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsOneHotOptions(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = OneHotOptions() + x.Init(buf, n + offset) + return x + + # OneHotOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # OneHotOptions + def Axis(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def OneHotOptionsStart(builder): + builder.StartObject(1) + + +def OneHotOptionsAddAxis(builder, axis): + builder.PrependInt32Slot(0, axis, 0) + + +def OneHotOptionsEnd(builder): + return builder.EndObject() diff --git a/tools/tflitefile_tool/tflite/OperatorCode.py b/tools/tflitefile_tool/tflite/OperatorCode.py index 0f945b901..ca0b49ef3 100644 --- a/tools/tflitefile_tool/tflite/OperatorCode.py +++ b/tools/tflitefile_tool/tflite/OperatorCode.py @@ -31,7 +31,7 @@ class OperatorCode(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: return self._tab.String(o + self._tab.Pos) - return "" + return None # OperatorCode def Version(self): diff --git a/tools/tflitefile_tool/tflite/PackOptions.py b/tools/tflitefile_tool/tflite/PackOptions.py new file mode 100644 index 000000000..c1d5579fd --- /dev/null +++ b/tools/tflitefile_tool/tflite/PackOptions.py @@ -0,0 +1,50 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tflite + +import flatbuffers + + +class PackOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsPackOptions(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = PackOptions() + x.Init(buf, n + offset) + return x + + # PackOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # PackOptions + def ValuesCount(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # PackOptions + def Axis(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def PackOptionsStart(builder): + builder.StartObject(2) + + +def PackOptionsAddValuesCount(builder, valuesCount): + builder.PrependInt32Slot(0, valuesCount, 0) + + +def PackOptionsAddAxis(builder, axis): + builder.PrependInt32Slot(1, axis, 0) + + +def PackOptionsEnd(builder): + return builder.EndObject() diff --git a/tools/tflitefile_tool/tflite/ReducerOptions.py b/tools/tflitefile_tool/tflite/ReducerOptions.py index 5b6fa1acf..1f1a1b173 100644 --- a/tools/tflitefile_tool/tflite/ReducerOptions.py +++ b/tools/tflitefile_tool/tflite/ReducerOptions.py @@ -23,8 +23,9 @@ class ReducerOptions(object): def KeepDims(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) if o != 0: - return self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos) - return 0 + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False def ReducerOptionsStart(builder): diff --git a/tools/tflitefile_tool/tflite/ResizeBilinearOptions.py b/tools/tflitefile_tool/tflite/ResizeBilinearOptions.py index 66512bb1e..76948948e 100644 --- a/tools/tflitefile_tool/tflite/ResizeBilinearOptions.py +++ b/tools/tflitefile_tool/tflite/ResizeBilinearOptions.py @@ -23,8 +23,9 @@ class ResizeBilinearOptions(object): def AlignCorners(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) if o != 0: - return self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos) - return 0 + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False def ResizeBilinearOptionsStart(builder): diff --git a/tools/tflitefile_tool/tflite/SequenceRNNOptions.py b/tools/tflitefile_tool/tflite/SequenceRNNOptions.py index bee7a0fc6..2681296bb 100644 --- a/tools/tflitefile_tool/tflite/SequenceRNNOptions.py +++ b/tools/tflitefile_tool/tflite/SequenceRNNOptions.py @@ -23,8 +23,9 @@ class SequenceRNNOptions(object): def TimeMajor(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) if o != 0: - return self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos) - return 0 + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False # SequenceRNNOptions def FusedActivationFunction(self): diff --git a/tools/tflitefile_tool/tflite/SkipGramOptions.py b/tools/tflitefile_tool/tflite/SkipGramOptions.py index 50738b924..9eb5059ea 100644 --- a/tools/tflitefile_tool/tflite/SkipGramOptions.py +++ b/tools/tflitefile_tool/tflite/SkipGramOptions.py @@ -37,8 +37,9 @@ class SkipGramOptions(object): def IncludeAllNgrams(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) if o != 0: - return self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos) - return 0 + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False def SkipGramOptionsStart(builder): diff --git a/tools/tflitefile_tool/tflite/SparseToDenseOptions.py b/tools/tflitefile_tool/tflite/SparseToDenseOptions.py index 2782ae573..952d08fc1 100644 --- a/tools/tflitefile_tool/tflite/SparseToDenseOptions.py +++ b/tools/tflitefile_tool/tflite/SparseToDenseOptions.py @@ -23,8 +23,9 @@ class SparseToDenseOptions(object): def ValidateIndices(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) if o != 0: - return self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos) - return 0 + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False def SparseToDenseOptionsStart(builder): diff --git a/tools/tflitefile_tool/tflite/SquareOptions.py b/tools/tflitefile_tool/tflite/SquareOptions.py new file mode 100644 index 000000000..0f9f5af9e --- /dev/null +++ b/tools/tflitefile_tool/tflite/SquareOptions.py @@ -0,0 +1,28 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tflite + +import flatbuffers + + +class SquareOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsSquareOptions(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = SquareOptions() + x.Init(buf, n + offset) + return x + + # SquareOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def SquareOptionsStart(builder): + builder.StartObject(0) + + +def SquareOptionsEnd(builder): + return builder.EndObject() diff --git a/tools/tflitefile_tool/tflite/SubGraph.py b/tools/tflitefile_tool/tflite/SubGraph.py index c20880a36..df9acd8ce 100644 --- a/tools/tflitefile_tool/tflite/SubGraph.py +++ b/tools/tflitefile_tool/tflite/SubGraph.py @@ -112,7 +112,7 @@ class SubGraph(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) if o != 0: return self._tab.String(o + self._tab.Pos) - return "" + return None def SubGraphStart(builder): diff --git a/tools/tflitefile_tool/tflite/Tensor.py b/tools/tflitefile_tool/tflite/Tensor.py index 468b120f4..e5f13301c 100644 --- a/tools/tflitefile_tool/tflite/Tensor.py +++ b/tools/tflitefile_tool/tflite/Tensor.py @@ -62,7 +62,7 @@ class Tensor(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) if o != 0: return self._tab.String(o + self._tab.Pos) - return "" + return None # Tensor def Quantization(self): @@ -79,8 +79,9 @@ class Tensor(object): def IsVariable(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) if o != 0: - return self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos) - return 0 + return bool( + self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False def TensorStart(builder): diff --git a/tools/tflitefile_tool/tflite/UnpackOptions.py b/tools/tflitefile_tool/tflite/UnpackOptions.py new file mode 100644 index 000000000..f580418e6 --- /dev/null +++ b/tools/tflitefile_tool/tflite/UnpackOptions.py @@ -0,0 +1,50 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tflite + +import flatbuffers + + +class UnpackOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsUnpackOptions(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = UnpackOptions() + x.Init(buf, n + offset) + return x + + # UnpackOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # UnpackOptions + def Num(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + # UnpackOptions + def Axis(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) + return 0 + + +def UnpackOptionsStart(builder): + builder.StartObject(2) + + +def UnpackOptionsAddNum(builder, num): + builder.PrependInt32Slot(0, num, 0) + + +def UnpackOptionsAddAxis(builder, axis): + builder.PrependInt32Slot(1, axis, 0) + + +def UnpackOptionsEnd(builder): + return builder.EndObject() diff --git a/tools/tflitefile_tool/tflite/ZerosLikeOptions.py b/tools/tflitefile_tool/tflite/ZerosLikeOptions.py new file mode 100644 index 000000000..ca0880ab0 --- /dev/null +++ b/tools/tflitefile_tool/tflite/ZerosLikeOptions.py @@ -0,0 +1,28 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tflite + +import flatbuffers + + +class ZerosLikeOptions(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAsZerosLikeOptions(cls, buf, offset): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = ZerosLikeOptions() + x.Init(buf, n + offset) + return x + + # ZerosLikeOptions + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + +def ZerosLikeOptionsStart(builder): + builder.StartObject(0) + + +def ZerosLikeOptionsEnd(builder): + return builder.EndObject() |