summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoryunjayh <50489513+yunjayh@users.noreply.github.com>2022-06-30 10:12:58 +0900
committerGitHub <noreply@github.com>2022-06-30 10:12:58 +0900
commitc426200a29e4236c73562331c70e5bf35663c0a8 (patch)
tree16b4564723c6af5525b603669dc6ebc89a16d024
parent39b920db01bef399392ad2c3188773fc022b95a1 (diff)
downloadnnfw-c426200a29e4236c73562331c70e5bf35663c0a8.tar.gz
nnfw-c426200a29e4236c73562331c70e5bf35663c0a8.tar.bz2
nnfw-c426200a29e4236c73562331c70e5bf35663c0a8.zip
[one-cmds] Fix python files by formatter (#9353)
This commit fixes python files by formatter. Signed-off-by: yunjay-hong <yunjay.hong@samsung.com>
-rw-r--r--compiler/one-cmds/one-build3
-rw-r--r--compiler/one-cmds/one-import-bcq10
-rw-r--r--compiler/one-cmds/one-import-onnx4
-rw-r--r--compiler/one-cmds/one-import-pytorch7
-rw-r--r--compiler/one-cmds/one-import-tf8
-rw-r--r--compiler/one-cmds/one-import-tflite4
-rw-r--r--compiler/one-cmds/one-infer18
-rw-r--r--compiler/one-cmds/one-optimize6
-rw-r--r--compiler/one-cmds/one-partition14
-rw-r--r--compiler/one-cmds/one-quantize31
-rw-r--r--compiler/one-cmds/onecc13
-rw-r--r--compiler/one-cmds/onelib/constant.py4
-rw-r--r--compiler/one-cmds/onelib/make_cmd.py1
-rwxr-xr-xcompiler/one-cmds/onnx_legalizer.py59
-rw-r--r--compiler/one-cmds/utils.py70
15 files changed, 128 insertions, 124 deletions
diff --git a/compiler/one-cmds/one-build b/compiler/one-cmds/one-build
index a6594498c..4b1f98070 100644
--- a/compiler/one-cmds/one-build
+++ b/compiler/one-cmds/one-build
@@ -157,7 +157,8 @@ def main():
bin_dir = os.path.dirname(os.path.realpath(__file__))
import_drivers_dict = _utils._detect_one_import_drivers(bin_dir)
transform_drivers = [
- 'one-optimize', 'one-quantize', 'one-pack', 'one-codegen', 'one-profile', 'one-partition'
+ 'one-optimize', 'one-quantize', 'one-pack', 'one-codegen', 'one-profile',
+ 'one-partition'
]
_verify_cfg(import_drivers_dict, config)
diff --git a/compiler/one-cmds/one-import-bcq b/compiler/one-cmds/one-import-bcq
index 5aa1ca9c6..c3ef0b275 100644
--- a/compiler/one-cmds/one-import-bcq
+++ b/compiler/one-cmds/one-import-bcq
@@ -159,9 +159,9 @@ def _convert(args):
tmpdir,
os.path.splitext(
os.path.basename(generate_bcq_metadata_output_path))[0]) + '.tflite'
- tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(args, tf2tfliteV2_path,
- generate_bcq_metadata_output_path,
- tf2tfliteV2_output_path)
+ tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(
+ args, tf2tfliteV2_path, generate_bcq_metadata_output_path,
+ tf2tfliteV2_output_path)
try:
output_arrays_idx = tf2tfliteV2_cmd.index('--output_arrays')
tf2tfliteV2_cmd[output_arrays_idx + 1] = ','.join(bcq_output_arrays)
@@ -176,8 +176,8 @@ def _convert(args):
# make a command to convert from tflite to circle
tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
- tf2tfliteV2_output_path,
- getattr(args, 'output_path'))
+ tf2tfliteV2_output_path,
+ getattr(args, 'output_path'))
f.write((' '.join(tflite2circle_cmd) + '\n').encode())
diff --git a/compiler/one-cmds/one-import-onnx b/compiler/one-cmds/one-import-onnx
index 5305cc62e..8271d56cf 100644
--- a/compiler/one-cmds/one-import-onnx
+++ b/compiler/one-cmds/one-import-onnx
@@ -165,8 +165,8 @@ def _convert(args):
# make a command to convert from tflite to circle
tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
- tf2tfliteV2_output_path,
- getattr(args, 'output_path'))
+ tf2tfliteV2_output_path,
+ getattr(args, 'output_path'))
f.write((' '.join(tflite2circle_cmd) + '\n').encode())
diff --git a/compiler/one-cmds/one-import-pytorch b/compiler/one-cmds/one-import-pytorch
index dbf1ba6d7..7f39e61bb 100644
--- a/compiler/one-cmds/one-import-pytorch
+++ b/compiler/one-cmds/one-import-pytorch
@@ -80,7 +80,8 @@ def _get_parser():
tf2tflite_group.add_argument('--converter_version', default='v2')
parser.add_argument('--unroll_rnn', action='store_true', help='Unroll RNN operators')
- parser.add_argument('--unroll_lstm', action='store_true', help='Unroll LSTM operators')
+ parser.add_argument(
+ '--unroll_lstm', action='store_true', help='Unroll LSTM operators')
# save intermediate file(s)
parser.add_argument(
@@ -338,8 +339,8 @@ def _convert(args):
# make a command to convert from tflite to circle
tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
- tf2tfliteV2_output_path,
- getattr(args, 'output_path'))
+ tf2tfliteV2_output_path,
+ getattr(args, 'output_path'))
f.write((' '.join(tflite2circle_cmd) + '\n').encode())
diff --git a/compiler/one-cmds/one-import-tf b/compiler/one-cmds/one-import-tf
index 8c03c06c6..6623fa6a4 100644
--- a/compiler/one-cmds/one-import-tf
+++ b/compiler/one-cmds/one-import-tf
@@ -150,8 +150,8 @@ def _convert(args):
tmpdir,
os.path.splitext(os.path.basename(args.output_path))[0]) + '.tflite'
tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(args, tf2tfliteV2_path,
- getattr(args, 'input_path'),
- tf2tfliteV2_output_path)
+ getattr(args, 'input_path'),
+ tf2tfliteV2_output_path)
f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
@@ -161,8 +161,8 @@ def _convert(args):
# make a command to convert from tflite to circle
tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
- tf2tfliteV2_output_path,
- getattr(args, 'output_path'))
+ tf2tfliteV2_output_path,
+ getattr(args, 'output_path'))
f.write((' '.join(tflite2circle_cmd) + '\n').encode())
diff --git a/compiler/one-cmds/one-import-tflite b/compiler/one-cmds/one-import-tflite
index 6072dad26..3d96b117f 100644
--- a/compiler/one-cmds/one-import-tflite
+++ b/compiler/one-cmds/one-import-tflite
@@ -82,8 +82,8 @@ def _convert(args):
# make a command to convert from tflite to circle
tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
- getattr(args, 'input_path'),
- getattr(args, 'output_path'))
+ getattr(args, 'input_path'),
+ getattr(args, 'output_path'))
f.write((' '.join(tflite2circle_cmd) + '\n').encode())
diff --git a/compiler/one-cmds/one-infer b/compiler/one-cmds/one-infer
index 610c08cb7..6dfca7588 100644
--- a/compiler/one-cmds/one-infer
+++ b/compiler/one-cmds/one-infer
@@ -58,15 +58,14 @@ def _get_backends_list():
# bin folder
files = [f for f in glob.glob(dir_path + '/*-infer')]
# backends folder
- files += [
- f for f in glob.glob(dir_path + '/../backends/**/*-infer', recursive=True)
- ]
+ files += [f for f in glob.glob(dir_path + '/../backends/**/*-infer', recursive=True)]
# TODO find backends in `$PATH`
backends_list = []
for cand in files:
base = ntpath.basename(cand)
- if (not base in backend_set) and os.path.isfile(cand) and os.access(cand, os.X_OK):
+ if (not base in backend_set) and os.path.isfile(cand) and os.access(
+ cand, os.X_OK):
backend_set.add(base)
backends_list.append(cand)
@@ -100,7 +99,8 @@ def _search_backend_driver(driver):
return driver_path
# CASE 2: one/backends/**/bin/{driver} is found
- for driver_path in glob.glob(dir_path + '/../backends/**/bin/' + driver, recursive=True):
+ for driver_path in glob.glob(
+ dir_path + '/../backends/**/bin/' + driver, recursive=True):
if os.path.isfile(driver_path) and os.access(driver_path, os.X_OK):
return driver_path
@@ -136,10 +136,12 @@ def _verify_arg(parser, args):
"""verify given arguments"""
# `-d/--driver` and `-b/--backend` are mutually exclusive arguments.
if _utils._is_valid_attr(args, 'driver') and _utils._is_valid_attr(args, 'backend'):
- parser.error('-d and -b options are mutually exclusive. Please use only one of them')
+ parser.error(
+ '-d and -b options are mutually exclusive. Please use only one of them')
missing = []
- if not _utils._is_valid_attr(args, 'driver') and not _utils._is_valid_attr(args, 'backend'):
+ if not _utils._is_valid_attr(args, 'driver') and not _utils._is_valid_attr(
+ args, 'backend'):
missing.append('{-d/--driver | -b/--backend}')
if len(missing):
parser.error('the following arguments are required: ' + ' '.join(missing))
@@ -153,7 +155,7 @@ def _parse_arg(parser):
del argv[0]
# split by '--'
args = [list(y) for x, y in itertools.groupby(argv, lambda z: z == '--') if not x]
-
+
# one-infer [-h] [-v] [-C CONFIG] [-d DRIVER] [-b BACKEND] -- [COMMANDS FOR BACKEND DRIVER]
if len(args):
infer_args = args[0]
diff --git a/compiler/one-cmds/one-optimize b/compiler/one-cmds/one-optimize
index b657351c4..481fc8459 100644
--- a/compiler/one-cmds/one-optimize
+++ b/compiler/one-cmds/one-optimize
@@ -88,7 +88,7 @@ def _verify_arg(parser, args):
# check if unrecognized arguments are given
diff = set(dir(args)) - set(dir(default))
if len(diff):
- parser.error('the following arguments are unrecognized: ' + ' '.join(diff))
+ parser.error('the following arguments are unrecognized: ' + ' '.join(diff))
def _parse_arg(parser):
@@ -109,8 +109,8 @@ def _optimize(args):
# make a command to optimize circle model
circle2circle_path = os.path.join(dir_path, 'circle2circle')
circle2circle_cmd = _make_cmd.make_circle2circle_cmd(args, circle2circle_path,
- getattr(args, 'input_path'),
- getattr(args, 'output_path'))
+ getattr(args, 'input_path'),
+ getattr(args, 'output_path'))
# verbose
if _utils._is_valid_attr(args, 'verbose'):
diff --git a/compiler/one-cmds/one-partition b/compiler/one-cmds/one-partition
index 521290eb5..c0d71e5d9 100644
--- a/compiler/one-cmds/one-partition
+++ b/compiler/one-cmds/one-partition
@@ -26,7 +26,6 @@ import sys
import utils as _utils
-
# TODO Find better way to suppress trackback on error
sys.tracebacklimit = 0
@@ -39,14 +38,15 @@ def _get_parser():
parser.add_argument(
'--backends', type=str, help='backends in CSV to use for partitioning')
- parser.add_argument(
- '--default', type=str, help='default backend to assign')
+ parser.add_argument('--default', type=str, help='default backend to assign')
- parser.add_argument('--part_file', type=str,help=
- 'partition file which provides backend to assign')
+ parser.add_argument(
+ '--part_file', type=str, help='partition file which provides backend to assign')
parser.add_argument('--input_file', type=str, help='input circle model filename')
- parser.add_argument('--work_path', type=str, help=
- 'work path of partition, input files exist and output files are produced')
+ parser.add_argument(
+ '--work_path',
+ type=str,
+ help='work path of partition, input files exist and output files are produced')
return parser
diff --git a/compiler/one-cmds/one-quantize b/compiler/one-cmds/one-quantize
index 2e973b521..11c7b4915 100644
--- a/compiler/one-cmds/one-quantize
+++ b/compiler/one-cmds/one-quantize
@@ -133,11 +133,7 @@ def _get_parser():
"Force MaxPool Op to have the same input/output quantparams. NOTE: This option can degrade accuracy of some models.)"
)
quantization_group.add_argument(
- '--quant_config',
- type=str,
- help=
- "Path to the quantization configuration file."
- )
+ '--quant_config', type=str, help="Path to the quantization configuration file.")
quantization_group.add_argument(
'--evaluate_result',
action='store_true',
@@ -145,11 +141,7 @@ def _get_parser():
"Evaluate accuracy of quantized model. Run inference for both fp32 model and the quantized model, and compare the inference results."
)
quantization_group.add_argument(
- '--test_data',
- type=str,
- help=
- "Path to the test data used for evaluation."
- )
+ '--test_data', type=str, help="Path to the test data used for evaluation.")
quantization_group.add_argument(
'--print_mae',
action='store_true',
@@ -226,9 +218,11 @@ def _set_default_values(args):
with open(getattr(args, 'quant_config')) as f:
qconf = json.load(f)
if 'default_quantization_dtype' in qconf:
- setattr(args, 'quantized_dtype', qconf['default_quantization_dtype'])
+ setattr(args, 'quantized_dtype',
+ qconf['default_quantization_dtype'])
except json.decoder.JSONDecodeError:
- print('Failed to decode ' + getattr(args, 'quant_config') + '. Please check it is a json file.')
+ print('Failed to decode ' + getattr(args, 'quant_config') +
+ '. Please check it is a json file.')
if not _utils._is_valid_attr(args, 'granularity'):
setattr(args, 'granularity', 'layer')
if _utils._is_valid_attr(args, 'quant_config'):
@@ -239,7 +233,8 @@ def _set_default_values(args):
if 'default_granularity' in qconf:
setattr(args, 'granularity', qconf['default_granularity'])
except json.decoder.JSONDecodeError:
- print('Failed to decode ' + getattr(args, 'quant_config') + '. Please check it is a json file.')
+ print('Failed to decode ' + getattr(args, 'quant_config') +
+ '. Please check it is a json file.')
if not _utils._is_valid_attr(args, 'mode'):
setattr(args, 'mode', 'percentile')
if not _utils._is_valid_attr(args, 'min_percentile'):
@@ -342,7 +337,8 @@ def _quantize(args):
circle_quantizer_cmd.append(getattr(args, 'input_path'))
tmp_weights_fake_quant_path = os.path.join(
tmpdir,
- os.path.splitext(os.path.basename(args.input_path))[0]) + '.weights_fake_quant.circle'
+ os.path.splitext(os.path.basename(
+ args.input_path))[0]) + '.weights_fake_quant.circle'
circle_quantizer_cmd.append(tmp_weights_fake_quant_path)
# profiling
if _utils._is_valid_attr(args, 'generate_profile_data'):
@@ -355,7 +351,8 @@ def _quantize(args):
tmp_minmax_recorded_path = os.path.join(
tmpdir,
- os.path.splitext(os.path.basename(args.input_path))[0]) + '.minmax_recorded.circle'
+ os.path.splitext(os.path.basename(
+ args.input_path))[0]) + '.minmax_recorded.circle'
## make a command to record min-max value of each tensor while running the representative dataset
record_minmax_cmd = Command(record_minmax_path, args, f)
@@ -420,7 +417,8 @@ def _quantize(args):
quant_model = getattr(args, 'output_path')
tmp_fake_quant_model = os.path.join(
tmpdir,
- os.path.splitext(os.path.basename(args.input_path))[0]) + '.fake_quant.circle'
+ os.path.splitext(os.path.basename(
+ args.input_path))[0]) + '.fake_quant.circle'
# do fake quantization
fake_quantize_cmd = Command(circle_quantizer_path, args, f)
@@ -442,6 +440,7 @@ def _quantize(args):
.add_noarg_option_if_valid_arg('--print_top5_match', 'print_top5_match') \
.run()
+
def _write_qparam(args):
# get file path to log
dir_path = os.path.dirname(os.path.realpath(__file__))
diff --git a/compiler/one-cmds/onecc b/compiler/one-cmds/onecc
index 8b79ff53f..9e8e1a6cf 100644
--- a/compiler/one-cmds/onecc
+++ b/compiler/one-cmds/onecc
@@ -146,8 +146,8 @@ def _parse_cfg(args):
def _is_available_driver(config, driver_name):
# if there's no `onecc` section, it will find `one-build` section because of backward compatibility
return (config.has_option('onecc', driver_name) and config.getboolean(
- 'onecc', driver_name)) or (config.has_option('one-build', driver_name) and config.getboolean(
- 'one-build', driver_name))
+ 'onecc', driver_name)) or (config.has_option('one-build', driver_name)
+ and config.getboolean('one-build', driver_name))
def _simple_warning(message, category, filename, lineno, file=None, line=None):
@@ -158,11 +158,11 @@ def _verify_cfg(import_driver_list, config):
if not config.has_section('onecc'):
if config.has_section('one-build'):
warnings.formatwarning = _simple_warning
- warnings.warn("[one-build] section will be deprecated. Please use [onecc] section.")
+ warnings.warn(
+ "[one-build] section will be deprecated. Please use [onecc] section.")
else:
raise ImportError('[onecc] section is required in configuration file')
-
import_driver_cnt = 0
for d in import_driver_list:
if _is_available_driver(config, d):
@@ -219,11 +219,12 @@ def main():
bin_dir = os.path.dirname(os.path.realpath(__file__))
import_drivers_dict = _utils._detect_one_import_drivers(bin_dir)
transform_drivers = [
- 'one-optimize', 'one-quantize', 'one-pack', 'one-codegen', 'one-profile', 'one-partition', 'one-infer'
+ 'one-optimize', 'one-quantize', 'one-pack', 'one-codegen', 'one-profile',
+ 'one-partition', 'one-infer'
]
_verify_cfg(import_drivers_dict, config)
- # verify optimization option file
+ # verify optimization option file
_verify_opt(args)
# get sections to run
diff --git a/compiler/one-cmds/onelib/constant.py b/compiler/one-cmds/onelib/constant.py
index a6a464601..4d330bd77 100644
--- a/compiler/one-cmds/onelib/constant.py
+++ b/compiler/one-cmds/onelib/constant.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
class CONSTANT:
__slots__ = () # This prevents access via __dict__.
OPTIMIZATION_OPTS = (
@@ -62,7 +63,8 @@ class CONSTANT:
('remove_unnecessary_slice', 'remove unnecessary slice ops'),
('remove_unnecessary_strided_slice', 'remove unnecessary strided slice ops'),
('remove_unnecessary_split', 'remove unnecessary split ops'),
- ('replace_non_const_fc_with_batch_matmul', 'replace FullyConnected op with non-const weights to BatchMatMul op'),
+ ('replace_non_const_fc_with_batch_matmul',
+ 'replace FullyConnected op with non-const weights to BatchMatMul op'),
('resolve_customop_add', 'convert Custom(Add) op to Add op'),
('resolve_customop_batchmatmul',
'convert Custom(BatchMatmul) op to BatchMatmul op'),
diff --git a/compiler/one-cmds/onelib/make_cmd.py b/compiler/one-cmds/onelib/make_cmd.py
index d8380f28d..71d8adf97 100644
--- a/compiler/one-cmds/onelib/make_cmd.py
+++ b/compiler/one-cmds/onelib/make_cmd.py
@@ -19,6 +19,7 @@ import sys
import onelib.constant as _constant
+
def _is_valid_attr(args, attr):
return hasattr(args, attr) and getattr(args, attr)
diff --git a/compiler/one-cmds/onnx_legalizer.py b/compiler/one-cmds/onnx_legalizer.py
index 26c2b75b9..0141514b6 100755
--- a/compiler/one-cmds/onnx_legalizer.py
+++ b/compiler/one-cmds/onnx_legalizer.py
@@ -341,7 +341,8 @@ def _dtype_to_np(dtype):
raise NotImplementedError('unsupported data type')
-def _generate_one_direction_RNN(transformer, X, W, R, B, initial_h, clip, activation_name):
+def _generate_one_direction_RNN(transformer, X, W, R, B, initial_h, clip,
+ activation_name):
"""Generate subgraph of one direction of unrolled RNN layer
Args:
@@ -395,7 +396,7 @@ def _generate_one_direction_RNN(transformer, X, W, R, B, initial_h, clip, activa
def _transform_unidirectional_RNN(transformer, original_node, x, tensor_infos, activation,
- clip, direction, hidden_size, layout):
+ clip, direction, hidden_size, layout):
"""Generate Simple (forward or reverse) unrolled RNN
Args:
@@ -432,7 +433,7 @@ def _transform_unidirectional_RNN(transformer, original_node, x, tensor_infos, a
else:
initial_h = None
state_tensors = _generate_one_direction_RNN(transformer, x, w, r, b, initial_h, clip,
- activation)
+ activation)
y_direction_dim = layout + 1
y_h_direction_dim = layout
state_layout_tensors = []
@@ -447,12 +448,11 @@ def _transform_unidirectional_RNN(transformer, original_node, x, tensor_infos, a
transformer.make_node(
'Unsqueeze', [state_tensors[-1]], [Y_h], axes=[y_h_direction_dim])
Y = outputs[0]
- transformer.make_node(
- 'Concat', state_layout_tensors, [Y], axis=seq_length_dim)
+ transformer.make_node('Concat', state_layout_tensors, [Y], axis=seq_length_dim)
def _transform_bidirectional_RNN(transformer, original_node, x, tensor_infos, activations,
- clip, hidden_size, layout):
+ clip, hidden_size, layout):
"""Generate Bidirectional unrolled RNN
Args:
@@ -503,10 +503,10 @@ def _transform_bidirectional_RNN(transformer, original_node, x, tensor_infos, ac
initial_h[d] = transformer.make_squeeze(initial_h[d], axes=[direction_dim])
state_f_tensors = _generate_one_direction_RNN(transformer, x, w[0], r[0], b[0],
- initial_h[0], clip, activations[0])
+ initial_h[0], clip, activations[0])
x.reverse()
state_b_tensors = _generate_one_direction_RNN(transformer, x, w[1], r[1], b[1],
- initial_h[1], clip, activations[1])
+ initial_h[1], clip, activations[1])
state_b_tensors.reverse()
y_direction_dim = layout + 1
@@ -538,8 +538,7 @@ def _transform_bidirectional_RNN(transformer, original_node, x, tensor_infos, ac
axis=y_h_direction_dim)
Y = outputs[0]
- transformer.make_node(
- 'Concat', state_layout_tensors, [Y], axis=seq_length_dim)
+ transformer.make_node('Concat', state_layout_tensors, [Y], axis=seq_length_dim)
def _legalize_RNN(transformer, tensor_infos, node):
@@ -600,10 +599,10 @@ def _legalize_RNN(transformer, tensor_infos, node):
if direction in ['forward', 'reverse']:
_transform_unidirectional_RNN(transformer, node, x, tensor_infos, activations[0],
- clip, direction, hidden_size, layout)
+ clip, direction, hidden_size, layout)
elif direction == 'bidirectional':
- _transform_bidirectional_RNN(transformer, node, x, tensor_infos, activations, clip,
- hidden_size, layout)
+ _transform_bidirectional_RNN(transformer, node, x, tensor_infos, activations,
+ clip, hidden_size, layout)
else:
raise RuntimeError('Unknown RNN type')
@@ -611,7 +610,7 @@ def _legalize_RNN(transformer, tensor_infos, node):
def _generate_one_direction_LSTM(transformer, X, W, R, B, initial_h, initial_c, P, clip,
- act, dtype, hidden_size, batch_size):
+ act, dtype, hidden_size, batch_size):
"""Generate subgraph for one direction of unrolled LSTM layer
Args:
@@ -754,7 +753,7 @@ def _generate_one_direction_LSTM(transformer, X, W, R, B, initial_h, initial_c,
def _transform_unidirectional_LSTM(transformer, original_node, x, tensor_infos,
- activations, clip, direction, hidden_size, layout):
+ activations, clip, direction, hidden_size, layout):
"""Generate Simple (forward or reverse) unrolled LSTM
Args:
@@ -818,17 +817,15 @@ def _transform_unidirectional_LSTM(transformer, original_node, x, tensor_infos,
transformer.make_node(
'Unsqueeze', [state_h_tensors[-1]], [Y_h], axes=[y_h_direction_dim])
Y_c = outputs[2]
- transformer.make_node(
- 'Unsqueeze', [state_c_tensor], [Y_c], axes=[y_h_direction_dim])
+ transformer.make_node('Unsqueeze', [state_c_tensor], [Y_c], axes=[y_h_direction_dim])
if direction == 'reverse':
state_layout_tensors.reverse()
Y = outputs[0]
- transformer.make_node(
- 'Concat', state_layout_tensors, [Y], axis=seq_length_dim)
+ transformer.make_node('Concat', state_layout_tensors, [Y], axis=seq_length_dim)
-def _transform_bidirectional_LSTM(transformer, original_node, x, tensor_infos, activations,
- clip, hidden_size, layout):
+def _transform_bidirectional_LSTM(transformer, original_node, x, tensor_infos,
+ activations, clip, hidden_size, layout):
"""Generate Bidirectional unrolled LSTM
Args:
@@ -929,12 +926,10 @@ def _transform_bidirectional_LSTM(transformer, original_node, x, tensor_infos, a
Y_f_c = transformer.make_unsqueeze(state_f_c_tensor, axes=[y_c_direction_dim])
Y_b_c = transformer.make_unsqueeze(state_b_c_tensor, axes=[y_c_direction_dim])
Y_c = outputs[2]
- transformer.make_node(
- 'Concat', [Y_f_c, Y_b_c], [Y_c], axis=y_c_direction_dim)
+ transformer.make_node('Concat', [Y_f_c, Y_b_c], [Y_c], axis=y_c_direction_dim)
Y = outputs[0]
- transformer.make_node(
- 'Concat', state_layout_tensors, [Y], axis=seq_length_dim)
+ transformer.make_node('Concat', state_layout_tensors, [Y], axis=seq_length_dim)
def _legalize_LSTM(transformer, tensor_infos, node):
@@ -1001,10 +996,10 @@ def _legalize_LSTM(transformer, tensor_infos, node):
if direction in ['forward', 'reverse']:
_transform_unidirectional_LSTM(transformer, node, x, tensor_infos, activations,
- clip, direction, hidden_size, layout)
+ clip, direction, hidden_size, layout)
elif direction == 'bidirectional':
_transform_bidirectional_LSTM(transformer, node, x, tensor_infos, activations,
- clip, hidden_size, layout)
+ clip, hidden_size, layout)
else:
raise RuntimeError('Unknown LSTM type')
@@ -1052,10 +1047,12 @@ def legalize(model, options):
if __name__ == '__main__':
if len(sys.argv) < 3:
- print('usage: ./legalize_onnx.py <path to input model> <path to output model>\n'
- '\n'
- ' In stand-alone utility mode this tool provides basic funtionality\n'
- ' If you want to have more control over applied transformations, use this legalizer as a library')
+ print(
+ 'usage: ./legalize_onnx.py <path to input model> <path to output model>\n'
+ '\n'
+ ' In stand-alone utility mode this tool provides basic funtionality\n'
+ ' If you want to have more control over applied transformations, use this legalizer as a library'
+ )
exit(1)
options = LegalizeOptions()
options.unroll_lstm = True
diff --git a/compiler/one-cmds/utils.py b/compiler/one-cmds/utils.py
index 663ee82a8..7fb615216 100644
--- a/compiler/one-cmds/utils.py
+++ b/compiler/one-cmds/utils.py
@@ -61,41 +61,42 @@ def is_accumulated_arg(arg, driver):
def _is_valid_attr(args, attr):
return hasattr(args, attr) and getattr(args, attr)
+
class Command:
- def __init__(self, driver, args, log_file):
- self.cmd = [driver]
- self.driver = driver
- self.args = args
- self.log_file = log_file
-
- # Add option if attrs are valid
- # Option values are collected from self.args
- def add_option_with_valid_args(self, option, attrs):
- for attr in attrs:
- if not _is_valid_attr(self.args, attr):
+ def __init__(self, driver, args, log_file):
+ self.cmd = [driver]
+ self.driver = driver
+ self.args = args
+ self.log_file = log_file
+
+ # Add option if attrs are valid
+ # Option values are collected from self.args
+ def add_option_with_valid_args(self, option, attrs):
+ for attr in attrs:
+ if not _is_valid_attr(self.args, attr):
+ return self
+ self.cmd.append(option)
+ for attr in attrs:
+ self.cmd.append(getattr(self.args, attr))
+ return self
+
+ # Add option and values without any condition
+ def add_option_with_values(self, option, values):
+ self.cmd.append(option)
+ for value in values:
+ self.cmd.append(value)
return self
- self.cmd.append(option)
- for attr in attrs:
- self.cmd.append(getattr(self.args, attr))
- return self
-
- # Add option and values without any condition
- def add_option_with_values(self, option, values):
- self.cmd.append(option)
- for value in values:
- self.cmd.append(value)
- return self
-
- # Add option with no argument (ex: --verbose) if attr is valid
- def add_noarg_option_if_valid_arg(self, option, attr):
- if _is_valid_attr(self.args, attr):
- self.cmd.append(option)
- return self
-
- # Run cmd and save logs
- def run(self):
- self.log_file.write((' '.join(self.cmd) + '\n').encode())
- _run(self.cmd, err_prefix=self.driver, logfile=self.log_file)
+
+ # Add option with no argument (ex: --verbose) if attr is valid
+ def add_noarg_option_if_valid_arg(self, option, attr):
+ if _is_valid_attr(self.args, attr):
+ self.cmd.append(option)
+ return self
+
+ # Run cmd and save logs
+ def run(self):
+ self.log_file.write((' '.join(self.cmd) + '\n').encode())
+ _run(self.cmd, err_prefix=self.driver, logfile=self.log_file)
def _parse_cfg_and_overwrite(config_path, section, args):
@@ -189,8 +190,7 @@ def _run(cmd, err_prefix=None, logfile=None):
err_prefix: prefix to be put before every stderr lines
logfile: file stream to which both of stdout and stderr lines will be written
"""
- with subprocess.Popen(
- cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p:
+ with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p:
import select
inputs = set([p.stdout, p.stderr])
while inputs: