diff options
Diffstat (limited to 'tools/tflitefile_tool/model_parser.py')
-rwxr-xr-x | tools/tflitefile_tool/model_parser.py | 35 |
1 files changed, 31 insertions, 4 deletions
diff --git a/tools/tflitefile_tool/model_parser.py b/tools/tflitefile_tool/model_parser.py index 0edabbba1..6f9e1c616 100755 --- a/tools/tflitefile_tool/model_parser.py +++ b/tools/tflitefile_tool/model_parser.py @@ -29,6 +29,7 @@ import tflite.SubGraph import argparse from operator_parser import OperatorParser from model_printer import ModelPrinter +from model_saver import ModelSaver from perf_predictor import PerfPredictor @@ -62,6 +63,15 @@ class TFLiteModelFileParser(object): for operator_index in args.operator: self.print_operator_index.append(int(operator_index)) + # Set config option + self.save = False + if args.config: + self.save = True + self.save_config = True + + if self.save == True: + self.save_prefix = args.prefix + def PrintModel(self, model_name, op_parser): printer = ModelPrinter(self.print_level, op_parser, model_name) @@ -73,6 +83,12 @@ class TFLiteModelFileParser(object): printer.PrintInfo() + def SaveModel(self, model_name, op_parser): + saver = ModelSaver(model_name, op_parser) + + if self.save_config == True: + saver.SaveConfigInfo(self.save_prefix) + def main(self): # Generate Model: top structure of tflite model file buf = self.tflite_file.read() @@ -81,18 +97,22 @@ class TFLiteModelFileParser(object): # Model file can have many models # 1st subgraph is main model - model_name = "Main model" + model_name = "Main_model" for subgraph_index in range(tf_model.SubgraphsLength()): tf_subgraph = tf_model.Subgraphs(subgraph_index) if (subgraph_index != 0): - model_name = "Model #" + str(subgraph_index) + model_name = "Model_#" + str(subgraph_index) # Parse Operators op_parser = OperatorParser(tf_model, tf_subgraph, PerfPredictor()) op_parser.Parse() - # print all of operators or requested objects - self.PrintModel(model_name, op_parser) + if self.save == False: + # print all of operators or requested objects + self.PrintModel(model_name, op_parser) + else: + # save all of operators in this model + self.SaveModel(model_name, op_parser) if __name__ == '__main__': @@ -109,6 +129,13 @@ if __name__ == '__main__': '--operator', nargs='*', help="operator ID to print information (default: all)") + arg_parser.add_argument( + '-c', + '--config', + action='store_true', + help="Save the configuration file per operator") + arg_parser.add_argument( + '-p', '--prefix', help="file prefix to be saved (with -c/--config option)") args = arg_parser.parse_args() # Call main function |