summaryrefslogtreecommitdiff
path: root/compiler/one-cmds/one-import-tflite
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/one-cmds/one-import-tflite')
-rw-r--r--compiler/one-cmds/one-import-tflite155
1 files changed, 91 insertions, 64 deletions
diff --git a/compiler/one-cmds/one-import-tflite b/compiler/one-cmds/one-import-tflite
index 053489c92..8eba46dc5 100644
--- a/compiler/one-cmds/one-import-tflite
+++ b/compiler/one-cmds/one-import-tflite
@@ -1,4 +1,9 @@
-#!/bin/bash
+#!/usr/bin/env bash
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
+''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
+''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
+''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
+''''exit 255 # '''
# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
#
@@ -14,70 +19,92 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-set -e
+import argparse
+import os
+import sys
-DRIVER_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+import onelib.make_cmd as _make_cmd
+import onelib.utils as oneutils
-usage()
-{
- echo "Convert TensorFlow lite model to circle."
- echo "Usage: one-import-tflite"
- echo " --version Show version information and exit"
- echo " --input_path <path/to/tflitemodel>"
- echo " --output_path <path/to/circle>"
- exit 255
-}
+# TODO Find better way to suppress trackback on error
+sys.tracebacklimit = 0
-version()
-{
- $DRIVER_PATH/one-version one-import-tflite
- exit 255
-}
-# Parse command-line arguments
-#
-while [ "$#" -ne 0 ]; do
- CUR="$1"
-
- case $CUR in
- '--help')
- usage
- ;;
- '--version')
- version
- ;;
- '--input_path')
- export INPUT_PATH="$2"
- shift 2
- ;;
- '--output_path')
- export OUTPUT_PATH="$2"
- shift 2
- ;;
- *)
- echo "Unknown parameter: ${CUR}"
- shift
- ;;
- esac
-done
-
-if [ -z ${INPUT_PATH} ] || [ ! -e ${INPUT_PATH} ]; then
- echo "Error: input model not found"
- echo ""
- usage
-fi
-
-# remove previous log
-rm -rf "${OUTPUT_PATH}.log"
-
-show_err_onexit()
-{
- cat "${OUTPUT_PATH}.log"
-}
-
-trap show_err_onexit ERR
-
-# convert .tflite to .circle
-echo "${DRIVER_PATH}/tflite2circle" "${INPUT_PATH}" "${OUTPUT_PATH}" > "${OUTPUT_PATH}.log"
-
-"${DRIVER_PATH}/tflite2circle" "${INPUT_PATH}" "${OUTPUT_PATH}" >> "${OUTPUT_PATH}.log" 2>&1
+def get_driver_cfg_section():
+ return "one-import-tflite"
+
+
+def _get_parser():
+ parser = argparse.ArgumentParser(
+ description='command line tool to convert TensorFlow lite to circle')
+
+ oneutils.add_default_arg(parser)
+
+ ## tflite2circle arguments
+ tflite2circle_group = parser.add_argument_group('converter arguments')
+
+ # input and output path.
+ tflite2circle_group.add_argument(
+ '-i', '--input_path', type=str, help='full filepath of the input file')
+ tflite2circle_group.add_argument(
+ '-o', '--output_path', type=str, help='full filepath of the output file')
+
+ return parser
+
+
+def _verify_arg(parser, args):
+ """verify given arguments"""
+ # check if required arguments is given
+ missing = []
+ if not oneutils.is_valid_attr(args, 'input_path'):
+ missing.append('-i/--input_path')
+ if not oneutils.is_valid_attr(args, 'output_path'):
+ missing.append('-o/--output_path')
+ if len(missing):
+ parser.error('the following arguments are required: ' + ' '.join(missing))
+
+
+def _parse_arg(parser):
+ args = parser.parse_args()
+ # print version
+ if args.version:
+ oneutils.print_version_and_exit(__file__)
+
+ return args
+
+
+def _convert(args):
+ # get file path to log
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ logfile_path = os.path.realpath(args.output_path) + '.log'
+
+ with open(logfile_path, 'wb') as f:
+ # make a command to convert from tflite to circle
+ tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
+ tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
+ getattr(args, 'input_path'),
+ getattr(args, 'output_path'))
+
+ f.write((' '.join(tflite2circle_cmd) + '\n').encode())
+
+ # convert tflite to circle
+ oneutils.run(tflite2circle_cmd, err_prefix="tflite2circle", logfile=f)
+
+
+def main():
+ # parse arguments
+ parser = _get_parser()
+ args = _parse_arg(parser)
+
+ # parse configuration file
+ oneutils.parse_cfg(args.config, 'one-import-tflite', args)
+
+ # verify arguments
+ _verify_arg(parser, args)
+
+ # convert
+ _convert(args)
+
+
+if __name__ == '__main__':
+ oneutils.safemain(main, __file__)