diff options
Diffstat (limited to 'model-optimizer/mo/front/tf/loader.py')
-rw-r--r-- | model-optimizer/mo/front/tf/loader.py | 95 |
1 files changed, 89 insertions, 6 deletions
diff --git a/model-optimizer/mo/front/tf/loader.py b/model-optimizer/mo/front/tf/loader.py index b8d8ca192..8310e0acb 100644 --- a/model-optimizer/mo/front/tf/loader.py +++ b/model-optimizer/mo/front/tf/loader.py @@ -14,9 +14,12 @@ limitations under the License. """ +import logging as log import os import re +import networkx as nx + from mo.utils.error import Error, FrameworkError from mo.utils.utils import refer_to_faq_msg @@ -31,6 +34,55 @@ from mo.graph.graph import create_graph_with_nodes from mo.utils.summarize_graph import summarize_graph +def freeze_checkpoints(graph_def: tf.GraphDef, checkpoint_dir: str, output_node_names: list): + """ + Loads all the variables in a graph and stores them in a separate dictionary. Freezes output nodes in the graph + :param graph_def: GraphDef object holding the network. + :param checkpoint_dir: path to directory with checkpoint files with values of graph variables. + :param output_node_names: list of output node names. + :return: GraphDef containing a simplified version of the original. + """ + log.debug("Loading checkpoint files from directory: {}".format(checkpoint_dir)) + checkpoint_files = [] + for checkpoint_name in sorted(os.listdir(checkpoint_dir)): + checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) + if os.path.isfile(checkpoint_path): + checkpoint_files.append(checkpoint_path) + log.debug("File {} will be loaded".format(checkpoint_path)) + else: + log.debug("Path {} is not a file. Skipping") + + if len(checkpoint_files) == 0: + raise Error("There are no checkpoint files in directory: {}".format(checkpoint_dir)) + + tf.import_graph_def(graph_def, name='') + + with tf.Session() as sess: + uninitialized_variables = [str(v, 'utf-8') for v in set(sess.run(tf.report_uninitialized_variables()))] + all_variables = [n.name for n in sess.graph.as_graph_def().node if n.op in ['Variable', 'VariableV2']] + white_list = [v for v in all_variables if v not in uninitialized_variables] + black_list = [v for v in all_variables if v in uninitialized_variables] + output_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, output_node_names, + variable_names_whitelist=white_list, + variable_names_blacklist=black_list) + variable_values = {} + for checkpoint_file in checkpoint_files: + log.debug("Loading {}".format(checkpoint_file)) + with tf.Session() as sess: + var_list = {} + var_to_shape_map = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint_file).get_variable_to_shape_map() + for key in var_to_shape_map: + try: + tensor = sess.graph.get_operation_by_name(key).outputs[0] + except KeyError: + continue + var_list[key] = tensor + tf.train.Saver(var_list=var_list).restore(sess, checkpoint_file) + for name, tensor in var_list.items(): + variable_values[name] = sess.run(tensor) + return output_graph_def, variable_values + + def freeze_checkpoint(graph_def, checkpoint, output_node_names): """ Replaces all the variables in a graph with constants of the same values. @@ -40,6 +92,7 @@ def freeze_checkpoint(graph_def, checkpoint, output_node_names): :return: GraphDef containing a simplified version of the original. """ tf.import_graph_def(graph_def, name="") + with tf.Session() as sess: var_list = {} var_to_shape_map = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint).get_variable_to_shape_map() @@ -54,7 +107,8 @@ def freeze_checkpoint(graph_def, checkpoint, output_node_names): return output_graph_def -def read_file_to_graph_def(graph_def: [tf.GraphDef, tf.MetaGraphDef], graph_file_name: str = "", is_binary: bool = True): +def read_file_to_graph_def(graph_def: [tf.GraphDef, tf.MetaGraphDef], graph_file_name: str = "", + is_binary: bool = True): """ Reads file to protobuf :param graph_def: GraphDef orr MetaGraphDef object to store the network @@ -141,16 +195,22 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo '--input_checkpoint "path/to/*.ckpt"' '\n\n2. For "*.meta" file:' '\npython3 mo_tf.py --input_meta_graph "path/to/*.meta"') - + variables_values = {} try: if graph_file_name and not meta_graph_file and not checkpoint: # frozen graph - return read_file_to_graph_def(graph_def, graph_file_name, is_binary) + return read_file_to_graph_def(graph_def, graph_file_name, is_binary), variables_values if graph_file_name and not meta_graph_file and checkpoint: # inference graph and checkpoint graph_def = read_file_to_graph_def(graph_def, graph_file_name, is_binary) outputs = get_output_node_names_list(graph_def, user_output_node_names_list) - return freeze_checkpoint(graph_def=graph_def, checkpoint=checkpoint, output_node_names=outputs) + if os.path.isfile(checkpoint): + graph_def = freeze_checkpoint(graph_def=graph_def, checkpoint=checkpoint, output_node_names=outputs) + elif os.path.isdir(checkpoint): + graph_def, variables_values = freeze_checkpoints(graph_def=graph_def, checkpoint_dir=checkpoint, + output_node_names=outputs) + # we are sure that checkpoint is existing file or directory due to cli_parser configuration + return graph_def, variables_values if not graph_file_name and meta_graph_file: meta_graph_file = deducing_metagraph_path(meta_graph_file) input_meta_graph_def = read_file_to_graph_def(tf.MetaGraphDef(), meta_graph_file, is_binary) @@ -159,14 +219,16 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo restorer = tf.train.import_meta_graph(input_meta_graph_def) restorer.restore(sess, re.sub('\.meta$', '', meta_graph_file)) outputs = get_output_node_names_list(input_meta_graph_def.graph_def, user_output_node_names_list) - return tf.graph_util.convert_variables_to_constants(sess, input_meta_graph_def.graph_def, outputs) + graph_def = tf.graph_util.convert_variables_to_constants(sess, input_meta_graph_def.graph_def, outputs) + return graph_def, variables_values if model_dir: # saved model directory tags = saved_model_tags if saved_model_tags is not None else [tf.saved_model.tag_constants.SERVING] with tf.Session() as sess: meta_graph_def = tf.saved_model.loader.load(sess, tags, model_dir) outputs = get_output_node_names_list(meta_graph_def.graph_def, user_output_node_names_list) - return tf.graph_util.convert_variables_to_constants(sess, meta_graph_def.graph_def, outputs) + graph_def = tf.graph_util.convert_variables_to_constants(sess, meta_graph_def.graph_def, outputs) + return graph_def, variables_values except Exception as e: raise FrameworkError('Cannot load input model: {}', e) from e raise Error("Unknown configuration of input model parameters") @@ -194,3 +256,24 @@ def protobuf2nx(pb: tf.GraphDef): index = index + 1 return graph + + +def variables_to_constants(graph: nx.MultiDiGraph, variables_values: dict): + """ + Converts `Variable<V2>` operations to FakeConst operations with `value` from `variables_values` dictionary + :param graph: graph to operate on + :param variables_values: dictionary with variable names as keys and np.array data as values + """ + variable_operations = ['Variable', 'VariableV2'] + for node_name in graph.nodes(): + node_attr_dict = graph.node[node_name] + if 'op' not in node_attr_dict: + continue + op_name = node_attr_dict['op'] + if op_name not in variable_operations: + continue + if node_name not in variables_values: + log.debug("There is no value for '{}': {} in checkpoint variable values".format(op_name, node_name)) + continue + graph.node[node_name]['op'] = 'FakeConst' + graph.node[node_name]['value'] = variables_values[node_name] |