summaryrefslogtreecommitdiff
path: root/model-optimizer/mo/middle/passes/shape.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/mo/middle/passes/shape.py')
-rw-r--r--model-optimizer/mo/middle/passes/shape.py174
1 files changed, 111 insertions, 63 deletions
diff --git a/model-optimizer/mo/middle/passes/shape.py b/model-optimizer/mo/middle/passes/shape.py
index b0285edd4..647502bf9 100644
--- a/model-optimizer/mo/middle/passes/shape.py
+++ b/model-optimizer/mo/middle/passes/shape.py
@@ -21,7 +21,7 @@ import numpy as np
from mo.front.extractor import update_attrs
from mo.graph.graph import Node, create_edge
-from mo.middle.passes.eliminate import remove_op_node, merge_data_nodes, graph_clean_up_tf
+from mo.middle.passes.eliminate import remove_op_node_with_data_node, merge_data_nodes, graph_clean_up_tf, get_nodes_with_attributes
from mo.middle.passes.fusing.helpers import get_next_operation
from mo.middle.pattern_match import apply_pattern
from mo.ops.op import PermuteAttrs, Op
@@ -38,12 +38,10 @@ def reshape_squeeze_transform(graph: nx.MultiDiGraph, match: dict):
reshape['shape'] = output.shape
reshape.op = 'Reshape'
reshape['type'] = 'Reshape'
- reshape['axis'] = 0 # TODO what does it mean?
if not reshape.has_valid('dim'):
# do not override value 'dim' if it is set. It may contain specific values like -1 and 0
reshape['dim'] = reshape.shape.copy()
update_attrs(reshape, 'shape_attrs', 'dim')
- reshape['num_axes'] = -1 # TODO what does it mean?
if 'shape' in match:
graph.remove_edge(match['shape'].node, match['reshape'].node)
@@ -71,60 +69,81 @@ def convert_reshape(graph: nx.MultiDiGraph):
)
+def can_repack_fully_connected_weights_nhwc_to_nchw(fc_node: Node):
+ """
+ Checks that it is possible to repack weights of the FullyConnected layer if the Reshape layer is the input of the
+ FullyConnected and satisfies several conditions.
+ :param fc_node: the FullyConnected node to check
+ :return: the result of the check
+ """
+ if len(fc_node.in_node(0).in_nodes()) != 1:
+ return False
+
+ reshape_node = fc_node.in_node(0).in_node(0)
+ if not reshape_node.has_valid('type') or reshape_node.type != 'Reshape':
+ return False
+
+ if not reshape_node.in_node(0).has_valid('shape') or not reshape_node.out_node().has_valid('shape'):
+ return False
+
+ orig_shape = reshape_node.in_node(0).shape
+ new_shape = reshape_node.out_node().shape
+
+ # TODO a bit conservative condition; relax it checking specific dimensions that are involved in
+ # NHWC to NCWH translation
+ if len(orig_shape) == len(new_shape) and all(orig_shape == new_shape):
+ return False
+
+ # TODO here is a couple of limitations that makes this pass simpler; consider to relax them
+ if len(orig_shape) == 4 and len(new_shape) == 2 and orig_shape[0] == new_shape[0]:
+ # that means orig_shape is in NCHW and new_shape is in NC
+ # and we need to map CHW part to C after HWC to CHW transform
+ # Assuming that FullyConnected weights haven't been converted from IO to OI yet.
+ # So format is IO.
+ return True
+ else:
+ log.warning("Cannot do the complete NHWC to NCHW translation for FullyConnected weights. "
+ "The final model can be broken.")
+ return False
+
+
def repack_fully_connected_weights_nhwc_to_nchw(graph: nx.MultiDiGraph):
"""
Repack weights of FullyConnected layer as a part of nhwc_to_nchw translation if Reshape of
that involves dimensions that we are repacking appears right before FullyConnected layer.
"""
- for node in graph.nodes():
- node = Node(graph, node)
- if node.has_valid('type') and node.type == 'FullyConnected':
- assert node.in_node(0).kind == 'data'
- if len(node.in_node(0).in_nodes()) == 1:
- input = node.in_node(0).in_node(0)
- if input.has_valid('type') and input.type == 'Reshape':
- assert len(input.in_nodes()) > 0
- if input.in_node(0).has_valid('shape') and input.out_node().has_valid('shape'):
-
- orig_shape = input.in_node(0).shape
- new_shape = input.out_node().shape
-
- # TODO a bit conservative condition; relax it checking specific dimensions
- # that are involved in NHWC to NCWH translation
- if len(orig_shape) != len(new_shape) or any(orig_shape != new_shape):
- # OK, here we are; need to repack node.in_node(1) to maintain it compatible with original
- # input order
-
- # TODO here is a couple of limitations that makes this pass simpler; consider to relax them
- if len(orig_shape) == 4 and len(new_shape) == 2 and orig_shape[0] == new_shape[0]:
- # that means orig_shape is in NCHW and new_shape is in NC
- # and we need to map CHW part to C after HWC to CHW transform
- # Assuming that FullyConnected weights haven't been converted from IO to OI yet.
- # So format is IO.
-
- assert all(orig_shape != -1), 'Input shape for {} can not be negative.'.format(node.id)
- assert all(new_shape != -1), 'Output shape for {} can not be negative.'.format(node.id)
- assert orig_shape[1] * orig_shape[2] * orig_shape[3] == new_shape[1], \
- 'Input shape does not correspond to output shape for layer {}.'.format(node.id)
- assert node.in_node(1).has_valid('value'), 'Node {} does not have value.'.format(node.id)
-
- weights = node.in_node(1)
-
- log.debug("orig_shape = {}".format(orig_shape))
- log.debug("new_shape = {}".format(new_shape))
- log.debug("weights.shape = {}".format(weights.shape))
- log.debug("weights.shape[1] = {}, new_shape[1] = {}".format(weights.shape[1], new_shape[1]))
-
- assert weights.shape[0] == new_shape[1], \
- 'First dim of weights does not correspond to output shape of {}'.format(node.id)
- # interpret I dimension of the weights as packed HWC
- # orig shape is already converted to NCHW, so provide transposed order for I repacking
- tmp_shape = (orig_shape[2], orig_shape[3], orig_shape[1], weights.shape[1])
- weights.value = np.transpose(weights.value.reshape(tmp_shape), (2, 0, 1, 3)).reshape(
- weights.shape)
- else:
- log.warning("Cannot do the complete NHWC to NCHW translation for FullyConnected weights. "
- "The final model can be broken.")
+ for node_id in get_nodes_with_attributes(graph, type='FullyConnected'):
+ fc_node = Node(graph, node_id)
+
+ if not can_repack_fully_connected_weights_nhwc_to_nchw(fc_node):
+ continue
+
+ reshape_node = fc_node.in_node(0).in_node(0)
+
+ orig_shape = reshape_node.in_node(0).shape
+ new_shape = reshape_node.out_node().shape
+
+ # OK, here we are; need to repack fc_node.in_node(1) to maintain it compatible with original input order
+
+ assert all(orig_shape != -1), 'Input shape for {} can not be negative.'.format(fc_node.id)
+ assert all(new_shape != -1), 'Output shape for {} can not be negative.'.format(fc_node.id)
+ assert orig_shape[1] * orig_shape[2] * orig_shape[3] == new_shape[1], \
+ 'Input shape does not correspond to output shape for layer {}.'.format(fc_node.id)
+ assert fc_node.in_node(1).has_valid('value'), 'Node {} does not have value.'.format(fc_node.id)
+
+ weights = fc_node.in_node(1)
+
+ log.debug("orig_shape = {}".format(orig_shape))
+ log.debug("new_shape = {}".format(new_shape))
+ log.debug("weights.shape = {}".format(weights.shape))
+ log.debug("weights.shape[1] = {}, new_shape[1] = {}".format(weights.shape[1], new_shape[1]))
+
+ assert weights.shape[0] == new_shape[1], \
+ 'First dim of weights does not correspond to output shape of {}'.format(fc_node.id)
+ # interpret I dimension of the weights as packed HWC
+ # orig shape is already converted to NCHW, so provide transposed order for I repacking
+ tmp_shape = (orig_shape[2], orig_shape[3], orig_shape[1], weights.shape[1])
+ weights.value = np.transpose(weights.value.reshape(tmp_shape), (2, 0, 1, 3)).reshape(weights.shape)
def apply_nhwc_to_nchw_permutation(graph: nx.MultiDiGraph):
@@ -230,7 +249,10 @@ def permute_op_nodes_attrs(graph: nx.MultiDiGraph):
for node in graph.nodes():
node = Node(graph, node)
if node.kind == 'op' and node.has_valid('permute_attrs'):
- node.permute_attrs.permute_attrs(node)
+ try:
+ node.permute_attrs.permute_attrs(node)
+ except Exception as e:
+ raise Error('Can\'t permute attrs for node {}. Error message: {}'.format(node.id, e))
def reverse_input_channels(graph: nx.MultiDiGraph):
@@ -333,20 +355,30 @@ def reverse_input_channels(graph: nx.MultiDiGraph):
def conv_flatten_concat_action(graph: nx.MultiDiGraph, match: dict):
+ assert graph.graph['layout'] == 'NHWC'
reshape_node = match['reshape']
reshape_data_node = match['reshape_data']
- concat_node = match['concat']
- concat_data_node = match['concat_data']
conv_name = match['conv'].name
conv_data_node = match['conv_data']
+ # the pattern should be applied only in case when the reshape operation changes number of dimensions
+ if len(reshape_data_node.shape) == len(conv_data_node.shape) or reshape_node.has_and_set('nchw_layout'):
+ return
+
+ if len(reshape_data_node.out_nodes()) == 1 and reshape_data_node.out_node().has_valid('type') and \
+ reshape_data_node.out_node().type == 'FullyConnected' and \
+ can_repack_fully_connected_weights_nhwc_to_nchw(reshape_data_node.out_node()):
+ log.info('There is a FullyConnected layer after the node "{}" which weights will be repacked. So there is no '
+ 'need to insert Permute'.format(reshape_node.soft_get('name')))
+ return
assert len(graph.in_edges(reshape_node.id)) == 1
graph.remove_edge(conv_data_node.id, reshape_node.id)
- new_permute_op = Permute(graph, {'order': np.array([0, 2, 3, 1])})
+
+ permutation_order = PermuteAttrs.get_nchw_to_nhwc_permutation(len(conv_data_node.shape)).perm
+ new_permute_op = Permute(graph, {'order': permutation_order})
permute_data_node = new_permute_op.create_node_with_data([conv_data_node], dict(name=conv_name + '/Permute_'))
create_edge(permute_data_node, reshape_node)
- # Disable permutation for Reshape and Concat
+ # Disable permutation for Reshape and Concat layers attributes
PermuteAttrs.set_permutation(reshape_node, reshape_data_node, None)
- PermuteAttrs.set_permutation(concat_node, concat_data_node, None, skip_if_exists=True)
def conv_flatten_concat(graph: nx.MultiDiGraph):
@@ -357,15 +389,31 @@ def conv_flatten_concat(graph: nx.MultiDiGraph):
('conv_data', dict(kind='data')),
('reshape', dict(kind='op', type='Reshape')),
('reshape_data', dict(kind='data')),
- ('concat', dict(kind='op', type='Concat')),
- ('concat_data', dict(kind='data'))
],
edges=[
('conv', 'conv_data'),
('conv_data', 'reshape'),
('reshape', 'reshape_data'),
- ('reshape_data', 'concat'),
- ('concat', 'concat_data')
+ ],
+ action=conv_flatten_concat_action
+ )
+
+ apply_pattern(
+ graph,
+ nodes=[
+ ('real_conv', dict(kind='op', type='Convolution')),
+ ('real_conv_data', dict(kind='data')),
+ ('conv', dict(kind='op', type='ReLU')),
+ ('conv_data', dict(kind='data')),
+ ('reshape', dict(kind='op', type='Reshape')),
+ ('reshape_data', dict(kind='data')),
+ ],
+ edges=[
+ ('real_conv', 'real_conv_data'),
+ ('real_conv_data', 'conv'),
+ ('conv', 'conv_data'),
+ ('conv_data', 'reshape'),
+ ('reshape', 'reshape_data'),
],
action=conv_flatten_concat_action
)
@@ -390,4 +438,4 @@ def fuse_sequence_of_reshapes(graph: nx.MultiDiGraph):
# Detected Reshape1 --> data --> Reshape2 pattern without side edges
# Remove Reshape1
log.debug('Second phase for Reshape: {}'.format(node.name))
- remove_op_node(graph, node)
+ remove_op_node_with_data_node(graph, node)