summaryrefslogtreecommitdiff
path: root/model-optimizer/mo/front/common/partial_infer/split.py
blob: 209a516ae59f1129bc92a0dfd72d155e3df96f88 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
"""
 Copyright (c) 2018 Intel Corporation

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import logging as log
import numpy as np
from mo.utils.error import Error


def part_sizes_to_indices(part_sizes: list):
    """
    Calculates indices of splits in the array based on part sizes for the split.
    Output list can be used as the second argument for np.split function.
    """
    idx = 0
    indices = []
    for part_size in part_sizes:
        idx += part_size
        indices.append(idx)
    # the last element should equal to the size of original array and it is redundant to numpy
    log.debug("part_sizes: {}   -->   indices: {}".format(part_sizes, indices))
    del indices[-1]
    log.debug("part_sizes: {}   -->   indices: {}".format(part_sizes, indices))
    return np.array(indices)


def split(input, node, outputs, axis, part_sizes):
    """
    Partial inference of generic split node.

    Args:
        @input: input tensor node, subject to split
        @outputs: output tensor nodes where we put inferred output shapes
        @axis: split dimension index
        @part_sizes: a NumPy array with sizes of all pieces that we split to

    Returns:
        int: normalized axis index

    """

    if input.shape is None:
        return

    if len(outputs) != len(part_sizes):
        log.error('Number of outputs do not match the number of parts with sizes.')
        return

    # normalize axis
    if axis < 0:
        axis = input.shape.size + axis

    if axis < 0 or axis >= input.shape.size:
        log.error('Model is incorrect: axis for split node is out of range')
        return

    undef_indices = np.argwhere(part_sizes == -1)
    if undef_indices.size > 1:
        log.error('Desired split part sizes have more than one -1 element -- cannot deduce real sizes for them')
        return

    if undef_indices.size == 1:
        undef_index = undef_indices[0]
        part_sizes[undef_index] = 0
        deduced_dim = input.shape[axis] - np.add.reduce(part_sizes)
        if deduced_dim < 0:
            log.error(
                'Just deduced dimension for the split has negative value that means that split input shape and desired parts are not compatible')
            return

    all_parts_size = np.add.reduce(part_sizes)
    if all_parts_size != input.shape[axis]:
        log.error("input.shape[{}] = {}  !=  {} = sum of all parts in part_sizes".format(axis, input.shape[axis],
                                                                                         all_parts_size))
        return

    for i, part_size in enumerate(part_sizes):
        shape = input.shape.copy()
        shape[axis] = part_size
        outputs[i].shape = shape

    if input.value is not None:
        splitted = np.split(input.value, part_sizes_to_indices(part_sizes), axis)
        # log.debug("splitted = {}".format(splitted))
        for i, part in enumerate(splitted):
            outputs[i].value = part
            # log.debug('outputs[i].value.shape = {}, outputs[i].shape = {}'.format(outputs[i].value.shape, outputs[i].shape))
            assert all(outputs[i].value.shape == outputs[i].shape)

    assert not node.has_valid('axis') or node.axis == axis
    node.axis = axis
    # WARNING: != 4 is supposed to work for NHWC to NCHW translation only; if other global permutations happen this will fail
    # TODO: redesign it to have this logic built in NHWC to NCHW translation pass; it requires
    #       additional attributes with layout to be propagated through the network
    if len(input.shape) != 4 and node.has_valid('dim_attrs') and 'axis' in node.dim_attrs:
        log.warning(
            'Removed "axis" attribute from the scope of the model relayout pass because len(input.shape) == {} != 4 for node {}'.format(
                len(input.shape),
                node.name if node.has_valid('name') else '<UNKNOWN>'))
        node.dim_attrs.remove('axis')
        assert 'axis' not in node.dim_attrs


def tf_split_infer(node):
    """
    Partial infer of split node similar to Split op of TF.
    """

    if len(node.in_nodes()) == 1:
        return True

    # Two inputs: [split_dim, input)
    assert (len(node.in_nodes()) == 2)
    split_dim = node.in_node(0).value
    if split_dim is None:
        log.error('split_dim value for node {} is None. Cannot do shape inference.')
        return
    assert split_dim.ndim == 0
    split_dim = split_dim.item()
    input = node.in_node(1)

    if split_dim is None or input.shape is None:
        return

    log.debug('input shape for split: {}, should be split along {} dim'.format(input.shape, split_dim))
    split_dim_size = input.shape[split_dim]
    log.debug('split_dim_size type = {}'.format(type(split_dim_size)))

    if split_dim_size % node.num_split != 0:
        log.error("split_dim cannot be evenly divided by a given number of parts")
        return

    outputs = node.out_nodes()
    # split_dim is a numpy array, axis is split_dim[0]
    log.debug(
        'split_dim_size = {}, node.num_split = {}, div = {}, typeof div = {}'.format(split_dim_size, node.num_split,
                                                                                     split_dim_size / node.num_split,
                                                                                     type(
                                                                                         split_dim_size / node.num_split)))
    split(input, node, [outputs[i] for i in range(len(outputs))], split_dim,
          [int(split_dim_size / node.num_split)] * node.num_split)
    log.debug('output shapes after split: {}'.format([v.shape for k, v in outputs.items()]))
    node.graph.remove_edge(node.in_node(0).id, node.id)
    node['input_port'] = 1


def tf_split_v_infer(node):
    """
    Partial infer of split node similar to SplitV op of TF.
    """

    if len(node.in_nodes()) == 1:
        return True

    # Three inputs: [input, size_splits, split_dim)
    assert (len(node.in_nodes()) == 3)
    split_dim = node.in_node(2).value
    if split_dim is None:
        log.error('split_dim value for node {} is None. Cannot do shape inference.')
        return
    assert split_dim.ndim == 0
    split_dim = split_dim.item()
    input = node.in_node(0)
    size_splits = node.in_node(1)

    log.debug(
        'split_dim = {}, input.shape = {}, size_splits.value = {}'.format(split_dim, input.shape, size_splits.value))

    if split_dim is None or input.shape is None or size_splits.value is None:
        return

    outputs = node.out_nodes()
    # split_dim is a numpy array, axis is split_dim
    split(input, node, [outputs[i] for i in range(len(outputs))], split_dim, size_splits.value)
    log.debug('output shapes after split: {}'.format([v.shape for k, v in outputs.items()]))
    node.graph.remove_edge(node.in_node(1).id, node.id)
    node.graph.remove_edge(node.in_node(2).id, node.id)


def tf_unpack_infer(node):
    if len(node.in_nodes()) != 1:
        log.debug('Unpack node "{}" must have one input.'.format(node.name))
        return

    in_shape = node.in_node().shape
    if in_shape is None:
        log.debug('Unpack node "{}" input node shape is not defined.'.format(node.name))
        return

    split_dim = node.axis
    log.debug('input shape for unpack: {}, should be split along {} dim'.format(in_shape, split_dim))
    split_dim_size = in_shape[split_dim]
    log.debug('split_dim_size type = {}'.format(type(split_dim_size)))

    if node.num_split is not None and node.num_split != split_dim_size:
        log.debug('The unpack where num to unpack is not equal to the size of the dimension to unpack is not supported')
        return

    if node.num_split is None:
        node.num_split = split_dim_size

    if split_dim_size % node.num_split != 0:
        log.error("split_dim cannot be evenly divided by a given number of parts")
        return

    outputs = node.out_nodes()
    split(node.in_node(), node, [outputs[i] for i in range(len(outputs))], split_dim,
          [int(split_dim_size / node.num_split)] * node.num_split)

    # Should eliminate dimension that is used for unpacking
    for k, output in outputs.items():
        if output.shape[split_dim] != 1:
            raise Error('Cannot deduce output shape for Unpack trying to squeeze dimension {} for shape {}, but it is not 1.'.format(split_dim, output.shape))
        output.shape = np.delete(output.shape, split_dim)
        if output.value is not None:
            output.value = np.squeeze(output.value, split_dim)
            assert np.all(output.shape == output.value.shape)

    # node shapes will be squeezed in the separate pass
    log.debug('output shapes after split: {}'.format([v.shape for k, v in outputs.items()]))