diff options
Diffstat (limited to 'compiler/one-cmds/one-init')
-rw-r--r-- | compiler/one-cmds/one-init | 280 |
1 files changed, 280 insertions, 0 deletions
diff --git a/compiler/one-cmds/one-init b/compiler/one-cmds/one-init new file mode 100644 index 000000000..04c4534cd --- /dev/null +++ b/compiler/one-cmds/one-init @@ -0,0 +1,280 @@ +#!/usr/bin/env bash +''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # ''' +''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # ''' +''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # ''' +''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # ''' +''''exit 255 # ''' + +# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import copy +import glob +import itertools +import ntpath +import os +import sys + +import configparser +import utils as _utils + +# TODO Find better way to suppress trackback on error +sys.tracebacklimit = 0 + + +class CommentableConfigParser(configparser.ConfigParser): + """ + ConfigParser where comment can be stored + In Python ConfigParser, comment in ini file ( starting with ';') is considered a key of which + value is None. + Ref: https://stackoverflow.com/questions/6620637/writing-comments-to-files-with-configparser + """ + + def __init__(self): + # allow_no_value=True to add comment + # ref: https://stackoverflow.com/a/19432072 + configparser.ConfigParser.__init__(self, allow_no_value=True) + self.optionxform = str + + def add_comment(self, section, comment): + comment_sign = ';' + self[section][f'{comment_sign} {comment}'] = None + + +def _get_backends_list(): + """ + [one hierarchy] + one + ├── backends + ├── bin + ├── doc + ├── include + ├── lib + ├── optimization + └── test + + The list where `one-init` finds its backends + - `bin` folder where `one-init` exists + - `backends` folder + + NOTE If there are backends of the same name in different places, + the closer to the top in the list, the higher the priority. + """ + dir_path = os.path.dirname(os.path.realpath(__file__)) + backend_set = set() + + # bin folder + files = [f for f in glob.glob(dir_path + '/*-init')] + # backends folder + files += [f for f in glob.glob(dir_path + '/../backends/**/*-init', recursive=True)] + # TODO find backends in `$PATH` + + backends_list = [] + for cand in files: + base = ntpath.basename(cand) + if (not base in backend_set) and os.path.isfile(cand) and os.access( + cand, os.X_OK): + backend_set.add(base) + backends_list.append(cand) + + return backends_list + + +# TODO Add support for TF graphdef and bcq +def _get_parser(backends_list): + init_usage = ( + 'one-init [-h] [-v] [-V] ' + '[-i INPUT_PATH] ' + '[-o OUTPUT_PATH] ' + '[-m MODEL_TYPE] ' + '[-b BACKEND] ' + # args for onnx model + '[--convert_nchw_to_nhwc] ' + '[--nchw_to_nhwc_input_shape] ' + '[--nchw_to_nhwc_output_shape] ' + # args for backend driver + '[--] [COMMANDS FOR BACKEND DRIVER]') + """ + NOTE + layout options for onnx model could be difficult to users. + In one-init, we could consider easier args for the the above three: + For example, we could have another option, e.g., --input_img_layout LAYOUT + - When LAYOUT is NHWC, apply 'nchw_to_nhwc_input_shape=True' into cfg + - When LAYOUT is NCHW, apply 'nchw_to_nhwc_input_shape=False' into cfg + """ + + parser = argparse.ArgumentParser( + description='Command line tool to generate initial cfg file. ' + 'Currently tflite and onnx models are supported', + usage=init_usage) + + _utils._add_default_arg_no_CS(parser) + + parser.add_argument( + '-i', '--input_path', type=str, help='full filepath of the input model file') + parser.add_argument( + '-o', '--output_path', type=str, help='full filepath of the output cfg file') + parser.add_argument( + '-m', + '--model_type', + type=str, + help=('type of input model: "onnx", "tflite". ' + 'If the file extension passed to --input_path is ' + '".tflite" or ".onnx", this arg can be omitted.')) + + onnx_group = parser.add_argument_group('arguments when model type is onnx') + onnx_group.add_argument( + '--convert_nchw_to_nhwc', + action='store_true', + help= + 'Convert NCHW operators to NHWC under the assumption that input model is NCHW.') + onnx_group.add_argument( + '--nchw_to_nhwc_input_shape', + action='store_true', + help='Convert the input shape of the model (argument for convert_nchw_to_nhwc)') + onnx_group.add_argument( + '--nchw_to_nhwc_output_shape', + action='store_true', + help='Convert the output shape of the model (argument for convert_nchw_to_nhwc)') + + # get backend list in the directory + backends_name = [ntpath.basename(f) for f in backends_list] + if not backends_name: + backends_name_message = '(There is no available backend drivers)' + else: + backends_name_message = '(available backend drivers: ' + ', '.join( + backends_name) + ')' + backend_help_message = 'backend name to use ' + backends_name_message + parser.add_argument('-b', '--backend', type=str, help=backend_help_message) + + return parser + + +def _verify_arg(parser, args): + # check if required arguments is given + missing = [] + if not _utils._is_valid_attr(args, 'input_path'): + missing.append('-i/--input_path') + if not _utils._is_valid_attr(args, 'output_path'): + missing.append('-o/--output_path') + if not _utils._is_valid_attr(args, 'backend'): + missing.append('-b/--backend') + + if _utils._is_valid_attr(args, 'model_type'): + # TODO Support model types other than onnx and tflite (e.g., TF) + if getattr(args, 'model_type') not in ['onnx', 'tflite']: + parser.error('Allowed value for --model_type: "onnx" or "tflite"') + + if _utils._is_valid_attr(args, 'nchw_to_nhwc_input_shape'): + if not _utils._is_valid_attr(args, 'convert_nchw_to_nhwc'): + missing.append('--convert_nchw_to_nhwc') + if _utils._is_valid_attr(args, 'nchw_to_nhwc_output_shape'): + if not _utils._is_valid_attr(args, 'convert_nchw_to_nhwc'): + missing.append('--convert_nchw_to_nhwc') + + if len(missing): + parser.error('the following arguments are required: ' + ' '.join(missing)) + + +def _parse_arg(parser): + init_args = [] + backend_args = [] + argv = copy.deepcopy(sys.argv) + # delete file name + del argv[0] + # split by '--' + args = [list(y) for x, y in itertools.groupby(argv, lambda z: z == '--') if not x] + + # one-init [-h] [-v] ... + if len(args): + init_args = args[0] + init_args = parser.parse_args(init_args) + backend_args = backend_args if len(args) < 2 else args[1] + # print version + if len(args) and init_args.version: + _utils._print_version_and_exit(__file__) + + return init_args, backend_args + + +def _get_executable(args, backends_list): + if _utils._is_valid_attr(args, 'backend'): + backend_base = getattr(args, 'backend') + '-init' + for cand in backends_list: + if ntpath.basename(cand) == backend_base: + return cand + raise FileNotFoundError(backend_base + ' not found') + + +# TODO Support workflow format (https://github.com/Samsung/ONE/pull/9354) +def _generate(): + # generate cfg file + config = CommentableConfigParser() + + def _add_onecc_sections(): + pass # NYI + + def _gen_import(): + pass # NYI + + def _gen_optimize(): + pass # NYI + + def _gen_quantize(): + pass # NYI + + def _gen_codegen(): + pass # NYI + + # + # NYI: one-profile, one-partition, one-pack, one-infer + # + + _add_onecc_sections() + + _gen_import() + _gen_optimize() + _gen_quantize() + _gen_codegen() + + with open(args.output_path, 'w') as f: + config.write(f) + + +def main(): + # get backend list + backends_list = _get_backends_list() + + # parse arguments + parser = _get_parser(backends_list) + args, backend_args = _parse_arg(parser) + + # verify arguments + _verify_arg(parser, args) + + # make a command to run given backend driver + driver_path = _get_executable(args, backends_list) + init_cmd = [driver_path] + backend_args + + # run backend driver + _utils._run(init_cmd, err_prefix=ntpath.basename(driver_path)) + + #TODO generate cfg file + + raise NotImplementedError("NYI") + + +if __name__ == '__main__': + _utils._safemain(main, __file__) |