diff options
Diffstat (limited to 'model-optimizer/extensions/back/FuseReshapesSequence.py')
-rw-r--r-- | model-optimizer/extensions/back/FuseReshapesSequence.py | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/model-optimizer/extensions/back/FuseReshapesSequence.py b/model-optimizer/extensions/back/FuseReshapesSequence.py index 2f2459f7d..27ca9485f 100644 --- a/model-optimizer/extensions/back/FuseReshapesSequence.py +++ b/model-optimizer/extensions/back/FuseReshapesSequence.py @@ -46,6 +46,11 @@ class FuseReshapesSequence(BackReplacementPattern): next_op = get_next_operation(node)[0] log.debug('second node: id={}, type={}'.format(next_op.soft_get('id'), next_op.soft_get('type'))) if next_op.has_valid('type') and next_op.type == 'Reshape': + dim_value = next_op.in_port(1).data.get_value() + if dim_value is None or 0 in dim_value or -1 in dim_value: + # we do not fuse reshape sequences with special symbols: 0, -1 + continue + # Detected Reshape1 --> data --> Reshape2 pattern without side edges. Remove Reshape1 log.debug('Second phase for Reshape: {}'.format(node.soft_get('name'))) remove_op_node_with_data_node(graph, node) |