diff options
Diffstat (limited to 'model-optimizer/mo/middle/passes/shape.py')
-rw-r--r-- | model-optimizer/mo/middle/passes/shape.py | 316 |
1 files changed, 182 insertions, 134 deletions
diff --git a/model-optimizer/mo/middle/passes/shape.py b/model-optimizer/mo/middle/passes/shape.py index cfb4d89f0..b0285edd4 100644 --- a/model-optimizer/mo/middle/passes/shape.py +++ b/model-optimizer/mo/middle/passes/shape.py @@ -19,11 +19,12 @@ import logging as log import networkx as nx import numpy as np -from mo.front.common.layout import nchw_to_nhwc_permute, nhwc_to_nchw_permute from mo.front.extractor import update_attrs from mo.graph.graph import Node, create_edge -from mo.middle.passes.eliminate import remove_op_node +from mo.middle.passes.eliminate import remove_op_node, merge_data_nodes, graph_clean_up_tf +from mo.middle.passes.fusing.helpers import get_next_operation from mo.middle.pattern_match import apply_pattern +from mo.ops.op import PermuteAttrs, Op from mo.ops.permute import Permute from mo.utils.error import Error from mo.utils.utils import refer_to_faq_msg @@ -38,7 +39,9 @@ def reshape_squeeze_transform(graph: nx.MultiDiGraph, match: dict): reshape.op = 'Reshape' reshape['type'] = 'Reshape' reshape['axis'] = 0 # TODO what does it mean? - reshape['dim'] = reshape.shape.copy() + if not reshape.has_valid('dim'): + # do not override value 'dim' if it is set. It may contain specific values like -1 and 0 + reshape['dim'] = reshape.shape.copy() update_attrs(reshape, 'shape_attrs', 'dim') reshape['num_axes'] = -1 # TODO what does it mean? if 'shape' in match: @@ -52,9 +55,8 @@ def convert_squeeze(graph: nx.MultiDiGraph): ('reshape', dict(kind='op', op='Squeeze')), ('output', dict(kind='data'))], edges=[('reshape', 'output')], - action=reshape_squeeze_transform, - node_attrs=['kind', 'op'], - edge_attrs=[]) + action=reshape_squeeze_transform + ) def convert_reshape(graph: nx.MultiDiGraph): @@ -65,129 +67,170 @@ def convert_reshape(graph: nx.MultiDiGraph): ('reshape', dict(kind='op', op='Reshape')), ('output', dict(kind='data'))], edges=[('shape', 'reshape', {'in': 1}), ('reshape', 'output')], - action=reshape_squeeze_transform, - node_attrs=['kind', 'op'], - edge_attrs=['in']) - - -def permute_attrs(attrs: dict, perm: np.ndarray, inv: np.ndarray): - log.debug("perm = {}, inv = {}".format(perm, inv)) - assert len(perm) == len(inv) - assert all(perm[inv] == range(len(perm))) - if 'dim_attrs' in attrs: - for a in attrs['dim_attrs']: # inv is applicable for dim_attrs - if a in attrs and attrs[a] is not None: - try: - attrs[a] = inv[attrs[a]] if not isinstance(a, np.ndarray) else np.array(inv[attrs[a]]) - except: - raise Error("Can not transpose attribute '{}' with value {} for node '{}'. {}".format(a, - attrs[a], attrs['name'] if 'name' in attrs else '<unknown_name>', refer_to_faq_msg(98))) - if 'shape_attrs' in attrs: - for a in attrs['shape_attrs']: # perm is applicable for shape_attrs - if a in attrs and attrs[a] is not None: - length = 0 - try: - length = len(attrs[a]) - except TypeError: - log.warning("Can not transpose attribute '{}' with value {} for node '{}'.".format(a, attrs[a], attrs['name'] if 'name' in attrs else '<unknown_name>')) - continue - if length == 4: - try: - if not isinstance(attrs[a], np.ndarray): - attrs[a] = np.array(attrs[a]) - log.debug("a = {}, attrs[a] = {}, type(attrs[a]) = {}".format(a, attrs[a], type(attrs[a]))) - attrs[a] = np.array(attrs[a][perm]) - except: - raise Error("Can not transpose attribute '{}' with value {} for node '{}'. {}".format(a, - attrs[a], attrs['name'] if 'name' in attrs else '<unknown_name>', refer_to_faq_msg(98))) - - -def permute_value(attrs: dict, perm: np.ndarray): - if 'value' in attrs and attrs['value'] is not None and attrs['value'].ndim == 4: - log.debug( - 'Permutation {} of value with shape {} for node "{}". Expected target shape: {}.'.format( - perm, - attrs['value'].shape, - attrs['name'] if 'name' in attrs else '<NONE>', - attrs['shape'] if 'shape' in attrs else '<NONE>' - ) - ) - assert 'shape' in attrs and len(attrs['shape']) == 4 - attrs['value'] = np.transpose(attrs['value'], perm) - assert (list(attrs['shape']) == list(attrs['value'].shape)) - - -def repack_fully_connected_weights_nhwc_to_nchw(node: Node): + action=reshape_squeeze_transform + ) + + +def repack_fully_connected_weights_nhwc_to_nchw(graph: nx.MultiDiGraph): """ Repack weights of FullyConnected layer as a part of nhwc_to_nchw translation if Reshape of that involves dimensions that we are repacking appears right before FullyConnected layer. """ - if node.has_valid('type') and node.type == 'FullyConnected': - assert node.in_node(0).kind == 'data' - if len(node.in_node(0).in_nodes()) == 1: - input = node.in_node(0).in_node(0) - if input.has_valid('type') and input.type == 'Reshape': - assert len(input.in_nodes()) > 0 - if input.in_node(0).has_valid('shape') and input.out_node().has_valid('shape'): - - orig_shape = input.in_node(0).shape - new_shape = input.out_node().shape - - # TODO a bit conservative condition; relax it checking specific dimensions - # that are involved in NHWC to NCWH translation - if len(orig_shape) != len(new_shape) or any(orig_shape != new_shape): - # OK, here we are; need to repack node.in_node(1) to maintain it compatible with original - # input order - - # TODO here is a couple of limitations that makes this pass simpler; consider to relax them - if len(orig_shape) == 4 and len(new_shape) == 2 and orig_shape[0] == new_shape[0]: - # that means orig_shape is in NCHW and new_shape is in NC - # and we need to map CHW part to C after HWC to CHW transform - # Assuming that FullyConnected weights haven't been converted from IO to OI yet. - # So format is IO. - - assert all(orig_shape != -1), 'Input shape for {} can not be negative.'.format(node.id) - assert all(new_shape != -1), 'Output shape for {} can not be negative.'.format(node.id) - assert orig_shape[1] * orig_shape[2] * orig_shape[3] == new_shape[1], \ - 'Input shape does not correspond to output shape for layer {}.'.format(node.id) - assert node.in_node(1).has_valid('value'), 'Node {} does not have value.'.format(node.id) - - weights = node.in_node(1) - - log.debug("orig_shape = {}".format(orig_shape)) - log.debug("new_shape = {}".format(new_shape)) - log.debug("weights.shape = {}".format(weights.shape)) - log.debug("weights.shape[1] = {}, new_shape[1] = {}".format(weights.shape[1], new_shape[1])) - - assert weights.shape[0] == new_shape[1], \ - 'First dim of weights does not correspond to output shape of {}'.format(node.id) - # interpret I dimension of the weights as packed HWC - # orig shape is already converted to NCHW, so provide transposed order for I repacking - tmp_shape = (orig_shape[2], orig_shape[3], orig_shape[1], weights.shape[1]) - weights.value = np.transpose(weights.value.reshape(tmp_shape), (2, 0, 1, 3)).reshape( - weights.shape) - - else: - log.warning("Cannot do the complete NHWC to NCHW translation for FullyConnected weights. " + - "The final model can be broken.") - - -def convert_nhwc_to_nchw(graph: nx.MultiDiGraph): - """ - It doesn't check if the model really has NHWC, it assumes this. - So it is just do global permute for all data (without values) and ops. - """ - nodes = nx.topological_sort(graph) - for node in nodes: + for node in graph.nodes(): + node = Node(graph, node) + if node.has_valid('type') and node.type == 'FullyConnected': + assert node.in_node(0).kind == 'data' + if len(node.in_node(0).in_nodes()) == 1: + input = node.in_node(0).in_node(0) + if input.has_valid('type') and input.type == 'Reshape': + assert len(input.in_nodes()) > 0 + if input.in_node(0).has_valid('shape') and input.out_node().has_valid('shape'): + + orig_shape = input.in_node(0).shape + new_shape = input.out_node().shape + + # TODO a bit conservative condition; relax it checking specific dimensions + # that are involved in NHWC to NCWH translation + if len(orig_shape) != len(new_shape) or any(orig_shape != new_shape): + # OK, here we are; need to repack node.in_node(1) to maintain it compatible with original + # input order + + # TODO here is a couple of limitations that makes this pass simpler; consider to relax them + if len(orig_shape) == 4 and len(new_shape) == 2 and orig_shape[0] == new_shape[0]: + # that means orig_shape is in NCHW and new_shape is in NC + # and we need to map CHW part to C after HWC to CHW transform + # Assuming that FullyConnected weights haven't been converted from IO to OI yet. + # So format is IO. + + assert all(orig_shape != -1), 'Input shape for {} can not be negative.'.format(node.id) + assert all(new_shape != -1), 'Output shape for {} can not be negative.'.format(node.id) + assert orig_shape[1] * orig_shape[2] * orig_shape[3] == new_shape[1], \ + 'Input shape does not correspond to output shape for layer {}.'.format(node.id) + assert node.in_node(1).has_valid('value'), 'Node {} does not have value.'.format(node.id) + + weights = node.in_node(1) + + log.debug("orig_shape = {}".format(orig_shape)) + log.debug("new_shape = {}".format(new_shape)) + log.debug("weights.shape = {}".format(weights.shape)) + log.debug("weights.shape[1] = {}, new_shape[1] = {}".format(weights.shape[1], new_shape[1])) + + assert weights.shape[0] == new_shape[1], \ + 'First dim of weights does not correspond to output shape of {}'.format(node.id) + # interpret I dimension of the weights as packed HWC + # orig shape is already converted to NCHW, so provide transposed order for I repacking + tmp_shape = (orig_shape[2], orig_shape[3], orig_shape[1], weights.shape[1]) + weights.value = np.transpose(weights.value.reshape(tmp_shape), (2, 0, 1, 3)).reshape( + weights.shape) + else: + log.warning("Cannot do the complete NHWC to NCHW translation for FullyConnected weights. " + "The final model can be broken.") + + +def apply_nhwc_to_nchw_permutation(graph: nx.MultiDiGraph): + # Add NHWC to NCHW permutation for all data nodes (only for nodes without permutation) + if graph.graph['layout'] == 'NCHW': + return + for node in graph.nodes(): + node = Node(graph, node) + if node.kind == 'data': + if node.has_and_set('nchw_layout'): + continue + + # Get NHWC to NCHW permutation for N dims, where N = len(node.shape) + permutation = PermuteAttrs().get_nhwc_to_nchw_permutation(len(node.shape)) + + # Check that data node already has permutation + skip_permutation = False + for in_node in node.in_nodes(): + edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0] + if 'permutation' in edge_attrs: + skip_permutation = True + for out_node in node.out_nodes(): + edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0] + if 'permutation' in edge_attrs: + skip_permutation = True + + if skip_permutation: + continue + + # Set permutation to all in/out edges + for in_node in node.in_nodes(): + PermuteAttrs.set_permutation(in_node, node, permutation) + + for out_node in node.out_nodes(): + PermuteAttrs.set_permutation(node, out_node, permutation) + + +def merge_nodes_permutations(graph: nx.MultiDiGraph): + # Iterate over all data nodes and check all permutations for similarity + # In case of equal permutations, this permutation will be set as attribute for data node + # otherwise exception will be raised + for node in graph.nodes(): + node = Node(graph, node) + if node.kind != 'data': + continue + + permutations = [] + + # Get all permutations from in edges + for in_node in node.in_nodes(): + edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0] + if 'permutation' in edge_attrs: + permutations.append(edge_attrs['permutation']) + + # Get all permutations from out edges + for out_node in node.out_nodes(): + edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0] + if 'permutation' in edge_attrs: + permutations.append(edge_attrs['permutation']) + + # Check that all permutations are equal + final_permutations = [] + for p in permutations: + if p is not None: + final_permutations.append(p.perm) + else: + final_permutations.append(np.arange(node.shape.size)) + + if len(final_permutations) == 0: + continue + + if not all([np.array_equal(final_permutations[0], perm) for perm in final_permutations]): + raise Error( + 'Permutations requested for {} data node are not equal! List of permutations: {}'.format(node.name, + [p.perm for + p in + permutations])) + + assert not node.has_valid('permutation') or np.array_equal(node.permutation, permutations[0]) + node['permutation'] = permutations[0] + if node.permutation is not None and node.permutation.perm.size == 0: + node.permutation = None + + +def permute_data_nodes_attrs(graph: nx.MultiDiGraph): + # Iterate over all data nodes and apply permutation if exists + for node in graph.nodes(): node = Node(graph, node) - if node.has_and_set('nchw_layout'): - # this operation already produces output in the NCHW format + if node.kind != 'data' or not node.has_valid('permutation'): continue - # TODO Consider move it to a separate pass with pattern matcher - log.debug("node.name = {}".format(node.name if node.has_valid('name') else '<NO NAME>')) - permute_attrs(graph.node[node.id], nhwc_to_nchw_permute, nchw_to_nhwc_permute) - permute_value(graph.node[node.id], nhwc_to_nchw_permute) - repack_fully_connected_weights_nhwc_to_nchw(node) + + # Apply permutation for shape and value if exists + node.shape = np.array(node.shape)[node.permutation.perm] + if node.has_valid('value'): + if len(node.value.shape) != len(node.permutation.perm): + log.warning('Node {} has shape {} and permutation {} that is not satisfied'.format(node.name, node.value.shape, node.permutation.perm)) + continue + #print(node.name, node.value.shape, node.shape, node.permutation) + node.value = np.array(node.value.transpose(node.permutation.perm)) + + +def permute_op_nodes_attrs(graph: nx.MultiDiGraph): + for node in graph.nodes(): + node = Node(graph, node) + if node.kind == 'op' and node.has_valid('permute_attrs'): + node.permute_attrs.permute_attrs(node) def reverse_input_channels(graph: nx.MultiDiGraph): @@ -265,14 +308,14 @@ def reverse_input_channels(graph: nx.MultiDiGraph): tmp_shape_for_reorder = list(bottom_weights.value.shape) src_shape = list(tmp_shape_for_reorder) log.debug('weights shape = {}'.format(tmp_shape_for_reorder)) - assert (tmp_shape_for_reorder[bottom_weights.input_channel_dim[0]] == bottom_channels) - tmp_shape_for_reorder[bottom_weights.input_channel_dim[0]] = ngroups + assert (tmp_shape_for_reorder[bottom_weights.input_channel_dim] == bottom_channels) + tmp_shape_for_reorder[bottom_weights.input_channel_dim] = ngroups tmp_shape_for_reorder = tmp_shape_for_reorder + [multiplier] log.debug('tmp_shape_for_reorder = {}'.format(tmp_shape_for_reorder)) # temporary change shape of weights to do reordering # bottom_weights.value.shape = tuple(tmp_shape_for_reorder) bottom_weights.value = np.flip(bottom_weights.value.reshape(tuple(tmp_shape_for_reorder)), - bottom_weights.input_channel_dim[0]) + bottom_weights.input_channel_dim) # change shape of weights back log.debug('back to shape = {}'.format(tuple(src_shape))) bottom_weights.value = bottom_weights.value.reshape(tuple(src_shape)) @@ -283,14 +326,17 @@ def reverse_input_channels(graph: nx.MultiDiGraph): 'Reverse input channels are not applied: there is no Conv2D after DepthwiseConv2dNative to ' + 'complete the flip') - conv.in_node(1).value = np.flip(conv.in_node(1).value, conv.in_node(1).input_channel_dim[0]) + conv.in_node(1).value = np.flip(conv.in_node(1).value, conv.in_node(1).input_channel_dim) log.debug('Applied reversing input channels for weights of convolution {}'.format(conv.id)) log.debug('Shape was (shape){}, (value.shape){}'.format(conv.in_node(1).shape, conv.in_node(1).value.shape)) - log.debug('Flipped dim: {}'.format(conv.in_node(1).input_channel_dim[0])) + log.debug('Flipped dim: {}'.format(conv.in_node(1).input_channel_dim)) def conv_flatten_concat_action(graph: nx.MultiDiGraph, match: dict): reshape_node = match['reshape'] + reshape_data_node = match['reshape_data'] + concat_node = match['concat'] + concat_data_node = match['concat_data'] conv_name = match['conv'].name conv_data_node = match['conv_data'] assert len(graph.in_edges(reshape_node.id)) == 1 @@ -298,17 +344,20 @@ def conv_flatten_concat_action(graph: nx.MultiDiGraph, match: dict): new_permute_op = Permute(graph, {'order': np.array([0, 2, 3, 1])}) permute_data_node = new_permute_op.create_node_with_data([conv_data_node], dict(name=conv_name + '/Permute_')) create_edge(permute_data_node, reshape_node) + # Disable permutation for Reshape and Concat + PermuteAttrs.set_permutation(reshape_node, reshape_data_node, None) + PermuteAttrs.set_permutation(concat_node, concat_data_node, None, skip_if_exists=True) def conv_flatten_concat(graph: nx.MultiDiGraph): apply_pattern( graph, nodes=[ - ('conv', dict(kind='op', op='Conv2D')), + ('conv', dict(kind='op', type='Convolution')), ('conv_data', dict(kind='data')), - ('reshape', dict(kind='op', op='Reshape')), + ('reshape', dict(kind='op', type='Reshape')), ('reshape_data', dict(kind='data')), - ('concat', dict(kind='op', op='ConcatV2')), + ('concat', dict(kind='op', type='Concat')), ('concat_data', dict(kind='data')) ], edges=[ @@ -318,9 +367,8 @@ def conv_flatten_concat(graph: nx.MultiDiGraph): ('reshape_data', 'concat'), ('concat', 'concat_data') ], - action=conv_flatten_concat_action, - node_attrs=['kind', 'op'], - edge_attrs=[]) + action=conv_flatten_concat_action + ) def fuse_sequence_of_reshapes(graph: nx.MultiDiGraph): |