diff options
Diffstat (limited to 'compiler/one-cmds/one-import-tflite')
-rw-r--r-- | compiler/one-cmds/one-import-tflite | 155 |
1 files changed, 91 insertions, 64 deletions
diff --git a/compiler/one-cmds/one-import-tflite b/compiler/one-cmds/one-import-tflite index 053489c92..8eba46dc5 100644 --- a/compiler/one-cmds/one-import-tflite +++ b/compiler/one-cmds/one-import-tflite @@ -1,4 +1,9 @@ -#!/bin/bash +#!/usr/bin/env bash +''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # ''' +''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # ''' +''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # ''' +''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # ''' +''''exit 255 # ''' # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved # @@ -14,70 +19,92 @@ # See the License for the specific language governing permissions and # limitations under the License. -set -e +import argparse +import os +import sys -DRIVER_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +import onelib.make_cmd as _make_cmd +import onelib.utils as oneutils -usage() -{ - echo "Convert TensorFlow lite model to circle." - echo "Usage: one-import-tflite" - echo " --version Show version information and exit" - echo " --input_path <path/to/tflitemodel>" - echo " --output_path <path/to/circle>" - exit 255 -} +# TODO Find better way to suppress trackback on error +sys.tracebacklimit = 0 -version() -{ - $DRIVER_PATH/one-version one-import-tflite - exit 255 -} -# Parse command-line arguments -# -while [ "$#" -ne 0 ]; do - CUR="$1" - - case $CUR in - '--help') - usage - ;; - '--version') - version - ;; - '--input_path') - export INPUT_PATH="$2" - shift 2 - ;; - '--output_path') - export OUTPUT_PATH="$2" - shift 2 - ;; - *) - echo "Unknown parameter: ${CUR}" - shift - ;; - esac -done - -if [ -z ${INPUT_PATH} ] || [ ! -e ${INPUT_PATH} ]; then - echo "Error: input model not found" - echo "" - usage -fi - -# remove previous log -rm -rf "${OUTPUT_PATH}.log" - -show_err_onexit() -{ - cat "${OUTPUT_PATH}.log" -} - -trap show_err_onexit ERR - -# convert .tflite to .circle -echo "${DRIVER_PATH}/tflite2circle" "${INPUT_PATH}" "${OUTPUT_PATH}" > "${OUTPUT_PATH}.log" - -"${DRIVER_PATH}/tflite2circle" "${INPUT_PATH}" "${OUTPUT_PATH}" >> "${OUTPUT_PATH}.log" 2>&1 +def get_driver_cfg_section(): + return "one-import-tflite" + + +def _get_parser(): + parser = argparse.ArgumentParser( + description='command line tool to convert TensorFlow lite to circle') + + oneutils.add_default_arg(parser) + + ## tflite2circle arguments + tflite2circle_group = parser.add_argument_group('converter arguments') + + # input and output path. + tflite2circle_group.add_argument( + '-i', '--input_path', type=str, help='full filepath of the input file') + tflite2circle_group.add_argument( + '-o', '--output_path', type=str, help='full filepath of the output file') + + return parser + + +def _verify_arg(parser, args): + """verify given arguments""" + # check if required arguments is given + missing = [] + if not oneutils.is_valid_attr(args, 'input_path'): + missing.append('-i/--input_path') + if not oneutils.is_valid_attr(args, 'output_path'): + missing.append('-o/--output_path') + if len(missing): + parser.error('the following arguments are required: ' + ' '.join(missing)) + + +def _parse_arg(parser): + args = parser.parse_args() + # print version + if args.version: + oneutils.print_version_and_exit(__file__) + + return args + + +def _convert(args): + # get file path to log + dir_path = os.path.dirname(os.path.realpath(__file__)) + logfile_path = os.path.realpath(args.output_path) + '.log' + + with open(logfile_path, 'wb') as f: + # make a command to convert from tflite to circle + tflite2circle_path = os.path.join(dir_path, 'tflite2circle') + tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path, + getattr(args, 'input_path'), + getattr(args, 'output_path')) + + f.write((' '.join(tflite2circle_cmd) + '\n').encode()) + + # convert tflite to circle + oneutils.run(tflite2circle_cmd, err_prefix="tflite2circle", logfile=f) + + +def main(): + # parse arguments + parser = _get_parser() + args = _parse_arg(parser) + + # parse configuration file + oneutils.parse_cfg(args.config, 'one-import-tflite', args) + + # verify arguments + _verify_arg(parser, args) + + # convert + _convert(args) + + +if __name__ == '__main__': + oneutils.safemain(main, __file__) |