#!/usr/bin/env bash ''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # ''' ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # ''' ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # ''' ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # ''' ''''exit 255 # ''' # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import subprocess import sys import tempfile import onelib.make_cmd as _make_cmd import utils as _utils def get_driver_cfg_section(): return "one-import-tf" def _get_parser(): parser = argparse.ArgumentParser( description='command line tool to convert TensorFlow to circle') _utils._add_default_arg(parser) ## tf2tfliteV2 arguments tf2tfliteV2_group = parser.add_argument_group('converter arguments') # converter version converter_version = tf2tfliteV2_group.add_mutually_exclusive_group() converter_version.add_argument( '--v1', action='store_const', dest='converter_version_cmd', const='--v1', help='use TensorFlow Lite Converter 1.x') converter_version.add_argument( '--v2', action='store_const', dest='converter_version_cmd', const='--v2', help='use TensorFlow Lite Converter 2.x') parser.add_argument('--converter_version', type=str, help=argparse.SUPPRESS) # input model format model_format_arg = tf2tfliteV2_group.add_mutually_exclusive_group() model_format_arg.add_argument( '--graph_def', action='store_const', dest='model_format_cmd', const='--graph_def', help='use graph def file(default)') model_format_arg.add_argument( '--saved_model', action='store_const', dest='model_format_cmd', const='--saved_model', help='use saved model') model_format_arg.add_argument( '--keras_model', action='store_const', dest='model_format_cmd', const='--keras_model', help='use keras model') parser.add_argument('--model_format', type=str, help=argparse.SUPPRESS) # input and output path. tf2tfliteV2_group.add_argument( '-i', '--input_path', type=str, help='full filepath of the input file') tf2tfliteV2_group.add_argument( '-o', '--output_path', type=str, help='full filepath of the output file') # input and output arrays. tf2tfliteV2_group.add_argument( '-I', '--input_arrays', type=str, help='names of the input arrays, comma-separated') tf2tfliteV2_group.add_argument( '-s', '--input_shapes', type=str, help= 'shapes corresponding to --input_arrays, colon-separated (ex:"1,4,4,3:1,20,20,3")' ) tf2tfliteV2_group.add_argument( '-O', '--output_arrays', type=str, help='names of the output arrays, comma-separated') # save intermediate file(s) parser.add_argument( '--save_intermediate', action='store_true', help='Save intermediate files to output folder') return parser def _verify_arg(parser, args): """verify given arguments""" # check if required arguments is given missing = [] if not _utils._is_valid_attr(args, 'input_path'): missing.append('-i/--input_path') if not _utils._is_valid_attr(args, 'output_path'): missing.append('-o/--output_path') if len(missing): parser.error('the following arguments are required: ' + ' '.join(missing)) def _parse_arg(parser): args = parser.parse_args() # print version if args.version: _utils._print_version_and_exit(__file__) return args def _convert(args): # get file path to log dir_path = os.path.dirname(os.path.realpath(__file__)) logfile_path = os.path.realpath(args.output_path) + '.log' with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir: # save intermediate if _utils._is_valid_attr(args, 'save_intermediate'): tmpdir = os.path.dirname(logfile_path) # make a command to convert from tf to tflite tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py') tf2tfliteV2_output_path = os.path.join( tmpdir, os.path.splitext(os.path.basename(args.output_path))[0]) + '.tflite' tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(args, tf2tfliteV2_path, getattr(args, 'input_path'), tf2tfliteV2_output_path) f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode()) # convert tf to tflite _utils._run(tf2tfliteV2_cmd, logfile=f) # make a command to convert from tflite to circle tflite2circle_path = os.path.join(dir_path, 'tflite2circle') tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path, tf2tfliteV2_output_path, getattr(args, 'output_path')) f.write((' '.join(tflite2circle_cmd) + '\n').encode()) # convert tflite to circle _utils._run(tflite2circle_cmd, err_prefix="tflite2circle", logfile=f) def main(): # parse arguments parser = _get_parser() args = _parse_arg(parser) # parse configuration file _utils._parse_cfg(args, 'one-import-tf') # verify arguments _verify_arg(parser, args) # convert _convert(args) if __name__ == '__main__': _utils._safemain(main, __file__)