summaryrefslogtreecommitdiff
path: root/compiler/one-cmds/one-init
blob: 04c4534cd575acbca6b7e21434da4473fa8c33e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
#!/usr/bin/env bash
''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python                                       # '''
''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@"                                     # '''
''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
''''exit 255                                                                            # '''

# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import copy
import glob
import itertools
import ntpath
import os
import sys

import configparser
import utils as _utils

# TODO Find better way to suppress trackback on error
sys.tracebacklimit = 0


class CommentableConfigParser(configparser.ConfigParser):
    """
    ConfigParser where comment can be stored
    In Python ConfigParser, comment in ini file ( starting with ';') is considered a key of which
    value is None.
    Ref: https://stackoverflow.com/questions/6620637/writing-comments-to-files-with-configparser
    """

    def __init__(self):
        # allow_no_value=True to add comment
        # ref: https://stackoverflow.com/a/19432072
        configparser.ConfigParser.__init__(self, allow_no_value=True)
        self.optionxform = str

    def add_comment(self, section, comment):
        comment_sign = ';'
        self[section][f'{comment_sign} {comment}'] = None


def _get_backends_list():
    """
    [one hierarchy]
    one
    ├── backends
    ├── bin
    ├── doc
    ├── include
    ├── lib
    ├── optimization
    └── test

    The list where `one-init` finds its backends
    - `bin` folder where `one-init` exists
    - `backends` folder

    NOTE If there are backends of the same name in different places,
     the closer to the top in the list, the higher the priority.
    """
    dir_path = os.path.dirname(os.path.realpath(__file__))
    backend_set = set()

    # bin folder
    files = [f for f in glob.glob(dir_path + '/*-init')]
    # backends folder
    files += [f for f in glob.glob(dir_path + '/../backends/**/*-init', recursive=True)]
    # TODO find backends in `$PATH`

    backends_list = []
    for cand in files:
        base = ntpath.basename(cand)
        if (not base in backend_set) and os.path.isfile(cand) and os.access(
                cand, os.X_OK):
            backend_set.add(base)
            backends_list.append(cand)

    return backends_list


# TODO Add support for TF graphdef and bcq
def _get_parser(backends_list):
    init_usage = (
        'one-init [-h] [-v] [-V] '
        '[-i INPUT_PATH] '
        '[-o OUTPUT_PATH] '
        '[-m MODEL_TYPE] '
        '[-b BACKEND] '
        # args for onnx model
        '[--convert_nchw_to_nhwc] '
        '[--nchw_to_nhwc_input_shape] '
        '[--nchw_to_nhwc_output_shape] '
        # args for backend driver
        '[--] [COMMANDS FOR BACKEND DRIVER]')
    """
    NOTE
    layout options for onnx model could be difficult to users.
    In one-init, we could consider easier args for the the above three:
    For example, we could have another option, e.g., --input_img_layout LAYOUT
      - When LAYOUT is NHWC, apply 'nchw_to_nhwc_input_shape=True' into cfg
      - When LAYOUT is NCHW, apply 'nchw_to_nhwc_input_shape=False' into cfg
    """

    parser = argparse.ArgumentParser(
        description='Command line tool to generate initial cfg file. '
        'Currently tflite and onnx models are supported',
        usage=init_usage)

    _utils._add_default_arg_no_CS(parser)

    parser.add_argument(
        '-i', '--input_path', type=str, help='full filepath of the input model file')
    parser.add_argument(
        '-o', '--output_path', type=str, help='full filepath of the output cfg file')
    parser.add_argument(
        '-m',
        '--model_type',
        type=str,
        help=('type of input model: "onnx", "tflite". '
              'If the file extension passed to --input_path is '
              '".tflite" or ".onnx", this arg can be omitted.'))

    onnx_group = parser.add_argument_group('arguments when model type is onnx')
    onnx_group.add_argument(
        '--convert_nchw_to_nhwc',
        action='store_true',
        help=
        'Convert NCHW operators to NHWC under the assumption that input model is NCHW.')
    onnx_group.add_argument(
        '--nchw_to_nhwc_input_shape',
        action='store_true',
        help='Convert the input shape of the model (argument for convert_nchw_to_nhwc)')
    onnx_group.add_argument(
        '--nchw_to_nhwc_output_shape',
        action='store_true',
        help='Convert the output shape of the model (argument for convert_nchw_to_nhwc)')

    # get backend list in the directory
    backends_name = [ntpath.basename(f) for f in backends_list]
    if not backends_name:
        backends_name_message = '(There is no available backend drivers)'
    else:
        backends_name_message = '(available backend drivers: ' + ', '.join(
            backends_name) + ')'
    backend_help_message = 'backend name to use ' + backends_name_message
    parser.add_argument('-b', '--backend', type=str, help=backend_help_message)

    return parser


def _verify_arg(parser, args):
    # check if required arguments is given
    missing = []
    if not _utils._is_valid_attr(args, 'input_path'):
        missing.append('-i/--input_path')
    if not _utils._is_valid_attr(args, 'output_path'):
        missing.append('-o/--output_path')
    if not _utils._is_valid_attr(args, 'backend'):
        missing.append('-b/--backend')

    if _utils._is_valid_attr(args, 'model_type'):
        # TODO Support model types other than onnx and tflite (e.g., TF)
        if getattr(args, 'model_type') not in ['onnx', 'tflite']:
            parser.error('Allowed value for --model_type: "onnx" or "tflite"')

    if _utils._is_valid_attr(args, 'nchw_to_nhwc_input_shape'):
        if not _utils._is_valid_attr(args, 'convert_nchw_to_nhwc'):
            missing.append('--convert_nchw_to_nhwc')
    if _utils._is_valid_attr(args, 'nchw_to_nhwc_output_shape'):
        if not _utils._is_valid_attr(args, 'convert_nchw_to_nhwc'):
            missing.append('--convert_nchw_to_nhwc')

    if len(missing):
        parser.error('the following arguments are required: ' + ' '.join(missing))


def _parse_arg(parser):
    init_args = []
    backend_args = []
    argv = copy.deepcopy(sys.argv)
    # delete file name
    del argv[0]
    # split by '--'
    args = [list(y) for x, y in itertools.groupby(argv, lambda z: z == '--') if not x]

    # one-init [-h] [-v] ...
    if len(args):
        init_args = args[0]
        init_args = parser.parse_args(init_args)
        backend_args = backend_args if len(args) < 2 else args[1]
    # print version
    if len(args) and init_args.version:
        _utils._print_version_and_exit(__file__)

    return init_args, backend_args


def _get_executable(args, backends_list):
    if _utils._is_valid_attr(args, 'backend'):
        backend_base = getattr(args, 'backend') + '-init'
        for cand in backends_list:
            if ntpath.basename(cand) == backend_base:
                return cand
        raise FileNotFoundError(backend_base + ' not found')


# TODO Support workflow format (https://github.com/Samsung/ONE/pull/9354)
def _generate():
    # generate cfg file
    config = CommentableConfigParser()

    def _add_onecc_sections():
        pass  # NYI

    def _gen_import():
        pass  # NYI

    def _gen_optimize():
        pass  # NYI

    def _gen_quantize():
        pass  # NYI

    def _gen_codegen():
        pass  # NYI

    #
    # NYI: one-profile, one-partition, one-pack, one-infer
    #

    _add_onecc_sections()

    _gen_import()
    _gen_optimize()
    _gen_quantize()
    _gen_codegen()

    with open(args.output_path, 'w') as f:
        config.write(f)


def main():
    # get backend list
    backends_list = _get_backends_list()

    # parse arguments
    parser = _get_parser(backends_list)
    args, backend_args = _parse_arg(parser)

    # verify arguments
    _verify_arg(parser, args)

    # make a command to run given backend driver
    driver_path = _get_executable(args, backends_list)
    init_cmd = [driver_path] + backend_args

    # run backend driver
    _utils._run(init_cmd, err_prefix=ntpath.basename(driver_path))

    #TODO generate cfg file

    raise NotImplementedError("NYI")


if __name__ == '__main__':
    _utils._safemain(main, __file__)