summaryrefslogtreecommitdiff
path: root/tools/tflitefile_tool/model_parser.py
blob: 0edabbba1b212380f6a4f41c41c089fa79520097 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/python

# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import numpy

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tflite'))
flatbuffersPath = '../../externals/flatbuffers'
sys.path.append(
    os.path.join(os.path.dirname(os.path.abspath(__file__)), flatbuffersPath + '/python'))

import flatbuffers
import tflite.Model
import tflite.SubGraph
import argparse
from operator_parser import OperatorParser
from model_printer import ModelPrinter
from perf_predictor import PerfPredictor


class TFLiteModelFileParser(object):
    def __init__(self, args):
        # Read flatbuffer file descriptor using argument
        self.tflite_file = args.input_file

        # Set print level (0 ~ 2)
        self.print_level = args.verbose
        if (args.verbose > 2):
            self.print_level = 2
        if (args.verbose < 0):
            self.print_level = 0

        # Set tensor index list to print information
        self.print_all_tensor = True
        if (args.tensor != None):
            if (len(args.tensor) != 0):
                self.print_all_tensor = False
                self.print_tensor_index = []
                for tensor_index in args.tensor:
                    self.print_tensor_index.append(int(tensor_index))

        # Set operator index list to print information
        self.print_all_operator = True
        if (args.operator != None):
            if (len(args.operator) != 0):
                self.print_all_operator = False
                self.print_operator_index = []
                for operator_index in args.operator:
                    self.print_operator_index.append(int(operator_index))

    def PrintModel(self, model_name, op_parser):
        printer = ModelPrinter(self.print_level, op_parser, model_name)

        if self.print_all_tensor == False:
            printer.SetPrintSpecificTensors(self.print_tensor_index)

        if self.print_all_operator == False:
            printer.SetPrintSpecificOperators(self.print_operator_index)

        printer.PrintInfo()

    def main(self):
        # Generate Model: top structure of tflite model file
        buf = self.tflite_file.read()
        buf = bytearray(buf)
        tf_model = tflite.Model.Model.GetRootAsModel(buf, 0)

        # Model file can have many models
        # 1st subgraph is main model
        model_name = "Main model"
        for subgraph_index in range(tf_model.SubgraphsLength()):
            tf_subgraph = tf_model.Subgraphs(subgraph_index)
            if (subgraph_index != 0):
                model_name = "Model #" + str(subgraph_index)

            # Parse Operators
            op_parser = OperatorParser(tf_model, tf_subgraph, PerfPredictor())
            op_parser.Parse()

            # print all of operators or requested objects
            self.PrintModel(model_name, op_parser)


if __name__ == '__main__':
    # Define argument and read
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument(
        "input_file", type=argparse.FileType('rb'), help="tflite file to read")
    arg_parser.add_argument(
        '-v', '--verbose', type=int, default=1, help="set print level (0~2, default: 1)")
    arg_parser.add_argument(
        '-t', '--tensor', nargs='*', help="tensor ID to print information (default: all)")
    arg_parser.add_argument(
        '-o',
        '--operator',
        nargs='*',
        help="operator ID to print information (default: all)")
    args = arg_parser.parse_args()

    # Call main function
    TFLiteModelFileParser(args).main()