summaryrefslogtreecommitdiff
path: root/model-optimizer/mo/middle/passes/shape.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/mo/middle/passes/shape.py')
-rw-r--r--model-optimizer/mo/middle/passes/shape.py316
1 files changed, 182 insertions, 134 deletions
diff --git a/model-optimizer/mo/middle/passes/shape.py b/model-optimizer/mo/middle/passes/shape.py
index cfb4d89f0..b0285edd4 100644
--- a/model-optimizer/mo/middle/passes/shape.py
+++ b/model-optimizer/mo/middle/passes/shape.py
@@ -19,11 +19,12 @@ import logging as log
import networkx as nx
import numpy as np
-from mo.front.common.layout import nchw_to_nhwc_permute, nhwc_to_nchw_permute
from mo.front.extractor import update_attrs
from mo.graph.graph import Node, create_edge
-from mo.middle.passes.eliminate import remove_op_node
+from mo.middle.passes.eliminate import remove_op_node, merge_data_nodes, graph_clean_up_tf
+from mo.middle.passes.fusing.helpers import get_next_operation
from mo.middle.pattern_match import apply_pattern
+from mo.ops.op import PermuteAttrs, Op
from mo.ops.permute import Permute
from mo.utils.error import Error
from mo.utils.utils import refer_to_faq_msg
@@ -38,7 +39,9 @@ def reshape_squeeze_transform(graph: nx.MultiDiGraph, match: dict):
reshape.op = 'Reshape'
reshape['type'] = 'Reshape'
reshape['axis'] = 0 # TODO what does it mean?
- reshape['dim'] = reshape.shape.copy()
+ if not reshape.has_valid('dim'):
+ # do not override value 'dim' if it is set. It may contain specific values like -1 and 0
+ reshape['dim'] = reshape.shape.copy()
update_attrs(reshape, 'shape_attrs', 'dim')
reshape['num_axes'] = -1 # TODO what does it mean?
if 'shape' in match:
@@ -52,9 +55,8 @@ def convert_squeeze(graph: nx.MultiDiGraph):
('reshape', dict(kind='op', op='Squeeze')),
('output', dict(kind='data'))],
edges=[('reshape', 'output')],
- action=reshape_squeeze_transform,
- node_attrs=['kind', 'op'],
- edge_attrs=[])
+ action=reshape_squeeze_transform
+ )
def convert_reshape(graph: nx.MultiDiGraph):
@@ -65,129 +67,170 @@ def convert_reshape(graph: nx.MultiDiGraph):
('reshape', dict(kind='op', op='Reshape')),
('output', dict(kind='data'))],
edges=[('shape', 'reshape', {'in': 1}), ('reshape', 'output')],
- action=reshape_squeeze_transform,
- node_attrs=['kind', 'op'],
- edge_attrs=['in'])
-
-
-def permute_attrs(attrs: dict, perm: np.ndarray, inv: np.ndarray):
- log.debug("perm = {}, inv = {}".format(perm, inv))
- assert len(perm) == len(inv)
- assert all(perm[inv] == range(len(perm)))
- if 'dim_attrs' in attrs:
- for a in attrs['dim_attrs']: # inv is applicable for dim_attrs
- if a in attrs and attrs[a] is not None:
- try:
- attrs[a] = inv[attrs[a]] if not isinstance(a, np.ndarray) else np.array(inv[attrs[a]])
- except:
- raise Error("Can not transpose attribute '{}' with value {} for node '{}'. {}".format(a,
- attrs[a], attrs['name'] if 'name' in attrs else '<unknown_name>', refer_to_faq_msg(98)))
- if 'shape_attrs' in attrs:
- for a in attrs['shape_attrs']: # perm is applicable for shape_attrs
- if a in attrs and attrs[a] is not None:
- length = 0
- try:
- length = len(attrs[a])
- except TypeError:
- log.warning("Can not transpose attribute '{}' with value {} for node '{}'.".format(a, attrs[a], attrs['name'] if 'name' in attrs else '<unknown_name>'))
- continue
- if length == 4:
- try:
- if not isinstance(attrs[a], np.ndarray):
- attrs[a] = np.array(attrs[a])
- log.debug("a = {}, attrs[a] = {}, type(attrs[a]) = {}".format(a, attrs[a], type(attrs[a])))
- attrs[a] = np.array(attrs[a][perm])
- except:
- raise Error("Can not transpose attribute '{}' with value {} for node '{}'. {}".format(a,
- attrs[a], attrs['name'] if 'name' in attrs else '<unknown_name>', refer_to_faq_msg(98)))
-
-
-def permute_value(attrs: dict, perm: np.ndarray):
- if 'value' in attrs and attrs['value'] is not None and attrs['value'].ndim == 4:
- log.debug(
- 'Permutation {} of value with shape {} for node "{}". Expected target shape: {}.'.format(
- perm,
- attrs['value'].shape,
- attrs['name'] if 'name' in attrs else '<NONE>',
- attrs['shape'] if 'shape' in attrs else '<NONE>'
- )
- )
- assert 'shape' in attrs and len(attrs['shape']) == 4
- attrs['value'] = np.transpose(attrs['value'], perm)
- assert (list(attrs['shape']) == list(attrs['value'].shape))
-
-
-def repack_fully_connected_weights_nhwc_to_nchw(node: Node):
+ action=reshape_squeeze_transform
+ )
+
+
+def repack_fully_connected_weights_nhwc_to_nchw(graph: nx.MultiDiGraph):
"""
Repack weights of FullyConnected layer as a part of nhwc_to_nchw translation if Reshape of
that involves dimensions that we are repacking appears right before FullyConnected layer.
"""
- if node.has_valid('type') and node.type == 'FullyConnected':
- assert node.in_node(0).kind == 'data'
- if len(node.in_node(0).in_nodes()) == 1:
- input = node.in_node(0).in_node(0)
- if input.has_valid('type') and input.type == 'Reshape':
- assert len(input.in_nodes()) > 0
- if input.in_node(0).has_valid('shape') and input.out_node().has_valid('shape'):
-
- orig_shape = input.in_node(0).shape
- new_shape = input.out_node().shape
-
- # TODO a bit conservative condition; relax it checking specific dimensions
- # that are involved in NHWC to NCWH translation
- if len(orig_shape) != len(new_shape) or any(orig_shape != new_shape):
- # OK, here we are; need to repack node.in_node(1) to maintain it compatible with original
- # input order
-
- # TODO here is a couple of limitations that makes this pass simpler; consider to relax them
- if len(orig_shape) == 4 and len(new_shape) == 2 and orig_shape[0] == new_shape[0]:
- # that means orig_shape is in NCHW and new_shape is in NC
- # and we need to map CHW part to C after HWC to CHW transform
- # Assuming that FullyConnected weights haven't been converted from IO to OI yet.
- # So format is IO.
-
- assert all(orig_shape != -1), 'Input shape for {} can not be negative.'.format(node.id)
- assert all(new_shape != -1), 'Output shape for {} can not be negative.'.format(node.id)
- assert orig_shape[1] * orig_shape[2] * orig_shape[3] == new_shape[1], \
- 'Input shape does not correspond to output shape for layer {}.'.format(node.id)
- assert node.in_node(1).has_valid('value'), 'Node {} does not have value.'.format(node.id)
-
- weights = node.in_node(1)
-
- log.debug("orig_shape = {}".format(orig_shape))
- log.debug("new_shape = {}".format(new_shape))
- log.debug("weights.shape = {}".format(weights.shape))
- log.debug("weights.shape[1] = {}, new_shape[1] = {}".format(weights.shape[1], new_shape[1]))
-
- assert weights.shape[0] == new_shape[1], \
- 'First dim of weights does not correspond to output shape of {}'.format(node.id)
- # interpret I dimension of the weights as packed HWC
- # orig shape is already converted to NCHW, so provide transposed order for I repacking
- tmp_shape = (orig_shape[2], orig_shape[3], orig_shape[1], weights.shape[1])
- weights.value = np.transpose(weights.value.reshape(tmp_shape), (2, 0, 1, 3)).reshape(
- weights.shape)
-
- else:
- log.warning("Cannot do the complete NHWC to NCHW translation for FullyConnected weights. " +
- "The final model can be broken.")
-
-
-def convert_nhwc_to_nchw(graph: nx.MultiDiGraph):
- """
- It doesn't check if the model really has NHWC, it assumes this.
- So it is just do global permute for all data (without values) and ops.
- """
- nodes = nx.topological_sort(graph)
- for node in nodes:
+ for node in graph.nodes():
+ node = Node(graph, node)
+ if node.has_valid('type') and node.type == 'FullyConnected':
+ assert node.in_node(0).kind == 'data'
+ if len(node.in_node(0).in_nodes()) == 1:
+ input = node.in_node(0).in_node(0)
+ if input.has_valid('type') and input.type == 'Reshape':
+ assert len(input.in_nodes()) > 0
+ if input.in_node(0).has_valid('shape') and input.out_node().has_valid('shape'):
+
+ orig_shape = input.in_node(0).shape
+ new_shape = input.out_node().shape
+
+ # TODO a bit conservative condition; relax it checking specific dimensions
+ # that are involved in NHWC to NCWH translation
+ if len(orig_shape) != len(new_shape) or any(orig_shape != new_shape):
+ # OK, here we are; need to repack node.in_node(1) to maintain it compatible with original
+ # input order
+
+ # TODO here is a couple of limitations that makes this pass simpler; consider to relax them
+ if len(orig_shape) == 4 and len(new_shape) == 2 and orig_shape[0] == new_shape[0]:
+ # that means orig_shape is in NCHW and new_shape is in NC
+ # and we need to map CHW part to C after HWC to CHW transform
+ # Assuming that FullyConnected weights haven't been converted from IO to OI yet.
+ # So format is IO.
+
+ assert all(orig_shape != -1), 'Input shape for {} can not be negative.'.format(node.id)
+ assert all(new_shape != -1), 'Output shape for {} can not be negative.'.format(node.id)
+ assert orig_shape[1] * orig_shape[2] * orig_shape[3] == new_shape[1], \
+ 'Input shape does not correspond to output shape for layer {}.'.format(node.id)
+ assert node.in_node(1).has_valid('value'), 'Node {} does not have value.'.format(node.id)
+
+ weights = node.in_node(1)
+
+ log.debug("orig_shape = {}".format(orig_shape))
+ log.debug("new_shape = {}".format(new_shape))
+ log.debug("weights.shape = {}".format(weights.shape))
+ log.debug("weights.shape[1] = {}, new_shape[1] = {}".format(weights.shape[1], new_shape[1]))
+
+ assert weights.shape[0] == new_shape[1], \
+ 'First dim of weights does not correspond to output shape of {}'.format(node.id)
+ # interpret I dimension of the weights as packed HWC
+ # orig shape is already converted to NCHW, so provide transposed order for I repacking
+ tmp_shape = (orig_shape[2], orig_shape[3], orig_shape[1], weights.shape[1])
+ weights.value = np.transpose(weights.value.reshape(tmp_shape), (2, 0, 1, 3)).reshape(
+ weights.shape)
+ else:
+ log.warning("Cannot do the complete NHWC to NCHW translation for FullyConnected weights. "
+ "The final model can be broken.")
+
+
+def apply_nhwc_to_nchw_permutation(graph: nx.MultiDiGraph):
+ # Add NHWC to NCHW permutation for all data nodes (only for nodes without permutation)
+ if graph.graph['layout'] == 'NCHW':
+ return
+ for node in graph.nodes():
+ node = Node(graph, node)
+ if node.kind == 'data':
+ if node.has_and_set('nchw_layout'):
+ continue
+
+ # Get NHWC to NCHW permutation for N dims, where N = len(node.shape)
+ permutation = PermuteAttrs().get_nhwc_to_nchw_permutation(len(node.shape))
+
+ # Check that data node already has permutation
+ skip_permutation = False
+ for in_node in node.in_nodes():
+ edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
+ if 'permutation' in edge_attrs:
+ skip_permutation = True
+ for out_node in node.out_nodes():
+ edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
+ if 'permutation' in edge_attrs:
+ skip_permutation = True
+
+ if skip_permutation:
+ continue
+
+ # Set permutation to all in/out edges
+ for in_node in node.in_nodes():
+ PermuteAttrs.set_permutation(in_node, node, permutation)
+
+ for out_node in node.out_nodes():
+ PermuteAttrs.set_permutation(node, out_node, permutation)
+
+
+def merge_nodes_permutations(graph: nx.MultiDiGraph):
+ # Iterate over all data nodes and check all permutations for similarity
+ # In case of equal permutations, this permutation will be set as attribute for data node
+ # otherwise exception will be raised
+ for node in graph.nodes():
+ node = Node(graph, node)
+ if node.kind != 'data':
+ continue
+
+ permutations = []
+
+ # Get all permutations from in edges
+ for in_node in node.in_nodes():
+ edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
+ if 'permutation' in edge_attrs:
+ permutations.append(edge_attrs['permutation'])
+
+ # Get all permutations from out edges
+ for out_node in node.out_nodes():
+ edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
+ if 'permutation' in edge_attrs:
+ permutations.append(edge_attrs['permutation'])
+
+ # Check that all permutations are equal
+ final_permutations = []
+ for p in permutations:
+ if p is not None:
+ final_permutations.append(p.perm)
+ else:
+ final_permutations.append(np.arange(node.shape.size))
+
+ if len(final_permutations) == 0:
+ continue
+
+ if not all([np.array_equal(final_permutations[0], perm) for perm in final_permutations]):
+ raise Error(
+ 'Permutations requested for {} data node are not equal! List of permutations: {}'.format(node.name,
+ [p.perm for
+ p in
+ permutations]))
+
+ assert not node.has_valid('permutation') or np.array_equal(node.permutation, permutations[0])
+ node['permutation'] = permutations[0]
+ if node.permutation is not None and node.permutation.perm.size == 0:
+ node.permutation = None
+
+
+def permute_data_nodes_attrs(graph: nx.MultiDiGraph):
+ # Iterate over all data nodes and apply permutation if exists
+ for node in graph.nodes():
node = Node(graph, node)
- if node.has_and_set('nchw_layout'):
- # this operation already produces output in the NCHW format
+ if node.kind != 'data' or not node.has_valid('permutation'):
continue
- # TODO Consider move it to a separate pass with pattern matcher
- log.debug("node.name = {}".format(node.name if node.has_valid('name') else '<NO NAME>'))
- permute_attrs(graph.node[node.id], nhwc_to_nchw_permute, nchw_to_nhwc_permute)
- permute_value(graph.node[node.id], nhwc_to_nchw_permute)
- repack_fully_connected_weights_nhwc_to_nchw(node)
+
+ # Apply permutation for shape and value if exists
+ node.shape = np.array(node.shape)[node.permutation.perm]
+ if node.has_valid('value'):
+ if len(node.value.shape) != len(node.permutation.perm):
+ log.warning('Node {} has shape {} and permutation {} that is not satisfied'.format(node.name, node.value.shape, node.permutation.perm))
+ continue
+ #print(node.name, node.value.shape, node.shape, node.permutation)
+ node.value = np.array(node.value.transpose(node.permutation.perm))
+
+
+def permute_op_nodes_attrs(graph: nx.MultiDiGraph):
+ for node in graph.nodes():
+ node = Node(graph, node)
+ if node.kind == 'op' and node.has_valid('permute_attrs'):
+ node.permute_attrs.permute_attrs(node)
def reverse_input_channels(graph: nx.MultiDiGraph):
@@ -265,14 +308,14 @@ def reverse_input_channels(graph: nx.MultiDiGraph):
tmp_shape_for_reorder = list(bottom_weights.value.shape)
src_shape = list(tmp_shape_for_reorder)
log.debug('weights shape = {}'.format(tmp_shape_for_reorder))
- assert (tmp_shape_for_reorder[bottom_weights.input_channel_dim[0]] == bottom_channels)
- tmp_shape_for_reorder[bottom_weights.input_channel_dim[0]] = ngroups
+ assert (tmp_shape_for_reorder[bottom_weights.input_channel_dim] == bottom_channels)
+ tmp_shape_for_reorder[bottom_weights.input_channel_dim] = ngroups
tmp_shape_for_reorder = tmp_shape_for_reorder + [multiplier]
log.debug('tmp_shape_for_reorder = {}'.format(tmp_shape_for_reorder))
# temporary change shape of weights to do reordering
# bottom_weights.value.shape = tuple(tmp_shape_for_reorder)
bottom_weights.value = np.flip(bottom_weights.value.reshape(tuple(tmp_shape_for_reorder)),
- bottom_weights.input_channel_dim[0])
+ bottom_weights.input_channel_dim)
# change shape of weights back
log.debug('back to shape = {}'.format(tuple(src_shape)))
bottom_weights.value = bottom_weights.value.reshape(tuple(src_shape))
@@ -283,14 +326,17 @@ def reverse_input_channels(graph: nx.MultiDiGraph):
'Reverse input channels are not applied: there is no Conv2D after DepthwiseConv2dNative to ' +
'complete the flip')
- conv.in_node(1).value = np.flip(conv.in_node(1).value, conv.in_node(1).input_channel_dim[0])
+ conv.in_node(1).value = np.flip(conv.in_node(1).value, conv.in_node(1).input_channel_dim)
log.debug('Applied reversing input channels for weights of convolution {}'.format(conv.id))
log.debug('Shape was (shape){}, (value.shape){}'.format(conv.in_node(1).shape, conv.in_node(1).value.shape))
- log.debug('Flipped dim: {}'.format(conv.in_node(1).input_channel_dim[0]))
+ log.debug('Flipped dim: {}'.format(conv.in_node(1).input_channel_dim))
def conv_flatten_concat_action(graph: nx.MultiDiGraph, match: dict):
reshape_node = match['reshape']
+ reshape_data_node = match['reshape_data']
+ concat_node = match['concat']
+ concat_data_node = match['concat_data']
conv_name = match['conv'].name
conv_data_node = match['conv_data']
assert len(graph.in_edges(reshape_node.id)) == 1
@@ -298,17 +344,20 @@ def conv_flatten_concat_action(graph: nx.MultiDiGraph, match: dict):
new_permute_op = Permute(graph, {'order': np.array([0, 2, 3, 1])})
permute_data_node = new_permute_op.create_node_with_data([conv_data_node], dict(name=conv_name + '/Permute_'))
create_edge(permute_data_node, reshape_node)
+ # Disable permutation for Reshape and Concat
+ PermuteAttrs.set_permutation(reshape_node, reshape_data_node, None)
+ PermuteAttrs.set_permutation(concat_node, concat_data_node, None, skip_if_exists=True)
def conv_flatten_concat(graph: nx.MultiDiGraph):
apply_pattern(
graph,
nodes=[
- ('conv', dict(kind='op', op='Conv2D')),
+ ('conv', dict(kind='op', type='Convolution')),
('conv_data', dict(kind='data')),
- ('reshape', dict(kind='op', op='Reshape')),
+ ('reshape', dict(kind='op', type='Reshape')),
('reshape_data', dict(kind='data')),
- ('concat', dict(kind='op', op='ConcatV2')),
+ ('concat', dict(kind='op', type='Concat')),
('concat_data', dict(kind='data'))
],
edges=[
@@ -318,9 +367,8 @@ def conv_flatten_concat(graph: nx.MultiDiGraph):
('reshape_data', 'concat'),
('concat', 'concat_data')
],
- action=conv_flatten_concat_action,
- node_attrs=['kind', 'op'],
- edge_attrs=[])
+ action=conv_flatten_concat_action
+ )
def fuse_sequence_of_reshapes(graph: nx.MultiDiGraph):