summaryrefslogtreecommitdiff
path: root/tools/nnpackage_tool/tf2tfliteV2/tf2tfliteV2.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/nnpackage_tool/tf2tfliteV2/tf2tfliteV2.py')
-rwxr-xr-xtools/nnpackage_tool/tf2tfliteV2/tf2tfliteV2.py173
1 files changed, 173 insertions, 0 deletions
diff --git a/tools/nnpackage_tool/tf2tfliteV2/tf2tfliteV2.py b/tools/nnpackage_tool/tf2tfliteV2/tf2tfliteV2.py
new file mode 100755
index 000000000..ebd5a3afa
--- /dev/null
+++ b/tools/nnpackage_tool/tf2tfliteV2/tf2tfliteV2.py
@@ -0,0 +1,173 @@
+# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+# Copyright (C) 2018 The TensorFlow Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tensorflow as tf
+import argparse
+import sys
+
+
+def wrap_frozen_graph(graph_def, inputs, outputs):
+ def _imports_graph_def():
+ tf.compat.v1.import_graph_def(graph_def, name="")
+
+ wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
+ import_graph = wrapped_import.graph
+ return wrapped_import.prune(
+ tf.nest.map_structure(import_graph.as_graph_element, inputs),
+ tf.nest.map_structure(import_graph.as_graph_element, outputs))
+
+
+def _get_parser():
+ """
+ Returns an ArgumentParser for TensorFlow Lite Converter.
+ """
+ parser = argparse.ArgumentParser(
+ description=("Command line tool to run TensorFlow Lite Converter."))
+
+ # Converter version.
+ converter_version = parser.add_mutually_exclusive_group(required=True)
+ converter_version.add_argument(
+ "--v1", action="store_true", help="Use TensorFlow Lite Converter 1.x")
+ converter_version.add_argument(
+ "--v2", action="store_true", help="Use TensorFlow Lite Converter 2.x")
+
+ # Input and output path.
+ parser.add_argument(
+ "--input_path", type=str, help="Full filepath of the input file.", required=True)
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ help="Full filepath of the output file.",
+ required=True)
+
+ # Input and output arrays.
+ parser.add_argument(
+ "--input_arrays",
+ type=str,
+ help="Names of the input arrays, comma-separated.",
+ required=True)
+ parser.add_argument(
+ "--input_shapes",
+ type=str,
+ help="Shapes corresponding to --input_arrays, colon-separated.")
+ parser.add_argument(
+ "--output_arrays",
+ type=str,
+ help="Names of the output arrays, comma-separated.",
+ required=True)
+
+ return parser
+
+
+def _check_flags(flags):
+ """
+ Checks the parsed flags to ensure they are valid.
+ """
+ if flags.v1:
+ invalid = ""
+ # To be filled
+
+ if invalid:
+ raise ValueError(invalid + " options must be used with v2")
+
+ if flags.v2:
+ if tf.__version__.find("2.") != 0:
+ raise ValueError(
+ "Imported TensorFlow should have version >= 2.0 but you have " +
+ tf.__version__)
+
+ invalid = ""
+ # To be filled
+
+ if invalid:
+ raise ValueError(invalid + " options must be used with v1")
+
+ if flags.input_shapes:
+ if not flags.input_arrays:
+ raise ValueError("--input_shapes must be used with --input_arrays")
+ if flags.input_shapes.count(":") != flags.input_arrays.count(","):
+ raise ValueError("--input_shapes and --input_arrays must have the same "
+ "number of items")
+
+
+def _parse_array(arrays, type_fn=str):
+ return list(map(type_fn, arrays.split(",")))
+
+
+def _v1_convert(flags):
+ input_shapes = None
+ if flags.input_shapes:
+ input_arrays = _parse_array(flags.input_arrays)
+ input_shapes_list = [
+ _parse_array(shape, type_fn=int) for shape in flags.input_shapes.split(":")
+ ]
+ input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
+
+ converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
+ flags.input_path, _parse_array(flags.input_arrays),
+ _parse_array(flags.output_arrays), input_shapes)
+
+ converter.allow_custom_ops = True
+
+ tflite_model = converter.convert()
+ open(flags.output_path, "wb").write(tflite_model)
+
+
+def _v2_convert(flags):
+ graph_def = tf.compat.v1.GraphDef()
+ graph_def.ParseFromString(open(flags.input_path, 'rb').read())
+
+ wrap_func = wrap_frozen_graph(
+ graph_def,
+ inputs=[_str + ":0" for _str in _parse_array(flags.input_arrays)],
+ # TODO What if multiple outputs come in?
+ outputs=[_str + ":0" for _str in _parse_array(flags.output_arrays)])
+ converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
+
+ converter.allow_custom_ops = True
+ converter.experimental_new_converter = True
+
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+
+ tflite_model = converter.convert()
+ open(flags.output_path, "wb").write(tflite_model)
+
+
+def _convert(flags):
+ if (flags.v1):
+ _v1_convert(flags)
+ else:
+ _v2_convert(flags)
+
+
+"""
+Input frozen graph must be from TensorFlow 1.13.1
+"""
+
+
+def main():
+ # Parse argument.
+ parser = _get_parser()
+
+ # Check if the flags are valid.
+ flags = parser.parse_known_args(args=sys.argv[1:])
+ _check_flags(flags[0])
+
+ # Convert
+ _convert(flags[0])
+
+
+if __name__ == "__main__":
+ main()