diff options
Diffstat (limited to 'tools/nnpackage_tool/tf2tfliteV2/tf2tfliteV2.py')
-rwxr-xr-x | tools/nnpackage_tool/tf2tfliteV2/tf2tfliteV2.py | 173 |
1 files changed, 173 insertions, 0 deletions
diff --git a/tools/nnpackage_tool/tf2tfliteV2/tf2tfliteV2.py b/tools/nnpackage_tool/tf2tfliteV2/tf2tfliteV2.py new file mode 100755 index 000000000..ebd5a3afa --- /dev/null +++ b/tools/nnpackage_tool/tf2tfliteV2/tf2tfliteV2.py @@ -0,0 +1,173 @@ +# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved +# Copyright (C) 2018 The TensorFlow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import argparse +import sys + + +def wrap_frozen_graph(graph_def, inputs, outputs): + def _imports_graph_def(): + tf.compat.v1.import_graph_def(graph_def, name="") + + wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, []) + import_graph = wrapped_import.graph + return wrapped_import.prune( + tf.nest.map_structure(import_graph.as_graph_element, inputs), + tf.nest.map_structure(import_graph.as_graph_element, outputs)) + + +def _get_parser(): + """ + Returns an ArgumentParser for TensorFlow Lite Converter. + """ + parser = argparse.ArgumentParser( + description=("Command line tool to run TensorFlow Lite Converter.")) + + # Converter version. + converter_version = parser.add_mutually_exclusive_group(required=True) + converter_version.add_argument( + "--v1", action="store_true", help="Use TensorFlow Lite Converter 1.x") + converter_version.add_argument( + "--v2", action="store_true", help="Use TensorFlow Lite Converter 2.x") + + # Input and output path. + parser.add_argument( + "--input_path", type=str, help="Full filepath of the input file.", required=True) + parser.add_argument( + "--output_path", + type=str, + help="Full filepath of the output file.", + required=True) + + # Input and output arrays. + parser.add_argument( + "--input_arrays", + type=str, + help="Names of the input arrays, comma-separated.", + required=True) + parser.add_argument( + "--input_shapes", + type=str, + help="Shapes corresponding to --input_arrays, colon-separated.") + parser.add_argument( + "--output_arrays", + type=str, + help="Names of the output arrays, comma-separated.", + required=True) + + return parser + + +def _check_flags(flags): + """ + Checks the parsed flags to ensure they are valid. + """ + if flags.v1: + invalid = "" + # To be filled + + if invalid: + raise ValueError(invalid + " options must be used with v2") + + if flags.v2: + if tf.__version__.find("2.") != 0: + raise ValueError( + "Imported TensorFlow should have version >= 2.0 but you have " + + tf.__version__) + + invalid = "" + # To be filled + + if invalid: + raise ValueError(invalid + " options must be used with v1") + + if flags.input_shapes: + if not flags.input_arrays: + raise ValueError("--input_shapes must be used with --input_arrays") + if flags.input_shapes.count(":") != flags.input_arrays.count(","): + raise ValueError("--input_shapes and --input_arrays must have the same " + "number of items") + + +def _parse_array(arrays, type_fn=str): + return list(map(type_fn, arrays.split(","))) + + +def _v1_convert(flags): + input_shapes = None + if flags.input_shapes: + input_arrays = _parse_array(flags.input_arrays) + input_shapes_list = [ + _parse_array(shape, type_fn=int) for shape in flags.input_shapes.split(":") + ] + input_shapes = dict(list(zip(input_arrays, input_shapes_list))) + + converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph( + flags.input_path, _parse_array(flags.input_arrays), + _parse_array(flags.output_arrays), input_shapes) + + converter.allow_custom_ops = True + + tflite_model = converter.convert() + open(flags.output_path, "wb").write(tflite_model) + + +def _v2_convert(flags): + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(open(flags.input_path, 'rb').read()) + + wrap_func = wrap_frozen_graph( + graph_def, + inputs=[_str + ":0" for _str in _parse_array(flags.input_arrays)], + # TODO What if multiple outputs come in? + outputs=[_str + ":0" for _str in _parse_array(flags.output_arrays)]) + converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func]) + + converter.allow_custom_ops = True + converter.experimental_new_converter = True + + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] + + tflite_model = converter.convert() + open(flags.output_path, "wb").write(tflite_model) + + +def _convert(flags): + if (flags.v1): + _v1_convert(flags) + else: + _v2_convert(flags) + + +""" +Input frozen graph must be from TensorFlow 1.13.1 +""" + + +def main(): + # Parse argument. + parser = _get_parser() + + # Check if the flags are valid. + flags = parser.parse_known_args(args=sys.argv[1:]) + _check_flags(flags[0]) + + # Convert + _convert(flags[0]) + + +if __name__ == "__main__": + main() |