summaryrefslogtreecommitdiff
path: root/compiler/one-cmds/one-init
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/one-cmds/one-init')
-rw-r--r--compiler/one-cmds/one-init280
1 files changed, 280 insertions, 0 deletions
diff --git a/compiler/one-cmds/one-init b/compiler/one-cmds/one-init
new file mode 100644
index 000000000..04c4534cd
--- /dev/null
+++ b/compiler/one-cmds/one-init
@@ -0,0 +1,280 @@
+#!/usr/bin/env bash
+''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
+''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
+''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
+''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
+''''exit 255 # '''
+
+# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import copy
+import glob
+import itertools
+import ntpath
+import os
+import sys
+
+import configparser
+import utils as _utils
+
+# TODO Find better way to suppress trackback on error
+sys.tracebacklimit = 0
+
+
+class CommentableConfigParser(configparser.ConfigParser):
+ """
+ ConfigParser where comment can be stored
+ In Python ConfigParser, comment in ini file ( starting with ';') is considered a key of which
+ value is None.
+ Ref: https://stackoverflow.com/questions/6620637/writing-comments-to-files-with-configparser
+ """
+
+ def __init__(self):
+ # allow_no_value=True to add comment
+ # ref: https://stackoverflow.com/a/19432072
+ configparser.ConfigParser.__init__(self, allow_no_value=True)
+ self.optionxform = str
+
+ def add_comment(self, section, comment):
+ comment_sign = ';'
+ self[section][f'{comment_sign} {comment}'] = None
+
+
+def _get_backends_list():
+ """
+ [one hierarchy]
+ one
+ ├── backends
+ ├── bin
+ ├── doc
+ ├── include
+ ├── lib
+ ├── optimization
+ └── test
+
+ The list where `one-init` finds its backends
+ - `bin` folder where `one-init` exists
+ - `backends` folder
+
+ NOTE If there are backends of the same name in different places,
+ the closer to the top in the list, the higher the priority.
+ """
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ backend_set = set()
+
+ # bin folder
+ files = [f for f in glob.glob(dir_path + '/*-init')]
+ # backends folder
+ files += [f for f in glob.glob(dir_path + '/../backends/**/*-init', recursive=True)]
+ # TODO find backends in `$PATH`
+
+ backends_list = []
+ for cand in files:
+ base = ntpath.basename(cand)
+ if (not base in backend_set) and os.path.isfile(cand) and os.access(
+ cand, os.X_OK):
+ backend_set.add(base)
+ backends_list.append(cand)
+
+ return backends_list
+
+
+# TODO Add support for TF graphdef and bcq
+def _get_parser(backends_list):
+ init_usage = (
+ 'one-init [-h] [-v] [-V] '
+ '[-i INPUT_PATH] '
+ '[-o OUTPUT_PATH] '
+ '[-m MODEL_TYPE] '
+ '[-b BACKEND] '
+ # args for onnx model
+ '[--convert_nchw_to_nhwc] '
+ '[--nchw_to_nhwc_input_shape] '
+ '[--nchw_to_nhwc_output_shape] '
+ # args for backend driver
+ '[--] [COMMANDS FOR BACKEND DRIVER]')
+ """
+ NOTE
+ layout options for onnx model could be difficult to users.
+ In one-init, we could consider easier args for the the above three:
+ For example, we could have another option, e.g., --input_img_layout LAYOUT
+ - When LAYOUT is NHWC, apply 'nchw_to_nhwc_input_shape=True' into cfg
+ - When LAYOUT is NCHW, apply 'nchw_to_nhwc_input_shape=False' into cfg
+ """
+
+ parser = argparse.ArgumentParser(
+ description='Command line tool to generate initial cfg file. '
+ 'Currently tflite and onnx models are supported',
+ usage=init_usage)
+
+ _utils._add_default_arg_no_CS(parser)
+
+ parser.add_argument(
+ '-i', '--input_path', type=str, help='full filepath of the input model file')
+ parser.add_argument(
+ '-o', '--output_path', type=str, help='full filepath of the output cfg file')
+ parser.add_argument(
+ '-m',
+ '--model_type',
+ type=str,
+ help=('type of input model: "onnx", "tflite". '
+ 'If the file extension passed to --input_path is '
+ '".tflite" or ".onnx", this arg can be omitted.'))
+
+ onnx_group = parser.add_argument_group('arguments when model type is onnx')
+ onnx_group.add_argument(
+ '--convert_nchw_to_nhwc',
+ action='store_true',
+ help=
+ 'Convert NCHW operators to NHWC under the assumption that input model is NCHW.')
+ onnx_group.add_argument(
+ '--nchw_to_nhwc_input_shape',
+ action='store_true',
+ help='Convert the input shape of the model (argument for convert_nchw_to_nhwc)')
+ onnx_group.add_argument(
+ '--nchw_to_nhwc_output_shape',
+ action='store_true',
+ help='Convert the output shape of the model (argument for convert_nchw_to_nhwc)')
+
+ # get backend list in the directory
+ backends_name = [ntpath.basename(f) for f in backends_list]
+ if not backends_name:
+ backends_name_message = '(There is no available backend drivers)'
+ else:
+ backends_name_message = '(available backend drivers: ' + ', '.join(
+ backends_name) + ')'
+ backend_help_message = 'backend name to use ' + backends_name_message
+ parser.add_argument('-b', '--backend', type=str, help=backend_help_message)
+
+ return parser
+
+
+def _verify_arg(parser, args):
+ # check if required arguments is given
+ missing = []
+ if not _utils._is_valid_attr(args, 'input_path'):
+ missing.append('-i/--input_path')
+ if not _utils._is_valid_attr(args, 'output_path'):
+ missing.append('-o/--output_path')
+ if not _utils._is_valid_attr(args, 'backend'):
+ missing.append('-b/--backend')
+
+ if _utils._is_valid_attr(args, 'model_type'):
+ # TODO Support model types other than onnx and tflite (e.g., TF)
+ if getattr(args, 'model_type') not in ['onnx', 'tflite']:
+ parser.error('Allowed value for --model_type: "onnx" or "tflite"')
+
+ if _utils._is_valid_attr(args, 'nchw_to_nhwc_input_shape'):
+ if not _utils._is_valid_attr(args, 'convert_nchw_to_nhwc'):
+ missing.append('--convert_nchw_to_nhwc')
+ if _utils._is_valid_attr(args, 'nchw_to_nhwc_output_shape'):
+ if not _utils._is_valid_attr(args, 'convert_nchw_to_nhwc'):
+ missing.append('--convert_nchw_to_nhwc')
+
+ if len(missing):
+ parser.error('the following arguments are required: ' + ' '.join(missing))
+
+
+def _parse_arg(parser):
+ init_args = []
+ backend_args = []
+ argv = copy.deepcopy(sys.argv)
+ # delete file name
+ del argv[0]
+ # split by '--'
+ args = [list(y) for x, y in itertools.groupby(argv, lambda z: z == '--') if not x]
+
+ # one-init [-h] [-v] ...
+ if len(args):
+ init_args = args[0]
+ init_args = parser.parse_args(init_args)
+ backend_args = backend_args if len(args) < 2 else args[1]
+ # print version
+ if len(args) and init_args.version:
+ _utils._print_version_and_exit(__file__)
+
+ return init_args, backend_args
+
+
+def _get_executable(args, backends_list):
+ if _utils._is_valid_attr(args, 'backend'):
+ backend_base = getattr(args, 'backend') + '-init'
+ for cand in backends_list:
+ if ntpath.basename(cand) == backend_base:
+ return cand
+ raise FileNotFoundError(backend_base + ' not found')
+
+
+# TODO Support workflow format (https://github.com/Samsung/ONE/pull/9354)
+def _generate():
+ # generate cfg file
+ config = CommentableConfigParser()
+
+ def _add_onecc_sections():
+ pass # NYI
+
+ def _gen_import():
+ pass # NYI
+
+ def _gen_optimize():
+ pass # NYI
+
+ def _gen_quantize():
+ pass # NYI
+
+ def _gen_codegen():
+ pass # NYI
+
+ #
+ # NYI: one-profile, one-partition, one-pack, one-infer
+ #
+
+ _add_onecc_sections()
+
+ _gen_import()
+ _gen_optimize()
+ _gen_quantize()
+ _gen_codegen()
+
+ with open(args.output_path, 'w') as f:
+ config.write(f)
+
+
+def main():
+ # get backend list
+ backends_list = _get_backends_list()
+
+ # parse arguments
+ parser = _get_parser(backends_list)
+ args, backend_args = _parse_arg(parser)
+
+ # verify arguments
+ _verify_arg(parser, args)
+
+ # make a command to run given backend driver
+ driver_path = _get_executable(args, backends_list)
+ init_cmd = [driver_path] + backend_args
+
+ # run backend driver
+ _utils._run(init_cmd, err_prefix=ntpath.basename(driver_path))
+
+ #TODO generate cfg file
+
+ raise NotImplementedError("NYI")
+
+
+if __name__ == '__main__':
+ _utils._safemain(main, __file__)