summaryrefslogtreecommitdiff
path: root/tools/tflitefile_tool/model_parser.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/tflitefile_tool/model_parser.py')
-rwxr-xr-xtools/tflitefile_tool/model_parser.py35
1 files changed, 31 insertions, 4 deletions
diff --git a/tools/tflitefile_tool/model_parser.py b/tools/tflitefile_tool/model_parser.py
index 0edabbba1..6f9e1c616 100755
--- a/tools/tflitefile_tool/model_parser.py
+++ b/tools/tflitefile_tool/model_parser.py
@@ -29,6 +29,7 @@ import tflite.SubGraph
import argparse
from operator_parser import OperatorParser
from model_printer import ModelPrinter
+from model_saver import ModelSaver
from perf_predictor import PerfPredictor
@@ -62,6 +63,15 @@ class TFLiteModelFileParser(object):
for operator_index in args.operator:
self.print_operator_index.append(int(operator_index))
+ # Set config option
+ self.save = False
+ if args.config:
+ self.save = True
+ self.save_config = True
+
+ if self.save == True:
+ self.save_prefix = args.prefix
+
def PrintModel(self, model_name, op_parser):
printer = ModelPrinter(self.print_level, op_parser, model_name)
@@ -73,6 +83,12 @@ class TFLiteModelFileParser(object):
printer.PrintInfo()
+ def SaveModel(self, model_name, op_parser):
+ saver = ModelSaver(model_name, op_parser)
+
+ if self.save_config == True:
+ saver.SaveConfigInfo(self.save_prefix)
+
def main(self):
# Generate Model: top structure of tflite model file
buf = self.tflite_file.read()
@@ -81,18 +97,22 @@ class TFLiteModelFileParser(object):
# Model file can have many models
# 1st subgraph is main model
- model_name = "Main model"
+ model_name = "Main_model"
for subgraph_index in range(tf_model.SubgraphsLength()):
tf_subgraph = tf_model.Subgraphs(subgraph_index)
if (subgraph_index != 0):
- model_name = "Model #" + str(subgraph_index)
+ model_name = "Model_#" + str(subgraph_index)
# Parse Operators
op_parser = OperatorParser(tf_model, tf_subgraph, PerfPredictor())
op_parser.Parse()
- # print all of operators or requested objects
- self.PrintModel(model_name, op_parser)
+ if self.save == False:
+ # print all of operators or requested objects
+ self.PrintModel(model_name, op_parser)
+ else:
+ # save all of operators in this model
+ self.SaveModel(model_name, op_parser)
if __name__ == '__main__':
@@ -109,6 +129,13 @@ if __name__ == '__main__':
'--operator',
nargs='*',
help="operator ID to print information (default: all)")
+ arg_parser.add_argument(
+ '-c',
+ '--config',
+ action='store_true',
+ help="Save the configuration file per operator")
+ arg_parser.add_argument(
+ '-p', '--prefix', help="file prefix to be saved (with -c/--config option)")
args = arg_parser.parse_args()
# Call main function