summaryrefslogtreecommitdiff
path: root/compiler/tf2tfliteV2/tf2tfliteV2.py
blob: 8b6ba0dc4549b437d82f9693b70c96217f624745 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
# Copyright (C) 2018 The TensorFlow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import argparse
import sys

from google.protobuf.message import DecodeError
from google.protobuf import text_format as _text_format


def wrap_frozen_graph(graph_def, inputs, outputs):
    def _imports_graph_def():
        tf.compat.v1.import_graph_def(graph_def, name="")

    wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph
    return wrapped_import.prune(
        tf.nest.map_structure(import_graph.as_graph_element, inputs),
        tf.nest.map_structure(import_graph.as_graph_element, outputs))


def _get_parser():
    """
  Returns an ArgumentParser for TensorFlow Lite Converter.
  """
    parser = argparse.ArgumentParser(
        description=("Command line tool to run TensorFlow Lite Converter."))

    # Converter version.
    converter_version = parser.add_mutually_exclusive_group(required=True)
    converter_version.add_argument(
        "--v1", action="store_true", help="Use TensorFlow Lite Converter 1.x")
    converter_version.add_argument(
        "--v2", action="store_true", help="Use TensorFlow Lite Converter 2.x")

    # Input and output path.
    parser.add_argument(
        "--input_path", type=str, help="Full filepath of the input file.", required=True)
    parser.add_argument(
        "--output_path",
        type=str,
        help="Full filepath of the output file.",
        required=True)

    # Input and output arrays.
    parser.add_argument(
        "--input_arrays",
        type=str,
        help="Names of the input arrays, comma-separated.",
        required=True)
    parser.add_argument(
        "--input_shapes",
        type=str,
        help="Shapes corresponding to --input_arrays, colon-separated.")
    parser.add_argument(
        "--output_arrays",
        type=str,
        help="Names of the output arrays, comma-separated.",
        required=True)

    return parser


def _check_flags(flags):
    """
  Checks the parsed flags to ensure they are valid.
  """
    if flags.v1:
        invalid = ""
        # To be filled

        if invalid:
            raise ValueError(invalid + " options must be used with v2")

    if flags.v2:
        if tf.__version__.find("2.") != 0:
            raise ValueError(
                "Imported TensorFlow should have version >= 2.0 but you have " +
                tf.__version__)

        invalid = ""
        # To be filled

        if invalid:
            raise ValueError(invalid + " options must be used with v1")

    if flags.input_shapes:
        if not flags.input_arrays:
            raise ValueError("--input_shapes must be used with --input_arrays")
        if flags.input_shapes.count(":") != flags.input_arrays.count(","):
            raise ValueError("--input_shapes and --input_arrays must have the same "
                             "number of items")


def _parse_array(arrays, type_fn=str):
    return list(map(type_fn, arrays.split(",")))


def _v1_convert(flags):
    input_shapes = None
    if flags.input_shapes:
        input_arrays = _parse_array(flags.input_arrays)
        input_shapes_list = [
            _parse_array(shape, type_fn=int) for shape in flags.input_shapes.split(":")
        ]
        input_shapes = dict(list(zip(input_arrays, input_shapes_list)))

    converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
        flags.input_path, _parse_array(flags.input_arrays),
        _parse_array(flags.output_arrays), input_shapes)

    converter.allow_custom_ops = True

    tflite_model = converter.convert()
    open(flags.output_path, "wb").write(tflite_model)


def _v2_convert(flags):
    file_content = open(flags.input_path, 'rb').read()
    try:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(file_content)
    except (_text_format.ParseError, DecodeError):
        try:
            _text_format.Merge(file_content, graph_def)
        except (_text_format.ParseError, DecodeError):
            raise IOError("Unable to parse input file '{}'.".format(flags.input_path))

    wrap_func = wrap_frozen_graph(
        graph_def,
        inputs=[_str + ":0" for _str in _parse_array(flags.input_arrays)],
        # TODO What if multiple outputs come in?
        outputs=[_str + ":0" for _str in _parse_array(flags.output_arrays)])
    converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])

    converter.allow_custom_ops = True
    converter.experimental_new_converter = True

    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]

    tflite_model = converter.convert()
    open(flags.output_path, "wb").write(tflite_model)


def _convert(flags):
    if (flags.v1):
        _v1_convert(flags)
    else:
        _v2_convert(flags)


"""
Input frozen graph must be from TensorFlow 1.13.1
"""


def main():
    # Parse argument.
    parser = _get_parser()

    # Check if the flags are valid.
    flags = parser.parse_known_args(args=sys.argv[1:])
    _check_flags(flags[0])

    # Convert
    _convert(flags[0])


if __name__ == "__main__":
    main()