summaryrefslogtreecommitdiff
path: root/model-optimizer/mo/pipeline/onnx.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/mo/pipeline/onnx.py')
-rw-r--r--model-optimizer/mo/pipeline/onnx.py80
1 files changed, 40 insertions, 40 deletions
diff --git a/model-optimizer/mo/pipeline/onnx.py b/model-optimizer/mo/pipeline/onnx.py
index ded5cdb02..fcb6dc27d 100644
--- a/model-optimizer/mo/pipeline/onnx.py
+++ b/model-optimizer/mo/pipeline/onnx.py
@@ -20,53 +20,48 @@ from __future__ import print_function
from __future__ import unicode_literals
import argparse
-import copy
import logging as log
-import onnx
-import os
-
import numpy as np
-from mo.front.common.custom_replacement_registry import CustomReplacementRegistry
-from mo.front.common.find_unsupported_ops import find_unsupported_ops
+from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize
+from extensions.middle.NormalizeFullyConnected import NormalizeFullyConnected
from mo.front.common.register_custom_ops import check_for_duplicates
from mo.front.common.register_custom_ops import update_extractors_with_extensions
-from mo.front.extractor import restore_edges, add_output_ops, add_input_ops, \
+from mo.front.extractor import add_output_ops, add_input_ops, \
extract_node_attrs, create_tensor_nodes, remove_output_ops, user_data_repack
from mo.front.onnx.extractor import common_onnx_fields, onnx_op_extractor, onnx_op_extractors
from mo.front.onnx.loader import load_onnx_model, protobuf2nx
-from mo.middle.passes.conv import convert_add_to_scaleshift, \
- convert_weights_yxio_to_oiyx, convert_weights_yxio_to_goiyx, convert_gemm_to_fully_connected, \
- convert_muladd_to_scaleshift_or_power, fuse_pad, transpose_fully_connected_weights, \
- convert_dilated_convolution, convert_mul_to_scaleshift, convert_nasnet
-from mo.middle.passes.fusing.decomposition import convert_batch_norm, convert_scale_shift_to_mul_add
+from mo.middle.passes.conv import convert_add_to_scaleshift, convert_gemm_to_fully_connected, \
+ convert_muladd_to_scaleshift_or_power, fuse_pad, convert_dilated_convolution, convert_mul_to_scaleshift
from mo.middle.passes.eliminate import graph_clean_up, remove_op_nodes, remove_useless_split
+from mo.middle.passes.fusing.decomposition import convert_batch_norm, convert_scale_shift_to_mul_add
+from mo.middle.passes.fusing.fuse_grouped_conv import grouped_convolutions_fusing
from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
-from mo.middle.passes.fusing.fuse_grouped_conv import grouped_convolutions_fusing
from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
-from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
from mo.middle.passes.infer import scale_input, override_placeholder_shapes, partial_infer, convert_mul_add_to_power, \
update_fully_connected_shapes, add_mean_scale_values, override_batch
-from mo.middle.passes.l2normalization import l2_norm_to_norm
-from mo.middle.passes.pool import mean_to_avgpool
-from mo.middle.passes.shape import convert_squeeze, convert_reshape, convert_nhwc_to_nchw, reverse_input_channels, \
- conv_flatten_concat, fuse_sequence_of_reshapes
-from mo.utils import class_registration
+from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
+from mo.middle.passes.shape import convert_reshape, reverse_input_channels, \
+ fuse_sequence_of_reshapes, merge_nodes_permutations, permute_data_nodes_attrs, permute_op_nodes_attrs
from mo.pipeline.common import prepare_emit_ir
-from mo.utils.custom_replacement_config import update_custom_replacement_config_file
+from mo.utils import class_registration
+from mo.utils.cli_parser import get_meta_info
from mo.utils.error import Error
from mo.utils.utils import refer_to_faq_msg
+from mo.graph.graph import check_empty_graph
def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: str, outputs: list, output_dir: str,
- scale: float,
- user_shapes: [None, list, np.array] = None,
- mean_scale_values: [dict, list] = ()):
+ scale: float,
+ user_shapes: [None, list, np.array] = None,
+ mean_scale_values: [dict, list] = ()):
+
+ meta_info = get_meta_info(argv)
model_proto = load_onnx_model(model_file_name)
- model_graph = model_proto.graph
+ model_graph = model_proto.graph # pylint: disable=no-member
#print(model_graph)
#assert len(model_graph) == 1, "An ONNX model contains more than 1 graph: unsupported"
log.debug("Number of nodes in graph_def: {}".format(len(model_graph.node)))
@@ -78,11 +73,12 @@ def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: st
try:
graph = protobuf2nx(model_proto)
log.debug("Number of nodes in NX graph: {}".format(graph.number_of_nodes()))
- graph.__setattr__('name', output_model_name if output_model_name else model_proto.graph.name)
+ graph.__setattr__('name', output_model_name if output_model_name else model_proto.graph.name) # pylint: disable=no-member
graph.graph['layout'] = 'NCHW'
graph.graph['cmd_params'] = argv
graph.graph['fw'] = 'onnx'
graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
+ graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else 3
# extract basic attributes earlier to enable some passes that relies on them before full attribute
# extractor is called
extract_node_attrs(graph, lambda node: (True, common_onnx_fields(node)))
@@ -94,15 +90,15 @@ def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: st
model_file_name,
str(e)
) from e
+ check_empty_graph(graph, 'protobuf2nx. It may happen due to problems with loaded model')
+ packed_user_shapes, packed_outputs, _ = user_data_repack(graph, user_shapes, outputs, None)
- user_shapes, outputs, _ = user_data_repack(graph, user_shapes, outputs, None)
-
- graph, output_op_nodes = add_output_ops(graph, outputs)
- graph, input_op_nodes = add_input_ops(graph, user_shapes, True)
+ output_op_nodes = add_output_ops(graph, packed_outputs)
+ input_op_nodes = add_input_ops(graph, packed_user_shapes, True)
# this call of 'graph_clean_up' removes child nodes of outputs which is useful when custom output is specified
graph_clean_up(graph)
-
+ check_empty_graph(graph, 'add_output_ops and add_input_ops')
extract_node_attrs(graph, lambda node: onnx_op_extractor(node, check_for_duplicates(onnx_op_extractors)))
class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
@@ -110,7 +106,7 @@ def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: st
create_tensor_nodes(graph)
graph_clean_up(graph)
- override_placeholder_shapes(graph, user_shapes)
+ override_placeholder_shapes(graph, packed_user_shapes)
override_batch(graph, argv.batch)
graph_clean_up(graph)
@@ -122,11 +118,11 @@ def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: st
partial_infer(graph)
graph_clean_up(graph)
+ check_empty_graph(graph, 'partial_infer')
-
- graph, input_op_nodes = add_input_ops(graph, user_shapes, False)
+ input_op_nodes = add_input_ops(graph, packed_user_shapes, False)
graph_clean_up(graph)
-
+ check_empty_graph(graph, 'add_input_ops')
#change_placeholders_types_to_FP32(graph)
scale_input(graph, scale)
@@ -143,6 +139,7 @@ def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: st
class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
convert_gemm_to_fully_connected(graph)
+ NormalizeFullyConnected().find_and_replace_pattern(graph)
fuse_pad(graph)
graph_clean_up(graph)
@@ -179,12 +176,7 @@ def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: st
graph_clean_up(graph)
convert_mul_add_to_power(graph)
-
- # Need to eliminate dead nodes before doing update_fully_connected_shapes
- # because update_fully_connected_shapes does partial inference and dead
- # nodes will lead to sporadic failures.
graph_clean_up(graph)
- update_fully_connected_shapes(graph)
convert_reshape(graph)
convert_add_to_scaleshift(graph) # scale = 1
@@ -203,8 +195,16 @@ def driver(argv: argparse.Namespace, model_file_name: str, output_model_name: st
fuse_sequence_of_reshapes(graph)
graph_clean_up(graph)
+ pattern = EltwiseInputNormalize()
+ pattern.find_and_replace_pattern(graph)
+
+ merge_nodes_permutations(graph)
+ permute_data_nodes_attrs(graph)
+ permute_op_nodes_attrs(graph)
+
class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)
- prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name)
+ prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
+ meta_info=meta_info)
return 0