summaryrefslogtreecommitdiff
path: root/model-optimizer/mo/front/mxnet/extractors/utils.py
blob: 36f7ba3fa4183fe4c8609960a3d01b21d56e0176 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
 Copyright (c) 2018 Intel Corporation

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import mxnet as mx

from mo.utils.error import Error
from mo.utils.str_to import StrTo
from mo.utils.utils import refer_to_faq_msg


class AttrDictionary(object):
    def __init__(self, dict):
        self._dict = dict

    def is_valid(self):
        return not self._dict is None

    def dict(self):
        return self._dict

    def add_dict(self, dict):
        self._dict.update(dict)

    def set(self, key, value):
        self._dict[key] = value

    def remove(self, key):
        if key in self._dict:
            del self._dict[key]

    def str(self, key, default=None):
        if not self.is_valid:
            if default is None:
                raise ValueError("Missing required parameter: " + key)
        if key in self._dict:
            return self._dict[key]
        if default is None:
            raise ValueError("Missing required parameter: " + key)
        return default

    def bool(self, key, default=None):
        attr = self.str(key, default)
        if isinstance(attr, str):
            if attr.isdigit():
                return bool(int(attr))
            return StrTo.bool(attr)
        else:
            return attr

    def float(self, key, default=None):
        return self.val(key, float, default)

    def int(self, key, default=None):
        return self.val(key, int, default)

    def tuple(self, key, valtype=str, default=None):
        attr = self.str(key, default)
        if isinstance(attr, str):
            if (not attr) or (not attr[1:-1].split(',')[0]):
                return tuple([valtype(x) for x in default])
            return StrTo.tuple(valtype, attr)
        else:
            return tuple([valtype(x) for x in attr])

    def list(self, key, valtype, default=None, sep=","):
        attr = self.str(key, default)
        if isinstance(attr, list):
            attr = [valtype(x) for x in attr]
            return attr
        else:
            return StrTo.list(attr, valtype, sep)

    def val(self, key, valtype, default=None):
        attr = self.str(key, default)
        if valtype is None:
            return attr
        else:
            if not isinstance(attr, valtype):
                return valtype(attr)
            else:
                return attr

    def has(self, key):
        if not self.is_valid:
            return False
        else:
            return key in self._dict


def get_mxnet_node_edges(node: dict, node_id: [int, str], nodes_list: list, index_node_key: dict):
    edge_list = []
    for in_port, src_node_id in enumerate(node['inputs']):
        src_node = src_node_id[0]
        edge_attrs = {
            'in': in_port,
            'out': 0,  # TODO Check if src_node_id[1] should be here (already used as fw_tensor_debug_info)
            # debug anchor for name of tensor consumed at this input port
            'fw_tensor_debug_info': [(nodes_list[src_node]['name'], src_node_id[1])],
            'in_attrs': ['in'],
            'out_attrs': ['out'],
            'data_attrs': ['fw_tensor_debug_info']
        }
        edge = (index_node_key[src_node], index_node_key[node_id], edge_attrs)
        edge_list.append(edge)
    return edge_list


def get_mxnet_layer_attrs(json_dic: dict):
    attr = 'param'
    if 'attr' in json_dic:
        attr = 'attr'
    elif 'attrs' in json_dic:
        attr = 'attrs'
    return AttrDictionary(json_dic[attr] if attr in json_dic else {})


def get_json_layer_attrs(json_dic):
    attr = 'param'
    if 'attr' in json_dic:
        attr = 'attr'
    elif 'attrs' in json_dic:
        attr = 'attrs'
    return json_dic[attr]


def load_params(input_model, data_names = ('data',)):
    arg_params = {}
    aux_params = {}
    arg_keys = []
    aux_keys = []
    file_format = input_model.split('.')[-1]
    loaded_weight = mx.nd.load(input_model)
    if file_format == 'params':
        for key in loaded_weight:
            keys = key.split(':')
            if len(keys)>1 and 'aux' == keys[0]:
                aux_keys.append(keys[1])
                aux_params[keys[1]] = loaded_weight[key]
            elif len(keys)>1 and 'arg' == keys[0]:
                arg_keys.append(keys[1])
                arg_params[keys[1]] = loaded_weight[key]
    elif file_format == 'nd':
        for key in loaded_weight:
            if 'auxs' in input_model:
                aux_keys.append(key)
                aux_params[key] = loaded_weight[key]
            elif 'args' in input_model:
                arg_keys.append(key)
                arg_params[key] = loaded_weight[key]
    else:
        raise Error(
            'Unsupported Input model file type {}. Model Optimizer support only .params and .nd files format. ' +
            refer_to_faq_msg(85), file_format)

    data = mx.sym.Variable(data_names[0])
    model_params = mx.mod.Module(data, data_names=(data_names[0],), label_names=(data_names[0],))
    model_params._arg_params = arg_params
    model_params._aux_params = aux_params
    model_params._param_names = arg_keys
    model_params._aux_names = aux_keys
    return model_params