summaryrefslogtreecommitdiff
path: root/model-optimizer/mo/front/common/partial_infer/eltwise.py
blob: 53c3330627391b4392e06bb5ea022a0eadc3e282 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
 Copyright (c) 2018 Intel Corporation

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import numpy as np
import logging as log

from mo.graph.graph import get_sorted_inputs, Node


def eltwise_infer(node, op=None, **kwargs):
    inputs = [Node(node.graph, inp) for inp, attr in get_sorted_inputs(node)
              if 'control_flow_edge' not in attr or not attr['control_flow_edge']]
    shapes = [node.graph.node[inp.id]['shape'] for inp in inputs]
    values = [node.graph.node[inp.id]['value'] for inp in inputs]

    # infer output shape based on input shapes without op involvement
    # based on repeated application of rules https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html

    if any([s is None for s in shapes]):
        # nothing is known
        return

    if node.has_valid('axis') and all([value is None for value in values]):
        log.error('Eltwise operation with axis is not supported')
        return

    def check_value(value):
        # Check that value has shape like N,1,1
        return np.prod(value.shape) == np.max(value.shape) and \
                       value.shape[0] == np.max(value.shape)

    # make all input shapes of the same size by adding leading 1's
    max_dims = max([len(s) for s in shapes])
    # In case of not None axis, we are extending shape for not None values according to axis
    if node.has_valid('axis'):
        # Check if axis match feature dim and values shapes suits so that is ok, else we mark this op with can_be_fused=False
        if node.axis == node.graph.graph['feature_dim'] and \
           all([check_value(value) for value in values if value is not None]):
            for id, value in enumerate(values):
                if value is not None:
                    # Expand dims for value
                    dims_to_add = max_dims - node.axis - len(value.shape) # how much 1 we should add to the shape
                    if dims_to_add < 0:
                        log.error('Axis attribute for {} node is wrong (axis={}, input_shapes={})'.format(node.name, node.axis, shapes))
                        return
                    # Update values and shapes with new shape
                    shape = np.append(value.shape, [1]*dims_to_add).astype(dtype=np.int64)
                    value = np.reshape(value, shape)
                    shapes[id], values[id] = np.array(shape), np.array(value)
                    # Update node weights & shape
                    inputs[id].value, inputs[id].shape = np.array(value), np.array(shape)
        else:
            node['can_be_fused'] = False


    extended_shapes = [np.concatenate((np.ones(max_dims - len(s), dtype=np.int64), s)) for s in shapes]
    # ugly but clear solution
    output_shape = extended_shapes[0]
    for si in range(1, len(extended_shapes)):
        for ei in range(max_dims):
            mind = min(output_shape[ei], extended_shapes[si][ei])
            maxd = max(output_shape[ei], extended_shapes[si][ei])
            if mind == -1:
                output_shape[ei] = -1
            elif mind == 1:
                output_shape[ei] = maxd
            elif mind != maxd:
                output_shape[ei] = -1
    node.out_node().shape = output_shape

    if op is None or any([v is None for v in values]):
        return

    node.out_node().value = op(*values, **kwargs)