summaryrefslogtreecommitdiff
path: root/model-optimizer/extensions/middle/ConvertGroupedStridedSlice.py
blob: cec09ccb786c6054e63d137d5eef75d84b9967f2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
 Copyright (c) 2018 Intel Corporation

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import numpy as np
import networkx as nx
import logging as log

from extensions.ops.splitv import SplitV
from mo.graph.graph import Node
from mo.ops.op import Op
from mo.ops.reshape import Reshape
from mo.middle.replacement import MiddleReplacementPattern
from extensions.middle.SliceConverter import ConvertSlice

class ConvertGroupedStridedSlice(MiddleReplacementPattern):
    """
        This pass converts subgraphs where StridedSlices used for splitting single channel to single Split layers
        In case if StrdedSlices consume not entire tensor will be created fake outputs for Split layer
        For example:
            Let's suppose we have next graph:
            Data(1,H,W,54)
               |`---->Sslice1_out (1,H,W,(10,18))
               `---->Sslice2_out (1,H,W,(18,36))

            In this case StridedSlices takes only [10, 36] from input tensor in 3rd dim
            So this pass will convert this graph to the next one:
            Split(1,H,W,54)
               |`---->Fake_data (1,H,W,10)
               |`---->Sslice1_out (1,H,W,8)
               |`---->Sslice2_out (1,H,W,18)
               `----->Fake_data (1,H,W,18)
            Where Fake_data - data nodes that have not any consumers.
    """

    enabled = True

    def run_after(self):
        return [ConvertSlice]

    def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
        # Iterate over all data nodes and find all with >= 1 consumers
        data_nodes = [Node(graph, node) for node in graph.node if Node(graph, node).kind == 'data']
        for input_data in data_nodes:
            # We don't use constant data nodes
            if input_data.value is not None:
                continue

            input_shape = np.array(input_data.shape)

            # Get all StridedSlice consumers
            out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice']
            if len(out_nodes) < 1:
                continue

            valid_for_replacement = True

            # Detect dimension for splitting
            split_channel_dim = None
            for dim_id, s in enumerate(out_nodes[0].slices):
                l, r, stride = s.start, s.stop, s.step
                if l != 0 or r != input_shape[dim_id]:
                    if split_channel_dim is None:
                        split_channel_dim = dim_id
                    else:
                        valid_for_replacement = False

            # split_dims contains tuples with split range and output data node
            split_dims = []
            for out_id, node in enumerate(out_nodes):
                # Check that StridedSlice op has no shrink_axis_mask attribute
                if not np.all([x == False for x in node.shrink_axis_mask]):
                    valid_for_replacement = False
                # Check that StridedSlice op has stride eq 1 and splits only feature channel
                for id, s in enumerate(node.slices):
                    l, r, stride = s.start, s.stop, s.step
                    # We don't support StridedSlice with stride != 1
                    if stride != 1:
                        valid_for_replacement = False
                    if id == split_channel_dim:
                        split_dims.append((s.start, s.stop, node.out_node()))

            if not valid_for_replacement:
                continue

            # Check feature split intersection
            final_data_nodes_list = []
            sorted_split_dims = sorted(split_dims)
            size_splits = []
            prev_r = 0
            for l, r, out in sorted_split_dims:
                # Split dims shouldn't intersect
                if l < prev_r:
                    valid_for_replacement = False
                # Save missing tensor part
                if l > prev_r:
                    shape = np.array(input_shape)
                    size_splits.append(l - prev_r)
                    shape[split_channel_dim] = l - prev_r
                    data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape, 'is_output': True})
                    final_data_nodes_list.append(data_node)


                prev_r = r
                size_splits.append(r - l)
                final_data_nodes_list.append(out)

            if prev_r > input_shape[split_channel_dim]:
                valid_for_replacement = False
            elif prev_r != input_shape[split_channel_dim]:
                # Add last part of tensor
                shape = input_shape.copy()
                shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r
                size_splits.append(input_shape[split_channel_dim] - prev_r)
                data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape, 'is_output': True})
                final_data_nodes_list.append(data_node)

            if not valid_for_replacement:
                continue

            # Insert Split layer and remove old StridedSlice layers
            # 1. Remove connections from input_data to StridedSlice ops
            out_data_nodes = []
            name_for_future_split = out_nodes[0].name
            for node in out_nodes:
                out_data_nodes.append(node.out_node())
                graph.remove_edge(input_data.id, node.id)
                graph.remove_edge(node.id, node.out_node().id)
                graph.remove_node(node.id)
                log.debug("Removed: {}".format(node.id))

            # 2. Create Split layer and reorder outputs
            split = SplitV(graph, dict(name=name_for_future_split + "/Split", axis=split_channel_dim,
                                       size_splits=size_splits))
            split.create_node_with_data(inputs=[input_data], data_nodes=final_data_nodes_list)