summaryrefslogtreecommitdiff
path: root/compiler/one-cmds/onelib/utils.py
blob: f7a1a963a87057037f1d5a8ec6b9c24c6a3c50a7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
#!/usr/bin/env python

# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import configparser
import glob
import importlib.machinery
import importlib.util
import ntpath
import os
import subprocess
import sys

from typing import Union

import onelib.constant as _constant


def add_default_arg(parser):
    # version
    parser.add_argument(
        '-v',
        '--version',
        action='store_true',
        help='show program\'s version number and exit')

    # verbose
    parser.add_argument(
        '-V',
        '--verbose',
        action='store_true',
        help='output additional information to stdout or stderr')

    # configuration file
    parser.add_argument('-C', '--config', type=str, help='run with configuation file')
    # section name that you want to run in configuration file
    parser.add_argument('-S', '--section', type=str, help=argparse.SUPPRESS)


def add_default_arg_no_CS(parser):
    """
    This adds -v -V args only (no -C nor -S)
    """
    # version
    parser.add_argument(
        '-v',
        '--version',
        action='store_true',
        help='show program\'s version number and exit')

    # verbose
    parser.add_argument(
        '-V',
        '--verbose',
        action='store_true',
        help='output additional information to stdout or stderr')


def is_accumulated_arg(arg, driver):
    if driver == "one-quantize":
        accumulables = [
            "tensor_name", "scale", "zero_point", "src_tensor_name", "dst_tensor_name"
        ]
        if arg in accumulables:
            return True

    return False


def is_valid_attr(args, attr):
    return hasattr(args, attr) and getattr(args, attr)


def parse_cfg(config_path: Union[str, None], section_to_parse: str, args):
    """
    parse configuration file and store the information to args
    
    :param config_path: path to configuration file
    :param section_to_parse: section name to parse
    :param args: object to store the parsed information
    """
    if config_path is None:
        return

    parser = configparser.ConfigParser()
    parser.optionxform = str
    parser.read(config_path)

    if not parser.has_section(section_to_parse):
        raise AssertionError('configuration file must have \'' + section_to_parse +
                             '\' section')

    for key in parser[section_to_parse]:
        if is_accumulated_arg(key, section_to_parse):
            if not is_valid_attr(args, key):
                setattr(args, key, [parser[section_to_parse][key]])
            else:
                getattr(args, key).append(parser[section_to_parse][key])
            continue
        if hasattr(args, key) and getattr(args, key):
            continue
        setattr(args, key, parser[section_to_parse][key])


def print_version_and_exit(file_path):
    """print version of the file located in the file_path"""
    script_path = os.path.realpath(file_path)
    dir_path = os.path.dirname(script_path)
    script_name = os.path.splitext(os.path.basename(script_path))[0]
    # run one-version
    subprocess.call([os.path.join(dir_path, 'one-version'), script_name])
    sys.exit()


def safemain(main, mainpath):
    """execute given method and print with program name for all uncaught exceptions"""
    try:
        main()
    except Exception as e:
        prog_name = os.path.basename(mainpath)
        print(f"{prog_name}: {type(e).__name__}: " + str(e), file=sys.stderr)
        sys.exit(255)


def run(cmd, err_prefix=None, logfile=None):
    """Execute command in subprocess

    Args:
        cmd: command to be executed in subprocess
        err_prefix: prefix to be put before every stderr lines
        logfile: file stream to which both of stdout and stderr lines will be written
    """
    with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p:
        import select
        inputs = set([p.stdout, p.stderr])
        while inputs:
            readable, _, _ = select.select(inputs, [], [])
            for x in readable:
                line = x.readline()
                if len(line) == 0:
                    inputs.discard(x)
                    continue
                if x == p.stdout:
                    out = sys.stdout
                if x == p.stderr:
                    out = sys.stderr
                    if err_prefix:
                        line = f"{err_prefix}: ".encode() + line
                out.buffer.write(line)
                out.buffer.flush()
                if logfile != None:
                    logfile.write(line)
    if p.returncode != 0:
        sys.exit(p.returncode)


def remove_prefix(str, prefix):
    if str.startswith(prefix):
        return str[len(prefix):]
    return str


def remove_suffix(str, suffix):
    if str.endswith(suffix):
        return str[:-len(suffix)]
    return str


def get_optimization_list(get_name=False):
    """
    returns a list of optimization. If `get_name` is True,
    only basename without extension is returned rather than full file path.

    [one hierarchy]
    one
    ├── backends
    ├── bin
    ├── doc
    ├── include
    ├── lib
    ├── optimization
    └── test

    Optimization options must be placed in `optimization` folder
    """
    dir_path = os.path.dirname(os.path.realpath(__file__))

    # optimization folder
    files = [
        f for f in glob.glob(dir_path + '/../../optimization/O*.cfg', recursive=True)
    ]
    # exclude if the name has space
    files = [s for s in files if not ' ' in s]

    opt_list = []
    for cand in files:
        base = ntpath.basename(cand)
        if os.path.isfile(cand) and os.access(cand, os.R_OK):
            opt_list.append(cand)

    if get_name == True:
        # NOTE the name includes prefix 'O'
        # e.g. O1, O2, ONCHW not just 1, 2, NCHW
        opt_list = [ntpath.basename(f) for f in opt_list]
        opt_list = [remove_suffix(s, '.cfg') for s in opt_list]

    return opt_list


def detect_one_import_drivers(search_path):
    """Looks for import drivers in given directory

    Args:
        search_path: path to the directory where to search import drivers

    Returns:
    dict: each entry is related to single detected driver,
          key is a config section name, value is a driver name

    """
    import_drivers_dict = {}
    for module_name in os.listdir(search_path):
        full_path = os.path.join(search_path, module_name)
        if not os.path.isfile(full_path):
            continue
        if module_name.find("one-import-") != 0:
            continue
        module_loader = importlib.machinery.SourceFileLoader(module_name, full_path)
        module_spec = importlib.util.spec_from_loader(module_name, module_loader)
        module = importlib.util.module_from_spec(module_spec)
        try:
            module_loader.exec_module(module)
            if hasattr(module, "get_driver_cfg_section"):
                section = module.get_driver_cfg_section()
                import_drivers_dict[section] = module_name
        except:
            pass
    return import_drivers_dict