summaryrefslogtreecommitdiff
path: root/tools/tflitefile_tool/model_parser.py
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2019-01-08 17:36:34 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2019-01-08 17:36:34 +0900
commitbd11b24234d7d43dfe05a81c520aa01ffad06e42 (patch)
tree57d0d4044977e4fa0e50cd9ba40b32006dff19eb /tools/tflitefile_tool/model_parser.py
parent91f4ba45449f700a047a4aeea00b1a7c84e94c75 (diff)
downloadnnfw-bd11b24234d7d43dfe05a81c520aa01ffad06e42.tar.gz
nnfw-bd11b24234d7d43dfe05a81c520aa01ffad06e42.tar.bz2
nnfw-bd11b24234d7d43dfe05a81c520aa01ffad06e42.zip
Imported Upstream version 0.3upstream/0.3
Diffstat (limited to 'tools/tflitefile_tool/model_parser.py')
-rwxr-xr-xtools/tflitefile_tool/model_parser.py69
1 files changed, 37 insertions, 32 deletions
diff --git a/tools/tflitefile_tool/model_parser.py b/tools/tflitefile_tool/model_parser.py
index b8967d33f..0edabbba1 100755
--- a/tools/tflitefile_tool/model_parser.py
+++ b/tools/tflitefile_tool/model_parser.py
@@ -1,4 +1,19 @@
#!/usr/bin/python
+
+# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
import sys
import numpy
@@ -13,6 +28,7 @@ import tflite.Model
import tflite.SubGraph
import argparse
from operator_parser import OperatorParser
+from model_printer import ModelPrinter
from perf_predictor import PerfPredictor
@@ -22,7 +38,6 @@ class TFLiteModelFileParser(object):
self.tflite_file = args.input_file
# Set print level (0 ~ 2)
- # TODO: print information based on level
self.print_level = args.verbose
if (args.verbose > 2):
self.print_level = 2
@@ -30,33 +45,34 @@ class TFLiteModelFileParser(object):
self.print_level = 0
# Set tensor index list to print information
- # TODO:
- # Print tensors in list only
- # Print all tensors if argument used and not specified index number
+ self.print_all_tensor = True
if (args.tensor != None):
- if (len(args.tensor) == 0):
- self.print_all_tensor = True
- else:
+ if (len(args.tensor) != 0):
self.print_all_tensor = False
self.print_tensor_index = []
-
for tensor_index in args.tensor:
self.print_tensor_index.append(int(tensor_index))
# Set operator index list to print information
- # TODO:
- # Print operators in list only
- # Print all operators if argument used and not specified index number
+ self.print_all_operator = True
if (args.operator != None):
- if (len(args.operator) == 0):
- self.print_all_oeprator = True
- else:
- self.print_all_oeprator = False
+ if (len(args.operator) != 0):
+ self.print_all_operator = False
self.print_operator_index = []
-
for operator_index in args.operator:
self.print_operator_index.append(int(operator_index))
+ def PrintModel(self, model_name, op_parser):
+ printer = ModelPrinter(self.print_level, op_parser, model_name)
+
+ if self.print_all_tensor == False:
+ printer.SetPrintSpecificTensors(self.print_tensor_index)
+
+ if self.print_all_operator == False:
+ printer.SetPrintSpecificOperators(self.print_operator_index)
+
+ printer.PrintInfo()
+
def main(self):
# Generate Model: top structure of tflite model file
buf = self.tflite_file.read()
@@ -71,19 +87,12 @@ class TFLiteModelFileParser(object):
if (subgraph_index != 0):
model_name = "Model #" + str(subgraph_index)
- print("[" + model_name + "]\n")
-
- # Model inputs & outputs
- model_inputs = tf_subgraph.InputsAsNumpy()
- model_outputs = tf_subgraph.OutputsAsNumpy()
-
- print(model_name + " input tensors: " + str(model_inputs))
- print(model_name + " output tensors: " + str(model_outputs))
-
- # Parse Operators and print all of operators
+ # Parse Operators
op_parser = OperatorParser(tf_model, tf_subgraph, PerfPredictor())
op_parser.Parse()
- op_parser.PrintAll()
+
+ # print all of operators or requested objects
+ self.PrintModel(model_name, op_parser)
if __name__ == '__main__':
@@ -92,11 +101,7 @@ if __name__ == '__main__':
arg_parser.add_argument(
"input_file", type=argparse.FileType('rb'), help="tflite file to read")
arg_parser.add_argument(
- '-v',
- '--verbose',
- action='count',
- default=0,
- help="set print level (0~2, default: 0)")
+ '-v', '--verbose', type=int, default=1, help="set print level (0~2, default: 1)")
arg_parser.add_argument(
'-t', '--tensor', nargs='*', help="tensor ID to print information (default: all)")
arg_parser.add_argument(