diff options
Diffstat (limited to 'model-optimizer/mo/middle/passes/shape.py')
-rw-r--r-- | model-optimizer/mo/middle/passes/shape.py | 174 |
1 files changed, 111 insertions, 63 deletions
diff --git a/model-optimizer/mo/middle/passes/shape.py b/model-optimizer/mo/middle/passes/shape.py index b0285edd4..647502bf9 100644 --- a/model-optimizer/mo/middle/passes/shape.py +++ b/model-optimizer/mo/middle/passes/shape.py @@ -21,7 +21,7 @@ import numpy as np from mo.front.extractor import update_attrs from mo.graph.graph import Node, create_edge -from mo.middle.passes.eliminate import remove_op_node, merge_data_nodes, graph_clean_up_tf +from mo.middle.passes.eliminate import remove_op_node_with_data_node, merge_data_nodes, graph_clean_up_tf, get_nodes_with_attributes from mo.middle.passes.fusing.helpers import get_next_operation from mo.middle.pattern_match import apply_pattern from mo.ops.op import PermuteAttrs, Op @@ -38,12 +38,10 @@ def reshape_squeeze_transform(graph: nx.MultiDiGraph, match: dict): reshape['shape'] = output.shape reshape.op = 'Reshape' reshape['type'] = 'Reshape' - reshape['axis'] = 0 # TODO what does it mean? if not reshape.has_valid('dim'): # do not override value 'dim' if it is set. It may contain specific values like -1 and 0 reshape['dim'] = reshape.shape.copy() update_attrs(reshape, 'shape_attrs', 'dim') - reshape['num_axes'] = -1 # TODO what does it mean? if 'shape' in match: graph.remove_edge(match['shape'].node, match['reshape'].node) @@ -71,60 +69,81 @@ def convert_reshape(graph: nx.MultiDiGraph): ) +def can_repack_fully_connected_weights_nhwc_to_nchw(fc_node: Node): + """ + Checks that it is possible to repack weights of the FullyConnected layer if the Reshape layer is the input of the + FullyConnected and satisfies several conditions. + :param fc_node: the FullyConnected node to check + :return: the result of the check + """ + if len(fc_node.in_node(0).in_nodes()) != 1: + return False + + reshape_node = fc_node.in_node(0).in_node(0) + if not reshape_node.has_valid('type') or reshape_node.type != 'Reshape': + return False + + if not reshape_node.in_node(0).has_valid('shape') or not reshape_node.out_node().has_valid('shape'): + return False + + orig_shape = reshape_node.in_node(0).shape + new_shape = reshape_node.out_node().shape + + # TODO a bit conservative condition; relax it checking specific dimensions that are involved in + # NHWC to NCWH translation + if len(orig_shape) == len(new_shape) and all(orig_shape == new_shape): + return False + + # TODO here is a couple of limitations that makes this pass simpler; consider to relax them + if len(orig_shape) == 4 and len(new_shape) == 2 and orig_shape[0] == new_shape[0]: + # that means orig_shape is in NCHW and new_shape is in NC + # and we need to map CHW part to C after HWC to CHW transform + # Assuming that FullyConnected weights haven't been converted from IO to OI yet. + # So format is IO. + return True + else: + log.warning("Cannot do the complete NHWC to NCHW translation for FullyConnected weights. " + "The final model can be broken.") + return False + + def repack_fully_connected_weights_nhwc_to_nchw(graph: nx.MultiDiGraph): """ Repack weights of FullyConnected layer as a part of nhwc_to_nchw translation if Reshape of that involves dimensions that we are repacking appears right before FullyConnected layer. """ - for node in graph.nodes(): - node = Node(graph, node) - if node.has_valid('type') and node.type == 'FullyConnected': - assert node.in_node(0).kind == 'data' - if len(node.in_node(0).in_nodes()) == 1: - input = node.in_node(0).in_node(0) - if input.has_valid('type') and input.type == 'Reshape': - assert len(input.in_nodes()) > 0 - if input.in_node(0).has_valid('shape') and input.out_node().has_valid('shape'): - - orig_shape = input.in_node(0).shape - new_shape = input.out_node().shape - - # TODO a bit conservative condition; relax it checking specific dimensions - # that are involved in NHWC to NCWH translation - if len(orig_shape) != len(new_shape) or any(orig_shape != new_shape): - # OK, here we are; need to repack node.in_node(1) to maintain it compatible with original - # input order - - # TODO here is a couple of limitations that makes this pass simpler; consider to relax them - if len(orig_shape) == 4 and len(new_shape) == 2 and orig_shape[0] == new_shape[0]: - # that means orig_shape is in NCHW and new_shape is in NC - # and we need to map CHW part to C after HWC to CHW transform - # Assuming that FullyConnected weights haven't been converted from IO to OI yet. - # So format is IO. - - assert all(orig_shape != -1), 'Input shape for {} can not be negative.'.format(node.id) - assert all(new_shape != -1), 'Output shape for {} can not be negative.'.format(node.id) - assert orig_shape[1] * orig_shape[2] * orig_shape[3] == new_shape[1], \ - 'Input shape does not correspond to output shape for layer {}.'.format(node.id) - assert node.in_node(1).has_valid('value'), 'Node {} does not have value.'.format(node.id) - - weights = node.in_node(1) - - log.debug("orig_shape = {}".format(orig_shape)) - log.debug("new_shape = {}".format(new_shape)) - log.debug("weights.shape = {}".format(weights.shape)) - log.debug("weights.shape[1] = {}, new_shape[1] = {}".format(weights.shape[1], new_shape[1])) - - assert weights.shape[0] == new_shape[1], \ - 'First dim of weights does not correspond to output shape of {}'.format(node.id) - # interpret I dimension of the weights as packed HWC - # orig shape is already converted to NCHW, so provide transposed order for I repacking - tmp_shape = (orig_shape[2], orig_shape[3], orig_shape[1], weights.shape[1]) - weights.value = np.transpose(weights.value.reshape(tmp_shape), (2, 0, 1, 3)).reshape( - weights.shape) - else: - log.warning("Cannot do the complete NHWC to NCHW translation for FullyConnected weights. " - "The final model can be broken.") + for node_id in get_nodes_with_attributes(graph, type='FullyConnected'): + fc_node = Node(graph, node_id) + + if not can_repack_fully_connected_weights_nhwc_to_nchw(fc_node): + continue + + reshape_node = fc_node.in_node(0).in_node(0) + + orig_shape = reshape_node.in_node(0).shape + new_shape = reshape_node.out_node().shape + + # OK, here we are; need to repack fc_node.in_node(1) to maintain it compatible with original input order + + assert all(orig_shape != -1), 'Input shape for {} can not be negative.'.format(fc_node.id) + assert all(new_shape != -1), 'Output shape for {} can not be negative.'.format(fc_node.id) + assert orig_shape[1] * orig_shape[2] * orig_shape[3] == new_shape[1], \ + 'Input shape does not correspond to output shape for layer {}.'.format(fc_node.id) + assert fc_node.in_node(1).has_valid('value'), 'Node {} does not have value.'.format(fc_node.id) + + weights = fc_node.in_node(1) + + log.debug("orig_shape = {}".format(orig_shape)) + log.debug("new_shape = {}".format(new_shape)) + log.debug("weights.shape = {}".format(weights.shape)) + log.debug("weights.shape[1] = {}, new_shape[1] = {}".format(weights.shape[1], new_shape[1])) + + assert weights.shape[0] == new_shape[1], \ + 'First dim of weights does not correspond to output shape of {}'.format(fc_node.id) + # interpret I dimension of the weights as packed HWC + # orig shape is already converted to NCHW, so provide transposed order for I repacking + tmp_shape = (orig_shape[2], orig_shape[3], orig_shape[1], weights.shape[1]) + weights.value = np.transpose(weights.value.reshape(tmp_shape), (2, 0, 1, 3)).reshape(weights.shape) def apply_nhwc_to_nchw_permutation(graph: nx.MultiDiGraph): @@ -230,7 +249,10 @@ def permute_op_nodes_attrs(graph: nx.MultiDiGraph): for node in graph.nodes(): node = Node(graph, node) if node.kind == 'op' and node.has_valid('permute_attrs'): - node.permute_attrs.permute_attrs(node) + try: + node.permute_attrs.permute_attrs(node) + except Exception as e: + raise Error('Can\'t permute attrs for node {}. Error message: {}'.format(node.id, e)) def reverse_input_channels(graph: nx.MultiDiGraph): @@ -333,20 +355,30 @@ def reverse_input_channels(graph: nx.MultiDiGraph): def conv_flatten_concat_action(graph: nx.MultiDiGraph, match: dict): + assert graph.graph['layout'] == 'NHWC' reshape_node = match['reshape'] reshape_data_node = match['reshape_data'] - concat_node = match['concat'] - concat_data_node = match['concat_data'] conv_name = match['conv'].name conv_data_node = match['conv_data'] + # the pattern should be applied only in case when the reshape operation changes number of dimensions + if len(reshape_data_node.shape) == len(conv_data_node.shape) or reshape_node.has_and_set('nchw_layout'): + return + + if len(reshape_data_node.out_nodes()) == 1 and reshape_data_node.out_node().has_valid('type') and \ + reshape_data_node.out_node().type == 'FullyConnected' and \ + can_repack_fully_connected_weights_nhwc_to_nchw(reshape_data_node.out_node()): + log.info('There is a FullyConnected layer after the node "{}" which weights will be repacked. So there is no ' + 'need to insert Permute'.format(reshape_node.soft_get('name'))) + return assert len(graph.in_edges(reshape_node.id)) == 1 graph.remove_edge(conv_data_node.id, reshape_node.id) - new_permute_op = Permute(graph, {'order': np.array([0, 2, 3, 1])}) + + permutation_order = PermuteAttrs.get_nchw_to_nhwc_permutation(len(conv_data_node.shape)).perm + new_permute_op = Permute(graph, {'order': permutation_order}) permute_data_node = new_permute_op.create_node_with_data([conv_data_node], dict(name=conv_name + '/Permute_')) create_edge(permute_data_node, reshape_node) - # Disable permutation for Reshape and Concat + # Disable permutation for Reshape and Concat layers attributes PermuteAttrs.set_permutation(reshape_node, reshape_data_node, None) - PermuteAttrs.set_permutation(concat_node, concat_data_node, None, skip_if_exists=True) def conv_flatten_concat(graph: nx.MultiDiGraph): @@ -357,15 +389,31 @@ def conv_flatten_concat(graph: nx.MultiDiGraph): ('conv_data', dict(kind='data')), ('reshape', dict(kind='op', type='Reshape')), ('reshape_data', dict(kind='data')), - ('concat', dict(kind='op', type='Concat')), - ('concat_data', dict(kind='data')) ], edges=[ ('conv', 'conv_data'), ('conv_data', 'reshape'), ('reshape', 'reshape_data'), - ('reshape_data', 'concat'), - ('concat', 'concat_data') + ], + action=conv_flatten_concat_action + ) + + apply_pattern( + graph, + nodes=[ + ('real_conv', dict(kind='op', type='Convolution')), + ('real_conv_data', dict(kind='data')), + ('conv', dict(kind='op', type='ReLU')), + ('conv_data', dict(kind='data')), + ('reshape', dict(kind='op', type='Reshape')), + ('reshape_data', dict(kind='data')), + ], + edges=[ + ('real_conv', 'real_conv_data'), + ('real_conv_data', 'conv'), + ('conv', 'conv_data'), + ('conv_data', 'reshape'), + ('reshape', 'reshape_data'), ], action=conv_flatten_concat_action ) @@ -390,4 +438,4 @@ def fuse_sequence_of_reshapes(graph: nx.MultiDiGraph): # Detected Reshape1 --> data --> Reshape2 pattern without side edges # Remove Reshape1 log.debug('Second phase for Reshape: {}'.format(node.name)) - remove_op_node(graph, node) + remove_op_node_with_data_node(graph, node) |