summaryrefslogtreecommitdiff
path: root/model-optimizer/mo/ops/op.py
blob: 378d8a378f26659932c2476377ae7a1fcad9f64d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
"""
 Copyright (c) 2018 Intel Corporation

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import logging as log
from collections import namedtuple

import networkx as nx
import numpy as np

from mo.front.extractor import add_attrs_props
from mo.front.extractor import update_ie_fields
from mo.graph.graph import Node, unique_id
from mo.utils import class_registration
from mo.utils.error import Error


class Op(object):
    registered_ops = {}
    registered_cls = []
    # Add the derived class to excluded_classes if one should not be registered in registered_ops
    excluded_classes = []

    def __init__(self, graph: nx.MultiDiGraph, attrs1: dict = None, attrs2: dict = None):
        self.graph = graph
        try:
            self.ir_version = graph.graph['ir_version']
        except:
            self.ir_version = None

        self.attrs = {
            'precision': "FP32",
            'kind': 'op'
        }
        self.default_backend_attrs = []
        if attrs1 is not None:
            self.attrs.update(attrs1)
        if attrs2 is not None:
            self.attrs.update(attrs2)

    def add_node(self, attrs: dict = None):
        new_attrs = {}
        new_attrs.update(self.attrs)
        if attrs is not None:
            new_attrs.update(attrs)
        id_prefix = new_attrs['name'] if 'name' in new_attrs else ''
        id = unique_id(self.graph, id_prefix)
        new_attrs['name'] = id
        new_attrs = add_attrs_props(new_attrs)
        update_ie_fields(new_attrs, self.ir_version)
        self.substitute_ie_attrs(new_attrs)
        self.graph.add_node(id, **new_attrs)
        return Node(self.graph, id)

    def substitute_ie_attrs(self, new_attrs: dict):
        """
        Replace standard list of attribute in layer/data by attributes
        delivered by backend_attrs
        """
        backend_attrs_mapping = {
            None: self.backend_attrs,
            3: self.backend_attrs,
            2: self.backend_attrs_v2
        }

        if self.ir_version not in backend_attrs_mapping.keys():
            raise Error("Unrecognized IR version was specified: {}".format(self.ir_version))

        new_attrs.update({
            'IE': [(
                'layer',
                [('id', lambda node: node.node), 'name', 'precision', 'type'],
                [
                    ('data', backend_attrs_mapping[self.ir_version]() + self.default_backend_attrs, []),
                    '@ports',
                    '@consts'])]
        })

    @staticmethod
    def extract_port(node_port):
        if isinstance(node_port, tuple):
            node = node_port[0]
            port = node_port[1]
        else:
            node = node_port
            port = 0
        # 'data' nodes do not have 'out' edge attibute but always has one output
        out_ids = [attr['out'] for _, __, attr in node.graph.out_edges(node.id, data=True) if 'out' in attr]
        if len(set(out_ids)) > 1 and not isinstance(node_port, tuple):
            raise Error('Node {} has more than one outputs. Provide output port explicitly. '.format(node.name))
        return node, port

    def cut_edge_and_create_node(self, node: Node, out_port: int, attrs: dict = None):
        """
        Removes an edge, that is connected to nodes out_port. Creates new_node with attrs attributes and
        connects it to node by edge that stores the same information as cutted edge.
        :param node: Input node, to cut the edge from
        :param out_port: output port of edge to cut
        :param attrs: attributes of new node
        :return: Node instance of created new_node
        """
        edges = [(u, v, keys, params) for u, v, keys, params in node.graph.out_edges(node.id, data=True, keys=True)
                 if 'out' in params and params['out'] == out_port]
        edge_attrs = edges[0][3]
        [self.graph.remove_edge(u, v, key=key) for u, v, key, params in edges]
        if attrs is None:
            attrs = dict()
        new_node = self.add_node(attrs)
        self.graph.add_edge(node.id, new_node.id, **edge_attrs)
        return new_node

    def create_node(self, inputs: list = None, attrs: dict = None, edge_attrs: dict = None):
        # TODO pass also edge attributes to copy to newly created edges
        # TODO attrs should be matched with attrs()
        if inputs is not None:
            inputs = [Op.extract_port(inp) for inp in inputs]
        else:
            inputs = []
        if attrs is None:
            attrs = dict()
        new_node = self.add_node(attrs)
        # Missed careful handling of debug information
        for i, inp in enumerate(inputs):
            edge_attr = {'in': i, 'out': inp[1],
                         'in_attrs': ['in', 'permutation'],
                         'out_attrs': ['out', 'permutation'],
                         'data_attrs': []} if not inp[0].has_valid('kind') or inp[0].kind == 'op' \
                else {'in': i, 'in_attrs': ['in', 'permutation']}
            if edge_attrs is not None:
                edge_attr.update(edge_attrs)
            self.graph.add_edge(inp[0].id, new_node.id, **edge_attr)
        return new_node

    def create_node_with_data(self, inputs: list = None, attrs: dict = None,
                              data_nodes: [Node, np.ndarray, list] = None, edge_attrs: list = None):
        """
        Creates a new node with given inputs and attrs and also creates data node that
        holds the op output value. Inputs should be data nodes (not op nodes).
        Work for ops with a single output port only.
        Edge attributes in edge_attrs go in order of items in 'inputs'
        """
        if inputs is None:
            inputs = []
        if attrs is None:
            attrs = {}
        # No need to extract port, because input node should be a data node,
        # so there is no choice.
        new_op_node = self.add_node(attrs)

        # TODO Preserve debug infor
        inputs_with_edge_attrs = []
        for i, inp in enumerate(inputs):
            edge_attr = {'in': i}
            if edge_attrs is not None and i < len(edge_attrs):
                edge_attr.update(edge_attrs[i])
            inputs_with_edge_attrs.append((inp.id, new_op_node.id, edge_attr))
        
        self.graph.add_edges_from(inputs_with_edge_attrs)
        
        # TODO: Extend to the case when multiple output ports
        old_data_value = [None]
        old_data_shape = [None]
        if data_nodes is None:
            data_node = unique_id(self.graph)
            self.graph.add_node(data_node, **add_attrs_props(
                dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
                     infer=None)))
            data_nodes = [Node(self.graph, data_node)]
        else:
            if type(data_nodes) not in [list, np.ndarray]:
                data_nodes = [data_nodes]
            old_data_value = [data_node.value.copy() if data_node.has_valid('value') else None for data_node in
                              data_nodes]
            old_data_shape = [data_node.shape.copy() if data_node.has_valid('shape') else None for data_node in
                              data_nodes]
        for id, data_node in enumerate(data_nodes):
            self.graph.add_edges_from([(new_op_node.id, data_node.id, {'out': id})])
        if new_op_node.has_valid('infer'):
            log.debug('Start running infer function for individual op node with attributes: {}'.format(
                new_op_node.graph.node[new_op_node.id]))
            new_op_node.infer(new_op_node)
            assert all(old_value is None for old_value in old_data_value) or all(
                [np.array_equal(old_data_value[id], data_node.value) for id, data_node in enumerate(data_nodes)])
            assert all(old_shape is None for old_shape in old_data_shape) or all(
                [np.array_equal(old_data_shape[id], data_node.shape) for id, data_node in enumerate(data_nodes)])
            for data_node in data_nodes:
                log.debug(
                    'Finished running infer function, data nodes attributes: {}'.format(
                        data_node.graph.node[data_node.id]))
        return data_nodes[0] if len(data_nodes) == 1 else data_nodes

    @staticmethod
    def create_data_node(graph: nx.MultiDiGraph, op_node: Node, attrs: dict = None, edge_attrs: dict = None):
        assert op_node is not None and op_node.kind == 'op'
        assert len(op_node.out_nodes()) == 0
        if attrs is None:
            attrs = {}

        data_node = unique_id(graph, op_node.id)
        defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
                            infer=None)
        defaul_attrs.update(attrs)
        graph.add_node(data_node, **add_attrs_props(defaul_attrs))
        data_node = Node(graph, data_node)
        if edge_attrs is not None:
            graph.add_edges_from([(op_node.id, data_node.id, {'out': 0, **edge_attrs})])
        else:
            graph.add_edges_from([(op_node.id, data_node.id, {'out': 0})])
        return data_node

    @staticmethod
    def _create_data_node(graph: nx.MultiDiGraph, name: str, attrs: dict = None):
        if attrs is None:
            attrs = {}

        data_node = unique_id(graph, name)
        defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
                            infer=None)
        defaul_attrs.update(attrs)
        graph.add_node(data_node, **add_attrs_props(defaul_attrs))
        data_node = Node(graph, data_node)
        return data_node

    @staticmethod
    def create_input_data_node(graph: nx.MultiDiGraph, name: str, value: np.array, attrs: dict = {}):
        data_node = unique_id(graph, name)
        defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=np.array(value), shape=value.shape,
                            data_type=None, infer=None)
        defaul_attrs.update(attrs)
        graph.add_node(data_node, **add_attrs_props(defaul_attrs))
        return Node(graph, data_node)

    @staticmethod
    def create_and_connect_input_data_node(graph: nx.MultiDiGraph, op_node: Node, attrs: dict = None, edge_attrs: dict = None):
        assert op_node is not None and op_node.kind == 'op'
        if attrs is None:
            attrs = {}
        if edge_attrs is None:
            edge_attrs = {}

        data_node = unique_id(graph, op_node.id)
        defaul_attrs = dict(kind='data', precision="FP32", name=data_node, value=None, shape=None, data_type=None,
                            infer=None)
        defaul_attrs.update(attrs)
        graph.add_node(data_node, **add_attrs_props(defaul_attrs))
        data_node = Node(graph, data_node)
        graph.add_edges_from([(data_node.id, op_node.id, edge_attrs)])
        return data_node

    def update_node(self, node: Node, attrs: dict = None):
        """
        Updates/creates new attributes in node based on self.attrs and attrs.
        """
        new_attrs = {}
        new_attrs.update(self.attrs)
        if attrs:
            new_attrs.update(attrs)
        new_attrs = add_attrs_props(new_attrs)
        update_ie_fields(new_attrs, self.ir_version)
        self.substitute_ie_attrs(new_attrs)
        for k, v in new_attrs.items():
            node[k] = v

    @classmethod
    def update_node_stat(cls, node: Node, attrs: dict = None):
        if attrs is None:
            attrs = dict()
        op = cls(node.graph, attrs)
        op.update_node(node)

    def supported_attrs(self):
        """
        Attributes that user should/can set for the operation
        """
        return []

    def backend_attrs(self):
        """
        Attributes that will be translated to back-end IR
        """
        return self.supported_attrs()

    def backend_attrs_v2(self):
        return self.backend_attrs()

    @staticmethod
    def get_op_class_by_name(name: str):
        return __class__.registered_ops[name]

    @classmethod
    def class_type(cls):
        return class_registration.ClassType.OP

    @staticmethod
    def expand_node_shape(node: Node, dims_to_add):
        if node is None or not node.has_valid('value'):
            return
        for idx in range(dims_to_add):
            node.value = np.expand_dims(node.value, axis=-1)
        node.shape = np.array(node.value.shape)


class PermuteAttrs:
    Permutation = namedtuple('Permutation', ['perm', 'inv'])
    Attr = namedtuple('Attr', ['name', 'port', 'func'])

    common_permutation = lambda node, permutation, attr: node[attr][permutation.perm]
    common_permutation_inv = lambda node, permutation, attr: permutation.inv[node[attr]]

    # List of default permutations
    common_attrs_permutation = {
            'dim': common_permutation,
            'pad': common_permutation,
            'shape': common_permutation,
            'order': lambda node, permutation, attr: permutation.inv[node[attr][permutation.perm]],
            'stride': common_permutation,
            'window': common_permutation,
            'dilation': common_permutation,
            'kernel_shape': common_permutation,
            'output_shape': common_permutation,
            'slices': common_permutation,
            'shrink_axis_mask': common_permutation,
            'new_axis_mask': common_permutation,

            'axis': common_permutation_inv,
            'batch_dims': common_permutation_inv,
            'channel_dims': common_permutation_inv,
            'spatial_dims': common_permutation_inv,

            'input_channel_dim': common_permutation_inv,
            'output_channel_dim': common_permutation_inv,
            'kernel_spatial_idx': common_permutation_inv,
            'input_feature_channel': common_permutation_inv,
            'output_feature_channel': common_permutation_inv,
    }

    @staticmethod
    def __attr(name, port, func=None):
        if func is None:
            if name in PermuteAttrs.common_attrs_permutation:
                func = PermuteAttrs.common_attrs_permutation[name]
            else:
                raise Error('Attr {} is missing in PermuteAttrs.common_attrs_permutation. Please update '
                            'common_attrs_permutation with permutation for your attribute!'.format(name))

        if len(port.split(':')) != 2 or port.split(':')[0] not in ['input', 'output']:
            raise Error("Attribute port {} for {} wasn't set correctly!".format(port, name))

        return PermuteAttrs.Attr(name=name, port=port, func=func)

    def __init__(self):
        self.attrs = {}

    def update_attrs(self, attrs):
        for attr in attrs:
            if not isinstance(attr, tuple) or len(attr) not in [2, 3]:
                raise Error('attr object must be a tuple: (attribute_name, port) or (attribute_name, port, func)')
            self.attrs.update({attr[0]: self.__attr(*attr)})
        return self

    def permute_attrs(self, node):
        # This function applies permutation for given node
        for attr in self.attrs.keys():
            name, port, func = self.attrs[attr]
            node_type, port = port.split(':')
            port = int(port)
            node_with_permutation = node.in_node(port) if node_type == 'input' else node.out_node(port)

            if node_with_permutation.has_valid('permutation'):
                permutation = node_with_permutation.permutation
                if isinstance(permutation, type(lambda: 0)):
                    node[name] = func(node, permutation(node), name)
                else:
                    node[name] = func(node, permutation, name)

    @staticmethod
    def create_permute_attrs(node, attrs=None):
        # Create permute_attrs if not exists
        if not node.has_valid('permute_attrs'):
            node['permute_attrs'] = PermuteAttrs()
        node['permute_attrs'].update_attrs(attrs)

    @staticmethod
    def set_permutation(node1, node2, permutation, skip_if_exists=False):
        # This function creates permutation on edge between node1->node2
        edge_attrs = node1.graph.get_edge_data(node1.id, node2.id)[0]
        if 'permutation' not in edge_attrs:
            nx.set_edge_attributes(G=node1.graph,
                                   values={(node1.id, node2.id, 0): permutation},
                                   name='permutation')
        else:
            if skip_if_exists:
                return
            raise Error('Permutation already exists in edge between {} and {}'.format(node1.name, node2.name))

    @staticmethod
    def get_inverse_permutation(perm):
        inv = [0] * len(perm)
        # Create reverse permutation
        for index, pos in enumerate(perm):
            inv[pos] = index
        return inv

    @staticmethod
    def get_nhwc_to_nchw_permutation(dims_number: int):
        # This function returns permutation from NHWC to NCHW for given dims number
        if dims_number != 3:
            perm = [0, dims_number - 1, *[x for x in range(1, dims_number - 1)]] if dims_number > 1 else [x for x in range(
                dims_number)]
        else:
            # Exclude 3D shapes from permutation process: identity permutation
            perm = list(range(0, dims_number))
        inv = PermuteAttrs.get_inverse_permutation(perm)
        return PermuteAttrs.Permutation(perm=np.array(perm), inv=np.array(inv))

    @staticmethod
    def get_nchw_to_nhwc_permutation(dims_number: int):
        # This function returns permutation from NCHW to NHWC for given dims number
        if dims_number != 3:
            perm = [0, *[x for x in range(2, dims_number)], 1] if dims_number > 1 else [x for x in range(
                dims_number)]
        else:
            # Exclude 3D shapes from permutation process: identity permutation
            perm = list(range(0, dims_number))
        inv = PermuteAttrs.get_inverse_permutation(perm)
        return PermuteAttrs.Permutation(perm=np.array(perm), inv=np.array(inv))