summaryrefslogtreecommitdiff
path: root/model-optimizer/mo/utils
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/mo/utils')
-rw-r--r--model-optimizer/mo/utils/class_registration.py29
-rw-r--r--model-optimizer/mo/utils/cli_parser.py153
-rw-r--r--model-optimizer/mo/utils/custom_replacement_config.py18
-rw-r--r--model-optimizer/mo/utils/dsu.py18
-rw-r--r--model-optimizer/mo/utils/error.py8
-rw-r--r--model-optimizer/mo/utils/graph.py19
-rw-r--r--model-optimizer/mo/utils/logger.py5
-rw-r--r--model-optimizer/mo/utils/pipeline_config.py142
-rw-r--r--model-optimizer/mo/utils/replacement_pattern.py2
-rw-r--r--model-optimizer/mo/utils/simple_proto_parser.py12
-rw-r--r--model-optimizer/mo/utils/str_to.py2
-rw-r--r--model-optimizer/mo/utils/summarize_graph.py30
-rw-r--r--model-optimizer/mo/utils/utils.py13
-rw-r--r--model-optimizer/mo/utils/versions_checker.py1
14 files changed, 318 insertions, 134 deletions
diff --git a/model-optimizer/mo/utils/class_registration.py b/model-optimizer/mo/utils/class_registration.py
index 3fcc9a268..f69978b1d 100644
--- a/model-optimizer/mo/utils/class_registration.py
+++ b/model-optimizer/mo/utils/class_registration.py
@@ -14,12 +14,14 @@
limitations under the License.
"""
-from enum import Enum
import logging as log
+from enum import Enum
+
import networkx as nx
from mo.utils.error import Error
from mo.utils.utils import refer_to_faq_msg
+from mo.graph.graph import check_empty_graph
_registered_classes_dict = {}
@@ -74,8 +76,8 @@ def apply_replacements(graph: nx.MultiDiGraph, replacements_type):
for class_type, classes_set in _registered_classes_dict.items():
if class_type == replacements_type:
for cls in classes_set:
- replacers = [c for c in cls.registered_cls if not hasattr(c, 'op')] + [c for op, c in cls.registered_ops.items()
- if c]
+ replacers = [c for c in cls.registered_cls if not hasattr(c, 'op')] + \
+ [c for op, c in cls.registered_ops.items() if c]
for replacer_cls in replacers:
if replacer_cls in cls.excluded_replacers:
# skip infrastructure classes
@@ -97,6 +99,23 @@ def apply_replacements(graph: nx.MultiDiGraph, replacements_type):
' -> '.join([str(node) for node in list(cycles)[0]])) from exception
for replacer_cls in replacers_order:
- log.debug("Run replacer {}".format(replacer_cls))
replacer = replacer_cls()
- replacer.find_and_replace_pattern(graph)
+ replacement_id = 'REPLACEMENT_ID'
+ if hasattr(replacer, 'replacement_id'):
+ replacement_id = replacer.replacement_id
+
+ if hasattr(replacer, 'enabled') and not replacer.enabled:
+ log.info("Skip replacer {} (enabled = False)".format(replacer_cls))
+ continue
+
+ log.debug("Run replacer {}".format(replacer_cls))
+
+ try:
+ replacer.find_and_replace_pattern(graph)
+ check_empty_graph(graph, replacer_cls)
+ except Error as err:
+ raise Error('Exception occurred during running replacer "{}": {}'.format(replacement_id, str(err).replace(
+ '[REPLACEMENT_ID]', replacement_id))) from err
+ except Exception as err:
+ raise Exception('Exception occurred during running replacer "{}": {}'.format(
+ replacement_id, str(err).replace('[REPLACEMENT_ID]', replacement_id))) from err
diff --git a/model-optimizer/mo/utils/cli_parser.py b/model-optimizer/mo/utils/cli_parser.py
index 85c7017a5..a3a6460ba 100644
--- a/model-optimizer/mo/utils/cli_parser.py
+++ b/model-optimizer/mo/utils/cli_parser.py
@@ -33,6 +33,7 @@ class CanonicalizePathAction(argparse.Action):
"""
Expand user home directory paths and convert relative-paths to absolute.
"""
+
def __call__(self, parser, namespace, values, option_string=None):
if values is not None:
list_of_values = list()
@@ -52,6 +53,7 @@ class CanonicalizePathCheckExistenceAction(CanonicalizePathAction):
Expand user home directory paths and convert relative-paths to absolute and check specified file or directory
existence.
"""
+
def __call__(self, parser, namespace, values, option_string=None):
super().__call__(parser, namespace, values, option_string)
names = getattr(namespace, self.dest)
@@ -293,20 +295,17 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
action=CanonicalizePathAction,
type=writable_dir)
common_group.add_argument('--input_shape',
- help='Input shape(s) that should be fed to an input node(s) of the model. ' +
- 'Shape is defined as a comma-separated list of integer numbers enclosed in parentheses, ' +
- 'for example [1,3,227,227] or [1,227,227,3], where the order of dimensions ' +
- 'depends on the framework input layout of the model. ' +
- 'For example, [N,C,H,W] is used for Caffe* models and [N,H,W,C] for TensorFlow* models. '
- 'Model Optimizer performs necessary transformations ' +
- 'to convert the shape to the layout required by Inference Engine (N,C,H,W). '
- 'Two types of brackets are allowed to enclose the dimensions: [...] or (...). ' +
- 'The shape ' +
- 'should not contain undefined dimensions (? or -1) and should ' +
- 'fit the dimensions defined in the input ' +
- 'operation of the graph. If there are multiple inputs in the model, --input_shape ' +
- 'should contain definition of shape for each input separated by a comma, for example: ' +
- '[1,3,227,227],[2,4] for a model with two inputs with 4D and 2D shapes.')
+ help='Input shape(s) that should be fed to an input node(s) of the model. '
+ 'Shape is defined as a comma-separated list of integer numbers enclosed in '
+ 'parentheses or square brackets, for example [1,3,227,227] or (1,227,227,3), where '
+ 'the order of dimensions depends on the framework input layout of the model. '
+ 'For example, [N,C,H,W] is used for Caffe* models and [N,H,W,C] for TensorFlow* '
+ 'models. Model Optimizer performs necessary transformations to convert the shape to '
+ 'the layout required by Inference Engine (N,C,H,W). The shape should not contain '
+ 'undefined dimensions (? or -1) and should fit the dimensions defined in the input '
+ 'operation of the graph. If there are multiple inputs in the model, --input_shape '
+ 'should contain definition of shape for each input separated by a comma, for '
+ 'example: [1,3,227,227],[2,4] for a model with two inputs with 4D and 2D shapes.')
common_group.add_argument('--scale', '-s',
type=float,
help='All input values coming from original network inputs will be ' +
@@ -316,11 +315,11 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
'is not applied for any input that does not match with ' +
'the original input of the model.')
common_group.add_argument('--reverse_input_channels',
- help='Switches the input channels order from RGB to BGR (or vice versa). ' +
- 'Applied to original inputs of the model when and only when ' +
- 'a number of channels equals 3. Applied after application of ' +
- '--mean_values and --scale_values options, so numbers in --mean_values and --scale_values ' +
- 'go in the order of channels used in the original model.',
+ help='Switch the input channels order from RGB to BGR (or vice versa). Applied to '
+ 'original inputs of the model if and only if a number of channels equals 3. Applied '
+ 'after application of --mean_values and --scale_values options, so numbers in '
+ '--mean_values and --scale_values go in the order of channels used in the original '
+ 'model.',
action='store_true')
common_group.add_argument('--log_level',
help='Logger level',
@@ -337,16 +336,16 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
common_group.add_argument('--mean_values', '-ms',
help='Mean values to be used for the input image per channel. ' +
'Values to be provided in the (R,G,B) or [R,G,B] format. ' +
- 'Can be defined for desired input of the model, e.g.: ' +
- '"--mean_values data[255,255,255],info[255,255,255]"' +
- ' The exact meaning and order ' +
+ 'Can be defined for desired input of the model, for example: ' +
+ '"--mean_values data[255,255,255],info[255,255,255]". ' +
+ 'The exact meaning and order ' +
'of channels depend on how the original model was trained.',
default=())
common_group.add_argument('--scale_values',
help='Scale values to be used for the input image per channel. ' +
- 'Values are provided in the (R,G,B) or [R,G,B] format.' +
- 'Can be defined for desired input of the model, e.g.: ' +
- '"--scale_values data[255,255,255],info[255,255,255]"' +
+ 'Values are provided in the (R,G,B) or [R,G,B] format. ' +
+ 'Can be defined for desired input of the model, for example: ' +
+ '"--scale_values data[255,255,255],info[255,255,255]". ' +
'The exact meaning and order ' +
'of channels depend on how the original model was trained.',
default=())
@@ -358,16 +357,16 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
choices=["FP16", "FP32", "half", "float"],
default='float')
common_group.add_argument('--disable_fusing',
- help='Turns off fusing of linear operations to Convolution',
+ help='Turn off fusing of linear operations to Convolution',
action='store_true')
common_group.add_argument('--disable_resnet_optimization',
- help='Turns off resnet optimization',
+ help='Turn off resnet optimization',
action='store_true')
common_group.add_argument('--finegrain_fusing',
help='Regex for layers/operations that won\'t be fused. ' +
'Example: --finegrain_fusing Convolution1,.*Scale.*')
common_group.add_argument('--disable_gfusing',
- help='Turns off fusing of grouped convolutions',
+ help='Turn off fusing of grouped convolutions',
action='store_true')
common_group.add_argument('--move_to_preprocess',
help='Move mean values to IR preprocess section',
@@ -389,14 +388,22 @@ def get_common_cli_parser(parser: argparse.ArgumentParser = None):
help="Version of Model Optimizer")
common_group.add_argument('--silent',
- help='Prevents any output messages except those that correspond to log level equals'
+ help='Prevent any output messages except those that correspond to log level equals '
'ERROR, that can be set with the following option: --log_level. '
'By default, log level is already ERROR. ',
action='store_true',
default=False)
- common_group.add_argument('--freeze_placeholder_with_value', help='Replace input layer with constant node with '
- 'provided value, e.g.: node_name->True',
+ common_group.add_argument('--freeze_placeholder_with_value', help='Replaces input layer with constant node with '
+ 'provided value, e.g.: "node_name->True"',
default=None)
+ common_group.add_argument('--generate_deprecated_IR_V2',
+ help='Force to generate legacy/deprecated IR V2 to work with previous versions of the'
+ ' Inference Engine. The resulting IR may or may not be correctly loaded by'
+ ' Inference Engine API (including the most recent and old versions of Inference'
+ ' Engine) and provided as a partially-validated backup option for specific'
+ ' deployment scenarios. Use it at your own discretion. By default, without this'
+ ' option, the Model Optimizer generates IR V3.',
+ action='store_true')
return parser
@@ -443,6 +450,7 @@ def get_tf_cli_options():
'tensorflow_use_custom_operations_config': '- Use the config file',
'tensorflow_object_detection_api_pipeline_config': '- Use configuration file used to generate the model with '
'Object Detection API',
+ 'tensorflow_custom_layer_libraries': '- List of shared libraries with TensorFlow custom layers implementation',
'tensorboard_logdir': '- Path to model dump for TensorBoard'
}
@@ -504,7 +512,7 @@ def get_caffe_cli_parser(parser: argparse.ArgumentParser = None):
action=CanonicalizePathCheckExistenceAction)
caffe_group.add_argument('--mean_file', '-mf',
help='Mean image to be used for the input. Should be a binaryproto file',
- default="",
+ default=None,
action=CanonicalizePathCheckExistenceAction)
caffe_group.add_argument('--mean_file_offsets', '-mo',
help='Mean image offsets to be used for the input binaryproto file. ' +
@@ -543,17 +551,17 @@ def get_tf_cli_parser(parser: argparse.ArgumentParser = None):
tf_group = parser.add_argument_group('TensorFlow*-specific parameters')
tf_group.add_argument('--input_model_is_text',
- help='TensorFlow*: treat the input model file in a text protobuf format ' +
- 'instead of ' +
- 'binary, which is default.',
+ help='TensorFlow*: treat the input model file as a text protobuf format. If not specified, ' +
+ 'the Model Optimizer treats it as a binary file by default.',
action='store_true')
- tf_group.add_argument('--input_checkpoint', type=str, default="", help="TensorFlow*: variables file to load.",
+ tf_group.add_argument('--input_checkpoint', type=str, default=None, help="TensorFlow*: variables file to load.",
action=CanonicalizePathCheckExistenceAction)
tf_group.add_argument('--input_meta_graph',
- help='Tensorflow*: a file with a non-trained model before freezing',
+ help='Tensorflow*: a file with a meta-graph of the model before freezing',
action=CanonicalizePathCheckExistenceAction,
type=readable_file)
- tf_group.add_argument('--saved_model_dir', default=None, help="TensorFlow*: directory representing non frozen model",
+ tf_group.add_argument('--saved_model_dir', default=None,
+ help="TensorFlow*: directory representing non frozen model",
action=CanonicalizePathCheckExistenceAction,
type=readable_dirs)
tf_group.add_argument('--saved_model_tags', type=str, default=None,
@@ -585,6 +593,14 @@ def get_tf_cli_parser(parser: argparse.ArgumentParser = None):
help='TensorFlow*: dump the input graph to a given directory that should be used with TensorBoard.',
default=None,
action=CanonicalizePathCheckExistenceAction)
+ tf_group.add_argument('--tensorflow_custom_layer_libraries',
+ help='TensorFlow*: comma separated list of shared libraries with TensorFlow* custom '
+ 'operations implementation.',
+ default=None,
+ action=CanonicalizePathCheckExistenceAction)
+ tf_group.add_argument('--disable_nhwc_to_nchw',
+ help='Disables default translation from NHWC to NCHW',
+ action='store_true')
return parser
@@ -604,15 +620,15 @@ def get_mxnet_cli_parser(parser: argparse.ArgumentParser = None):
mx_group.add_argument('--input_symbol',
help='Symbol file (for example, model-symbol.json) that contains a topology structure ' +
- 'and layer attributes',
+ 'and layer attributes',
type=str,
action=CanonicalizePathCheckExistenceAction)
mx_group.add_argument("--nd_prefix_name",
help="Prefix name for args.nd and argx.nd files.",
- default="")
+ default=None)
mx_group.add_argument("--pretrained_model_name",
help="Pretrained model without extension and epoch number which will be merged with args.nd and argx.nd files.",
- default="")
+ default=None)
mx_group.add_argument("--save_params_from_nd",
action='store_true',
help="Enable save built params file from nd files.")
@@ -634,16 +650,16 @@ def get_kaldi_cli_parser(parser: argparse.ArgumentParser = None):
parser = argparse.ArgumentParser()
get_common_cli_parser(parser=parser)
- mx_group = parser.add_argument_group('Kaldi-specific parameters')
+ kaldi_group = parser.add_argument_group('Kaldi-specific parameters')
- mx_group.add_argument("--counts",
- help="Path to the counts file",
- default="",
- action=CanonicalizePathCheckExistenceAction)
+ kaldi_group.add_argument("--counts",
+ help="Path to the counts file",
+ default=None,
+ action=CanonicalizePathCheckExistenceAction)
- mx_group.add_argument("--remove_output_softmax",
- help="Removes the Softmax layer that is the output layer",
- action='store_true')
+ kaldi_group.add_argument("--remove_output_softmax",
+ help="Removes the Softmax layer that is the output layer",
+ action='store_true')
return parser
@@ -780,6 +796,11 @@ def parse_tuple_pairs(argv_values: str):
data_str = argv_values
while True:
tuples_matches = re.findall(r'[(\[]([0-9., -]+)[)\]]', data_str, re.IGNORECASE)
+ if not tuples_matches :
+ raise Error(
+ "Mean/scale values should be in format: data(1,2,3),info(2,3,4)" +
+ " or just plain set of them without naming any inputs: (1,2,3),(2,3,4). " +
+ refer_to_faq_msg(101), argv_values)
tuple_value = tuples_matches[0]
matches = data_str.split(tuple_value)
@@ -793,9 +814,8 @@ def parse_tuple_pairs(argv_values: str):
# error - tuple with name is also specified
raise Error(
"Mean/scale values should either contain names of input layers: data(1,2,3),info(2,3,4)" +
- " or just plain set of them without naming any inputs: (1,2,3),(2,3,4). " +
- "For more information, please refer to to Model Optimizer FAQ, question #84.".format(
- argv_values))
+ " or just plain set of them without naming any inputs: (1,2,3),(2,3,4)." +
+ refer_to_faq_msg(101), argv_values)
for match in tuples_matches:
res.append(np.fromstring(match, dtype=float, sep=','))
break
@@ -1000,3 +1020,32 @@ def check_positive(value):
return int_value
+
+def depersonalize(value: str):
+ if not isinstance(value, str):
+ return value
+ res = []
+ for path in value.split(','):
+ if os.path.isdir(path):
+ res.append('DIR')
+ elif os.path.isfile(path):
+ res.append(os.path.join('DIR', os.path.split(path)[1]))
+ else:
+ res.append(path)
+ return ','.join(res)
+
+
+def get_meta_info(argv: argparse.Namespace):
+ meta_data = {'unset': []}
+ for key, value in argv.__dict__.items():
+ if value is not None:
+ value = depersonalize(value)
+ meta_data[key] = value
+ else:
+ meta_data['unset'].append(key)
+ # The attribute 'k' is treated separately because it points to not existing file by default
+ for key in ['k']:
+ if key in meta_data:
+ meta_data[key] = ','.join([os.path.join('DIR', os.path.split(i)[1]) for i in meta_data[key].split(',')])
+ return meta_data
+
diff --git a/model-optimizer/mo/utils/custom_replacement_config.py b/model-optimizer/mo/utils/custom_replacement_config.py
index 5fdd00540..8709e19a5 100644
--- a/model-optimizer/mo/utils/custom_replacement_config.py
+++ b/model-optimizer/mo/utils/custom_replacement_config.py
@@ -87,11 +87,14 @@ class CustomReplacementDescriptor(object):
raise Exception("The function 'get_sub_graph_instances' must be implemented in the sub-class.")
def get_config_file_representation(self):
- return {
+ result = {
'match_kind': self.match_kind, 'instances': self.instances,
'inputs': self.inputs, 'outputs': self.outputs,
- 'custom_attributes': self.custom_attributes, 'id': self.id,
+ 'custom_attributes': self.custom_attributes, 'id': self.id
}
+ if self.has('op'):
+ result.update({'op': self.op})
+ return result
def get_inputs_description(self):
"""
@@ -155,17 +158,20 @@ class CustomReplacementDescriptorPoints(CustomReplacementDescriptor):
def __init__(self, replacement_id: str, attrs: dict = None):
super().__init__(replacement_id, attrs)
if not self.has('include_inputs_to_sub_graph'):
- super(CustomReplacementDescriptor, self).__setattr__('include_inputs_to_sub_graph', True)
+ super(CustomReplacementDescriptorPoints, self).__setattr__('include_inputs_to_sub_graph', True)
if not self.has('include_outputs_to_sub_graph'):
- super(CustomReplacementDescriptor, self).__setattr__('include_outputs_to_sub_graph', True)
+ super(CustomReplacementDescriptorPoints, self).__setattr__('include_outputs_to_sub_graph', True)
def get_config_file_representation(self):
- return {
+ result = {
'match_kind': self.match_kind, 'instances': self.instances,
'custom_attributes': self.custom_attributes, 'id': self.id,
'include_inputs_to_sub_graph': bool(self.include_inputs_to_sub_graph),
- 'include_outputs_to_sub_graph': bool(self.include_outputs_to_sub_graph),
+ 'include_outputs_to_sub_graph': bool(self.include_outputs_to_sub_graph)
}
+ if self.has('op'):
+ result.update({'op': self.op})
+ return result
def get_inputs_description(self):
return [[('^' + node_name + '$', 0)] for node_name in self.instances['start_points']]
diff --git a/model-optimizer/mo/utils/dsu.py b/model-optimizer/mo/utils/dsu.py
index 6a912ad67..849db9008 100644
--- a/model-optimizer/mo/utils/dsu.py
+++ b/model-optimizer/mo/utils/dsu.py
@@ -15,37 +15,41 @@
"""
-class DSU_elem:
+class DSUElem:
"""
An object that represents one DSU element.
"""
+ name = ''
+ parent = ''
+ rank = 1
def __init__(self, name):
- self.__setattr__('name', name)
- self.__setattr__('parent', name)
- self.__setattr__('rank', 1)
+ self.name = name
+ self.parent = name
+ self.rank = 1
class DSU:
"""
Naive implementation of the "disjoint set union" data structure.
"""
+ map = dict()
def __init__(self, elems: list):
- self.__setattr__('map', {elem.name: elem for elem in elems})
+ self.map = {elem.name: elem for elem in elems}
pass
def find_elem(self, name: str):
return self.map[name]
- def find_parent(self, elem: DSU_elem):
+ def find_parent(self, elem: DSUElem):
if elem.parent == elem.name:
return elem
parent_elem = self.find_parent(self.find_elem(elem.parent))
elem.parent = parent_elem.name
return parent_elem
- def union(self, elem1: DSU_elem, elem2: DSU_elem):
+ def union(self, elem1: DSUElem, elem2: DSUElem):
elem1 = self.find_parent(elem1)
elem2 = self.find_parent(elem2)
if elem1.name == elem2.name: # already in the same set
diff --git a/model-optimizer/mo/utils/error.py b/model-optimizer/mo/utils/error.py
index 59b794c85..4b188668b 100644
--- a/model-optimizer/mo/utils/error.py
+++ b/model-optimizer/mo/utils/error.py
@@ -27,7 +27,12 @@ class BasicError(Exception):
def __str__(self):
if len(self.args) <= 1:
return Exception.__str__(self)
- return self.args[0].format(*self.args[1:])
+ return self.args[0].format(*self.args[1:]) # pylint: disable=unsubscriptable-object
+
+
+class FrameworkError(BasicError):
+ """ User-friendly error: raised when the error on the framework side. """
+ pass
class Error(BasicError):
@@ -38,3 +43,4 @@ class Error(BasicError):
class InternalError(BasicError):
""" Not user-friendly error: user cannot fix it and it points to the bug inside MO. """
pass
+
diff --git a/model-optimizer/mo/utils/graph.py b/model-optimizer/mo/utils/graph.py
index 7e7d9f7c1..43417f2e6 100644
--- a/model-optimizer/mo/utils/graph.py
+++ b/model-optimizer/mo/utils/graph.py
@@ -167,7 +167,7 @@ def is_connected_component(graph: nx.MultiDiGraph, node_names: list):
return set(node_names).issubset(visited)
-def sub_graph_between_nodes(graph: nx.MultiDiGraph, start_nodes: list, end_nodes: list):
+def sub_graph_between_nodes(graph: nx.MultiDiGraph, start_nodes: list, end_nodes: list, detect_extra_start_node: callable=None):
"""
Finds nodes of the sub-graph between 'start_nodes' and 'end_nodes'. Input nodes for the sub-graph nodes are also
added to the sub-graph. Constant inputs of the 'start_nodes' are also added to the sub-graph.
@@ -179,6 +179,7 @@ def sub_graph_between_nodes(graph: nx.MultiDiGraph, start_nodes: list, end_nodes
sub_graph_nodes = list()
visited = set(start_nodes)
d = deque(start_nodes)
+ extra_start_nodes = []
nx.set_node_attributes(graph, name='prev', values=None)
while len(d) != 0:
@@ -194,9 +195,12 @@ def sub_graph_between_nodes(graph: nx.MultiDiGraph, start_nodes: list, end_nodes
for src_node_name, _ in graph.in_edges(cur_node_name):
# add input nodes for the non-start_nodes
if cur_node_name not in start_nodes and src_node_name not in visited:
- d.append(src_node_name)
- graph.node[src_node_name]['prev'] = cur_node_name
- visited.add(src_node_name)
+ if detect_extra_start_node is not None and detect_extra_start_node(Node(graph, cur_node_name)):
+ extra_start_nodes.append(cur_node_name)
+ else:
+ d.append(src_node_name)
+ graph.node[src_node_name]['prev'] = cur_node_name
+ visited.add(src_node_name)
# use forward dfs to check that all end nodes are reachable from at least one of input nodes
forward_visited = set()
@@ -216,9 +220,12 @@ def sub_graph_between_nodes(graph: nx.MultiDiGraph, start_nodes: list, end_nodes
path.append(str(cur_node))
cur_node = graph.node[cur_node]['prev']
log.debug("The path from input node is the following: {}".format('\n'.join(path)))
- raise Error('Sub-graph contains network input node "{}". '.format(node_name) +
+ raise Error('The matched sub-graph contains network input node "{}". '.format(node_name) +
refer_to_faq_msg(75))
- return sub_graph_nodes
+ if detect_extra_start_node is None:
+ return sub_graph_nodes
+ else:
+ return sub_graph_nodes, extra_start_nodes
def node_neighbourhood(node_name: str, depth: int, next_node_fn):
diff --git a/model-optimizer/mo/utils/logger.py b/model-optimizer/mo/utils/logger.py
index 20be46a3f..26b7c2fdc 100644
--- a/model-optimizer/mo/utils/logger.py
+++ b/model-optimizer/mo/utils/logger.py
@@ -27,7 +27,8 @@ class LvlFormatter(log.Formatter):
log.INFO: "[ %(levelname)s ] %(msg)s",
log.WARNING: "[ WARNING ] %(msg)s",
log.ERROR: "[ %(levelname)s ] %(msg)s",
- log.CRITICAL: "[ %(levelname)s ] %(msg)s"
+ log.CRITICAL: "[ %(levelname)s ] %(msg)s",
+ 'framework_error': "[ FRAMEWORK ERROR ] %(msg)s"
}
def __init__(self, lvl, fmt=None):
@@ -41,6 +42,8 @@ class LvlFormatter(log.Formatter):
self._style._fmt = self.format_dict[record.levelno]
if 'is_warning' in record.__dict__.keys():
self._style._fmt = self.format_dict[log.WARNING]
+ if 'framework_error' in record.__dict__.keys():
+ self._style._fmt = self.format_dict['framework_error']
return log.Formatter.format(self, record)
diff --git a/model-optimizer/mo/utils/pipeline_config.py b/model-optimizer/mo/utils/pipeline_config.py
index c5ae47f89..901bf453c 100644
--- a/model-optimizer/mo/utils/pipeline_config.py
+++ b/model-optimizer/mo/utils/pipeline_config.py
@@ -13,11 +13,65 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
+import logging as log
+import re
from mo.utils.error import Error
from mo.utils.simple_proto_parser import SimpleProtoParser
+# The list of rules how to map the value from the pipeline.config file to the dictionary with attributes.
+# The rule is either a string or a tuple with two elements. In the first case the rule string is used as a key to
+# search in the parsed pipeline.config file attributes dictionary and a key to save found value. In the second case the
+# first element of the tuple is the key to save found value; the second element of the tuple is a string defining the
+# path to the value of the attribute in the pipeline.config file. The path consists of the regular expression strings
+# defining the dictionary key to look for separated with a '/' character.
+mapping_rules = [
+ 'num_classes',
+ # preprocessing block attributes
+ ('resizer_image_height', 'image_resizer/fixed_shape_resizer/height'),
+ ('resizer_image_width', 'image_resizer/fixed_shape_resizer/width'),
+ ('resizer_min_dimension', 'image_resizer/keep_aspect_ratio_resizer/min_dimension'),
+ ('resizer_max_dimension', 'image_resizer/keep_aspect_ratio_resizer/max_dimension'),
+ # anchor generator attributes
+ ('anchor_generator_height', 'first_stage_anchor_generator/grid_anchor_generator/height$', 256),
+ ('anchor_generator_width', 'first_stage_anchor_generator/grid_anchor_generator/width$', 256),
+ ('anchor_generator_height_stride', 'first_stage_anchor_generator/grid_anchor_generator/height_stride', 16),
+ ('anchor_generator_width_stride', 'first_stage_anchor_generator/grid_anchor_generator/width_stride', 16),
+ ('anchor_generator_scales', 'first_stage_anchor_generator/grid_anchor_generator/scales'),
+ ('anchor_generator_aspect_ratios', 'first_stage_anchor_generator/grid_anchor_generator/aspect_ratios'),
+ ('multiscale_anchor_generator_min_level', 'anchor_generator/multiscale_anchor_generator/min_level'),
+ ('multiscale_anchor_generator_max_level', 'anchor_generator/multiscale_anchor_generator/max_level'),
+ ('multiscale_anchor_generator_anchor_scale', 'anchor_generator/multiscale_anchor_generator/anchor_scale'),
+ ('multiscale_anchor_generator_aspect_ratios', 'anchor_generator/multiscale_anchor_generator/aspect_ratios'),
+ ('multiscale_anchor_generator_scales_per_octave', 'anchor_generator/multiscale_anchor_generator/scales_per_octave'),
+ # SSD anchor generator attributes
+ ('ssd_anchor_generator_min_scale', 'anchor_generator/ssd_anchor_generator/min_scale'),
+ ('ssd_anchor_generator_max_scale', 'anchor_generator/ssd_anchor_generator/max_scale'),
+ ('ssd_anchor_generator_num_layers', 'anchor_generator/ssd_anchor_generator/num_layers'),
+ ('ssd_anchor_generator_aspect_ratios', 'anchor_generator/ssd_anchor_generator/aspect_ratios'),
+ ('ssd_anchor_generator_reduce_lowest', 'anchor_generator/ssd_anchor_generator/reduce_boxes_in_lowest_layer'),
+ ('ssd_anchor_generator_base_anchor_height', 'anchor_generator/ssd_anchor_generator/base_anchor_height', 1.0),
+ ('ssd_anchor_generator_base_anchor_width', 'anchor_generator/ssd_anchor_generator/base_anchor_width', 1.0),
+ # Proposal and ROI Pooling layers attributes
+ ('first_stage_nms_score_threshold', '.*_nms_score_threshold'),
+ ('first_stage_nms_iou_threshold', '.*_nms_iou_threshold'),
+ ('first_stage_max_proposals', '.*_max_proposals'),
+ 'initial_crop_size',
+ # Detection Output layer attributes
+ ('postprocessing_score_converter', '.*/score_converter'),
+ ('postprocessing_score_threshold', '.*/batch_non_max_suppression/score_threshold'),
+ ('postprocessing_iou_threshold', '.*/batch_non_max_suppression/iou_threshold'),
+ ('postprocessing_max_detections_per_class', '.*/batch_non_max_suppression/max_detections_per_class'),
+ ('postprocessing_max_total_detections', '.*/batch_non_max_suppression/max_total_detections'),
+ # Variances for predicted bounding box deltas (tx, ty, tw, th)
+ ('frcnn_variance_x', 'box_coder/faster_rcnn_box_coder/x_scale', 10.0),
+ ('frcnn_variance_y', 'box_coder/faster_rcnn_box_coder/y_scale', 10.0),
+ ('frcnn_variance_width', 'box_coder/faster_rcnn_box_coder/width_scale', 5.0),
+ ('frcnn_variance_height', 'box_coder/faster_rcnn_box_coder/height_scale', 5.0)
+]
+
+
class PipelineConfig:
"""
The class that parses pipeline.config files used to generate TF models generated using Object Detection API.
@@ -33,57 +87,55 @@ class PipelineConfig:
self._initialize_model_params()
+ @staticmethod
+ def _get_value_by_path(params: dict, path: list):
+ if not path or len(path) == 0:
+ return None
+ if not isinstance(params, dict):
+ return None
+ compiled_regexp = re.compile(path[0])
+ for key in params.keys():
+ if re.match(compiled_regexp, key):
+ if len(path) == 1:
+ return params[key]
+ else:
+ value = __class__._get_value_by_path(params[key], path[1:])
+ if value is not None:
+ return value
+ return None
+
+ def _update_param_using_rule(self, params: dict, rule: [str, tuple]):
+ if isinstance(rule, str):
+ if rule in params:
+ self._model_params[rule] = params[rule]
+ log.debug('Found value "{}" for path "{}"'.format(params[rule], rule))
+ elif isinstance(rule, tuple):
+ if len(rule) != 2 and len(rule) != 3:
+ raise Error('Invalid rule length. Rule must be a tuple with two elements: key and path, or three '
+ 'elements: key, path, default_value.')
+ value = __class__._get_value_by_path(params, rule[1].split('/'))
+ if value is not None:
+ log.debug('Found value "{}" for path "{}"'.format(value, rule[1]))
+ self._model_params[rule[0]] = value
+ elif len(rule) == 3:
+ self._model_params[rule[0]] = rule[2]
+ log.debug('There is no value path "{}". Set default value "{}"'.format(value, rule[2]))
+
+ else:
+ raise Error('Invalid rule type. Rule can be either string or tuple')
+
def _initialize_model_params(self):
"""
Store global params in the dedicated dictionary self._model_params for easier use.
:return: None
"""
- params = list(self._raw_data_dict['model'].values())[0]
- # global topology parameters
- self._model_params['num_classes'] = params['num_classes']
-
- # pre-processing of the image
- self._model_params['image_resizer'] = list(params['image_resizer'].keys())[0]
- image_resize_params = list(params['image_resizer'].values())[0]
- if self._model_params['image_resizer'] == 'keep_aspect_ratio_resizer':
- self._model_params['preprocessed_image_height'] = image_resize_params['min_dimension']
- self._model_params['preprocessed_image_width'] = self._model_params['preprocessed_image_height']
- elif self._model_params['image_resizer'] == 'fixed_shape_resizer':
- self._model_params['preprocessed_image_height'] = image_resize_params['height']
- self._model_params['preprocessed_image_width'] = image_resize_params['width']
- else:
- raise Error('Unknown image resizer type "{}"'.format(self._model_params['image_resizer']))
-
- # grid anchors generator
- if 'first_stage_anchor_generator' in params:
- grid_params = params['first_stage_anchor_generator']['grid_anchor_generator']
- self._model_params['anchor_generator_base_size'] = 256
- self._model_params['anchor_generator_stride'] = grid_params['height_stride']
- self._model_params['anchor_generator_scales'] = grid_params['scales']
- self._model_params['anchor_generator_aspect_ratios'] = grid_params['aspect_ratios']
-
- if 'feature_extractor' in params:
- if 'first_stage_features_stride' in params['feature_extractor']:
- self._model_params['features_extractor_stride'] = params['feature_extractor']['first_stage_features_stride']
- else: # the value is not specified in the configuration file for NASNet so use default value here
- self._model_params['features_extractor_stride'] = 16
-
- # Proposal and ROI Pooling layers
- for param in ['first_stage_nms_score_threshold', 'first_stage_nms_iou_threshold', 'first_stage_max_proposals',
- 'initial_crop_size']:
- if param in params:
- self._model_params[param] = params[param]
-
- # post-processing parameters
- postprocessing_params = None
- for postpocessing_type in ['post_processing', 'second_stage_post_processing']:
- if postpocessing_type in params:
- postprocessing_params = params[postpocessing_type]['batch_non_max_suppression']
- self._model_params['postprocessing_score_converter'] = params[postpocessing_type]['score_converter']
- if postprocessing_params is not None:
- for param in ['score_threshold', 'iou_threshold', 'max_detections_per_class', 'max_total_detections']:
- self._model_params['postprocessing_' + param] = postprocessing_params[param]
+ if 'model' not in self._raw_data_dict:
+ raise Error('The "model" key is not found in the configuration file. Looks like the parsed file is not '
+ 'Object Detection API model configuration file.')
+ params = list(self._raw_data_dict['model'].values())[0]
+ for rule in mapping_rules:
+ self._update_param_using_rule(params, rule)
def get_param(self, param: str):
if param not in self._model_params:
diff --git a/model-optimizer/mo/utils/replacement_pattern.py b/model-optimizer/mo/utils/replacement_pattern.py
index 53f76114d..d77f7ce63 100644
--- a/model-optimizer/mo/utils/replacement_pattern.py
+++ b/model-optimizer/mo/utils/replacement_pattern.py
@@ -25,7 +25,7 @@ class ReplacementPattern(object):
excluded_replacers = []
def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
- apply_pattern(graph, **self.pattern(), action=self.replace_pattern)
+ apply_pattern(graph, **self.pattern(), action=self.replace_pattern) # pylint: disable=no-member
def run_before(self):
"""
diff --git a/model-optimizer/mo/utils/simple_proto_parser.py b/model-optimizer/mo/utils/simple_proto_parser.py
index 1ffae1c86..9009f6be5 100644
--- a/model-optimizer/mo/utils/simple_proto_parser.py
+++ b/model-optimizer/mo/utils/simple_proto_parser.py
@@ -156,11 +156,17 @@ class SimpleProtoParser(object):
if line.startswith('#'): # skip comments
continue
for char in line:
- if char == '"':
- if string_started: # string ended
+ if string_started:
+ if char == '"': # string ended
self._add_non_empty_token(cur_token)
+ cur_token = '' # start of a new string
+ string_started = False
+ else:
+ cur_token += char
+ elif char == '"':
+ self._add_non_empty_token(cur_token)
cur_token = '' # start of a new string
- string_started = not string_started
+ string_started = True
elif (char == " " and not string_started) or char == '\n':
self._add_non_empty_token(cur_token)
cur_token = ''
diff --git a/model-optimizer/mo/utils/str_to.py b/model-optimizer/mo/utils/str_to.py
index c3e591db4..c27a5813a 100644
--- a/model-optimizer/mo/utils/str_to.py
+++ b/model-optimizer/mo/utils/str_to.py
@@ -18,6 +18,8 @@
class StrTo(object):
@staticmethod
def tuple(type_of_elements: type, string: str):
+ if type_of_elements == int:
+ string = string.replace('L', '')
return tuple(type_of_elements(x) for x in string[1:-1].split(','))
@staticmethod
diff --git a/model-optimizer/mo/utils/summarize_graph.py b/model-optimizer/mo/utils/summarize_graph.py
index 2ce5b9da3..bfaea4d92 100644
--- a/model-optimizer/mo/utils/summarize_graph.py
+++ b/model-optimizer/mo/utils/summarize_graph.py
@@ -1,4 +1,21 @@
#!/usr/bin/env python3
+
+"""
+ Copyright (c) 2018 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
import argparse
import os
import sys
@@ -18,9 +35,9 @@ def summarize_graph(graph_def):
placeholders = dict()
outputs = list()
graph = tf.Graph()
- with graph.as_default():
+ with graph.as_default(): # pylint: disable=not-context-manager
tf.import_graph_def(graph_def, name='')
- for node in graph.as_graph_def().node:
+ for node in graph.as_graph_def().node: # pylint: disable=no-member
if node.op == 'Placeholder':
node_dict = dict()
node_dict['type'] = tf.DType(node.attr['dtype'].type).name
@@ -42,10 +59,11 @@ if __name__ == "__main__": # pragma: no cover
parser = argparse.ArgumentParser()
parser.add_argument("--input_model", type=str, help="Path to tensorflow model", default="")
parser.add_argument('--input_model_is_text', dest='text',
- help='TensorFlow*: treat the input model file in a text protobuf format instead of ' +
- 'binary, which is default.', action='store_true', default=False)
+ help='TensorFlow*: treat the input model file as a text protobuf format. If not specified, '
+ 'the Model Optimizer treats it as a binary file by default.', action='store_true',
+ default=False)
parser.add_argument('--input_meta', action='store_true',
- help='TensorFlow*: treat the input model file in a meta graph def format', default=False)
+ help='TensorFlow*: treat the input model file as a meta graph def format', default=False)
parser.add_argument("--input_checkpoint", type=str, help='TensorFlow variables file to load.', default="")
parser.add_argument('--saved_model_dir', type=str, default="", help="TensorFlow saved_model_dir")
parser.add_argument('--saved_model_tags', type=str, default="",
@@ -54,7 +72,7 @@ if __name__ == "__main__": # pragma: no cover
argv = parser.parse_args()
if not argv.input_model and not argv.saved_model_dir:
- print("[ ERROR ] Please, provide --input_model and --input_model_is_text if needed or --input_dir for saved " \
+ print("[ ERROR ] Please, provide --input_model and --input_model_is_text if needed or --input_dir for saved "
"model directory")
sys.exit(1)
if argv.input_model and argv.saved_model_dir:
diff --git a/model-optimizer/mo/utils/utils.py b/model-optimizer/mo/utils/utils.py
index b613d41fa..ad5c3f4d6 100644
--- a/model-optimizer/mo/utils/utils.py
+++ b/model-optimizer/mo/utils/utils.py
@@ -15,6 +15,9 @@
"""
+import numpy as np
+
+
def refer_to_faq_msg(question_num: int):
return '\n For more information please refer to Model Optimizer FAQ' \
' (<INSTALL_DIR>/deployment_tools/documentation/docs/MO_FAQ.html),' \
@@ -25,3 +28,13 @@ class NamedAttrsClass:
def __init__(self, class_attrs: dict):
for key, val in class_attrs.items():
self.__setattr__(key, val)
+
+
+def match_shapes(pattern: np.array, shape: np.array):
+ '''Check if shape matches shape pattern handling -1 and 0 in the pattern.'''
+ # Elements with values -1 and 0 in pattern are just ignored.
+ # Other elements should match.
+ if pattern.size != shape.size:
+ return False
+ indices = [i for i, n in enumerate(pattern) if n not in [0, -1]]
+ return np.array_equal(pattern[indices], shape[indices])
diff --git a/model-optimizer/mo/utils/versions_checker.py b/model-optimizer/mo/utils/versions_checker.py
index ad11f38a1..328ff441e 100644
--- a/model-optimizer/mo/utils/versions_checker.py
+++ b/model-optimizer/mo/utils/versions_checker.py
@@ -7,7 +7,6 @@
http://www.apache.org/licenses/LICENSE-2.0
-
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.