1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
|
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import mxnet as mx
from mo.utils.error import Error
from mo.utils.str_to import StrTo
from mo.utils.utils import refer_to_faq_msg
class AttrDictionary(object):
def __init__(self, dict):
self._dict = dict
def is_valid(self):
return not self._dict is None
def dict(self):
return self._dict
def add_dict(self, dict):
self._dict.update(dict)
def set(self, key, value):
self._dict[key] = value
def remove(self, key):
if key in self._dict:
del self._dict[key]
def str(self, key, default=None):
if not self.is_valid:
if default is None:
raise ValueError("Missing required parameter: " + key)
if key in self._dict:
return self._dict[key]
if default is None:
raise ValueError("Missing required parameter: " + key)
return default
def bool(self, key, default=None):
attr = self.str(key, default)
if isinstance(attr, str):
if attr.isdigit():
return bool(int(attr))
return StrTo.bool(attr)
else:
return attr
def float(self, key, default=None):
return self.val(key, float, default)
def int(self, key, default=None):
return self.val(key, int, default)
def tuple(self, key, valtype=str, default=None):
attr = self.str(key, default)
if isinstance(attr, str):
if (not attr) or (not attr[1:-1].split(',')[0]):
return tuple([valtype(x) for x in default])
return StrTo.tuple(valtype, attr)
else:
return tuple([valtype(x) for x in attr])
def list(self, key, valtype, default=None, sep=","):
attr = self.str(key, default)
if isinstance(attr, list):
attr = [valtype(x) for x in attr]
return attr
else:
return StrTo.list(attr, valtype, sep)
def val(self, key, valtype, default=None):
attr = self.str(key, default)
if valtype is None:
return attr
else:
if not isinstance(attr, valtype):
return valtype(attr)
else:
return attr
def has(self, key):
if not self.is_valid:
return False
else:
return key in self._dict
def get_mxnet_node_edges(node: dict, node_id: [int, str], nodes_list: list, index_node_key: dict):
edge_list = []
for in_port, src_node_id in enumerate(node['inputs']):
src_node = src_node_id[0]
edge_attrs = {
'in': in_port,
'out': 0, # TODO Check if src_node_id[1] should be here (already used as fw_tensor_debug_info)
# debug anchor for name of tensor consumed at this input port
'fw_tensor_debug_info': [(nodes_list[src_node]['name'], src_node_id[1])],
'in_attrs': ['in'],
'out_attrs': ['out'],
'data_attrs': ['fw_tensor_debug_info']
}
edge = (index_node_key[src_node], index_node_key[node_id], edge_attrs)
edge_list.append(edge)
return edge_list
def get_mxnet_layer_attrs(json_dic: dict):
attr = 'param'
if 'attr' in json_dic:
attr = 'attr'
elif 'attrs' in json_dic:
attr = 'attrs'
return AttrDictionary(json_dic[attr] if attr in json_dic else {})
def get_json_layer_attrs(json_dic):
attr = 'param'
if 'attr' in json_dic:
attr = 'attr'
elif 'attrs' in json_dic:
attr = 'attrs'
return json_dic[attr]
def load_params(input_model, data_names = ('data',)):
arg_params = {}
aux_params = {}
arg_keys = []
aux_keys = []
file_format = input_model.split('.')[-1]
loaded_weight = mx.nd.load(input_model)
if file_format == 'params':
for key in loaded_weight:
keys = key.split(':')
if len(keys)>1 and 'aux' == keys[0]:
aux_keys.append(keys[1])
aux_params[keys[1]] = loaded_weight[key]
elif len(keys)>1 and 'arg' == keys[0]:
arg_keys.append(keys[1])
arg_params[keys[1]] = loaded_weight[key]
elif file_format == 'nd':
for key in loaded_weight:
if 'auxs' in input_model:
aux_keys.append(key)
aux_params[key] = loaded_weight[key]
elif 'args' in input_model:
arg_keys.append(key)
arg_params[key] = loaded_weight[key]
else:
raise Error(
'Unsupported Input model file type {}. Model Optimizer support only .params and .nd files format. ' +
refer_to_faq_msg(85), file_format)
data = mx.sym.Variable(data_names[0])
model_params = mx.mod.Module(data, data_names=(data_names[0],), label_names=(data_names[0],))
model_params._arg_params = arg_params
model_params._aux_params = aux_params
model_params._param_names = arg_keys
model_params._aux_names = aux_keys
return model_params
|