diff options
author | Chunseok Lee <chunseok.lee@samsung.com> | 2019-01-08 17:36:34 +0900 |
---|---|---|
committer | Chunseok Lee <chunseok.lee@samsung.com> | 2019-01-08 17:36:34 +0900 |
commit | bd11b24234d7d43dfe05a81c520aa01ffad06e42 (patch) | |
tree | 57d0d4044977e4fa0e50cd9ba40b32006dff19eb /tools/tflitefile_tool/model_parser.py | |
parent | 91f4ba45449f700a047a4aeea00b1a7c84e94c75 (diff) | |
download | nnfw-bd11b24234d7d43dfe05a81c520aa01ffad06e42.tar.gz nnfw-bd11b24234d7d43dfe05a81c520aa01ffad06e42.tar.bz2 nnfw-bd11b24234d7d43dfe05a81c520aa01ffad06e42.zip |
Imported Upstream version 0.3upstream/0.3
Diffstat (limited to 'tools/tflitefile_tool/model_parser.py')
-rwxr-xr-x | tools/tflitefile_tool/model_parser.py | 69 |
1 files changed, 37 insertions, 32 deletions
diff --git a/tools/tflitefile_tool/model_parser.py b/tools/tflitefile_tool/model_parser.py index b8967d33f..0edabbba1 100755 --- a/tools/tflitefile_tool/model_parser.py +++ b/tools/tflitefile_tool/model_parser.py @@ -1,4 +1,19 @@ #!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import sys import numpy @@ -13,6 +28,7 @@ import tflite.Model import tflite.SubGraph import argparse from operator_parser import OperatorParser +from model_printer import ModelPrinter from perf_predictor import PerfPredictor @@ -22,7 +38,6 @@ class TFLiteModelFileParser(object): self.tflite_file = args.input_file # Set print level (0 ~ 2) - # TODO: print information based on level self.print_level = args.verbose if (args.verbose > 2): self.print_level = 2 @@ -30,33 +45,34 @@ class TFLiteModelFileParser(object): self.print_level = 0 # Set tensor index list to print information - # TODO: - # Print tensors in list only - # Print all tensors if argument used and not specified index number + self.print_all_tensor = True if (args.tensor != None): - if (len(args.tensor) == 0): - self.print_all_tensor = True - else: + if (len(args.tensor) != 0): self.print_all_tensor = False self.print_tensor_index = [] - for tensor_index in args.tensor: self.print_tensor_index.append(int(tensor_index)) # Set operator index list to print information - # TODO: - # Print operators in list only - # Print all operators if argument used and not specified index number + self.print_all_operator = True if (args.operator != None): - if (len(args.operator) == 0): - self.print_all_oeprator = True - else: - self.print_all_oeprator = False + if (len(args.operator) != 0): + self.print_all_operator = False self.print_operator_index = [] - for operator_index in args.operator: self.print_operator_index.append(int(operator_index)) + def PrintModel(self, model_name, op_parser): + printer = ModelPrinter(self.print_level, op_parser, model_name) + + if self.print_all_tensor == False: + printer.SetPrintSpecificTensors(self.print_tensor_index) + + if self.print_all_operator == False: + printer.SetPrintSpecificOperators(self.print_operator_index) + + printer.PrintInfo() + def main(self): # Generate Model: top structure of tflite model file buf = self.tflite_file.read() @@ -71,19 +87,12 @@ class TFLiteModelFileParser(object): if (subgraph_index != 0): model_name = "Model #" + str(subgraph_index) - print("[" + model_name + "]\n") - - # Model inputs & outputs - model_inputs = tf_subgraph.InputsAsNumpy() - model_outputs = tf_subgraph.OutputsAsNumpy() - - print(model_name + " input tensors: " + str(model_inputs)) - print(model_name + " output tensors: " + str(model_outputs)) - - # Parse Operators and print all of operators + # Parse Operators op_parser = OperatorParser(tf_model, tf_subgraph, PerfPredictor()) op_parser.Parse() - op_parser.PrintAll() + + # print all of operators or requested objects + self.PrintModel(model_name, op_parser) if __name__ == '__main__': @@ -92,11 +101,7 @@ if __name__ == '__main__': arg_parser.add_argument( "input_file", type=argparse.FileType('rb'), help="tflite file to read") arg_parser.add_argument( - '-v', - '--verbose', - action='count', - default=0, - help="set print level (0~2, default: 0)") + '-v', '--verbose', type=int, default=1, help="set print level (0~2, default: 1)") arg_parser.add_argument( '-t', '--tensor', nargs='*', help="tensor ID to print information (default: all)") arg_parser.add_argument( |