summaryrefslogtreecommitdiff
path: root/model-optimizer/mo/front/tf/loader.py
diff options
context:
space:
mode:
Diffstat (limited to 'model-optimizer/mo/front/tf/loader.py')
-rw-r--r--model-optimizer/mo/front/tf/loader.py95
1 files changed, 89 insertions, 6 deletions
diff --git a/model-optimizer/mo/front/tf/loader.py b/model-optimizer/mo/front/tf/loader.py
index b8d8ca192..8310e0acb 100644
--- a/model-optimizer/mo/front/tf/loader.py
+++ b/model-optimizer/mo/front/tf/loader.py
@@ -14,9 +14,12 @@
limitations under the License.
"""
+import logging as log
import os
import re
+import networkx as nx
+
from mo.utils.error import Error, FrameworkError
from mo.utils.utils import refer_to_faq_msg
@@ -31,6 +34,55 @@ from mo.graph.graph import create_graph_with_nodes
from mo.utils.summarize_graph import summarize_graph
+def freeze_checkpoints(graph_def: tf.GraphDef, checkpoint_dir: str, output_node_names: list):
+ """
+ Loads all the variables in a graph and stores them in a separate dictionary. Freezes output nodes in the graph
+ :param graph_def: GraphDef object holding the network.
+ :param checkpoint_dir: path to directory with checkpoint files with values of graph variables.
+ :param output_node_names: list of output node names.
+ :return: GraphDef containing a simplified version of the original.
+ """
+ log.debug("Loading checkpoint files from directory: {}".format(checkpoint_dir))
+ checkpoint_files = []
+ for checkpoint_name in sorted(os.listdir(checkpoint_dir)):
+ checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
+ if os.path.isfile(checkpoint_path):
+ checkpoint_files.append(checkpoint_path)
+ log.debug("File {} will be loaded".format(checkpoint_path))
+ else:
+ log.debug("Path {} is not a file. Skipping")
+
+ if len(checkpoint_files) == 0:
+ raise Error("There are no checkpoint files in directory: {}".format(checkpoint_dir))
+
+ tf.import_graph_def(graph_def, name='')
+
+ with tf.Session() as sess:
+ uninitialized_variables = [str(v, 'utf-8') for v in set(sess.run(tf.report_uninitialized_variables()))]
+ all_variables = [n.name for n in sess.graph.as_graph_def().node if n.op in ['Variable', 'VariableV2']]
+ white_list = [v for v in all_variables if v not in uninitialized_variables]
+ black_list = [v for v in all_variables if v in uninitialized_variables]
+ output_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, output_node_names,
+ variable_names_whitelist=white_list,
+ variable_names_blacklist=black_list)
+ variable_values = {}
+ for checkpoint_file in checkpoint_files:
+ log.debug("Loading {}".format(checkpoint_file))
+ with tf.Session() as sess:
+ var_list = {}
+ var_to_shape_map = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint_file).get_variable_to_shape_map()
+ for key in var_to_shape_map:
+ try:
+ tensor = sess.graph.get_operation_by_name(key).outputs[0]
+ except KeyError:
+ continue
+ var_list[key] = tensor
+ tf.train.Saver(var_list=var_list).restore(sess, checkpoint_file)
+ for name, tensor in var_list.items():
+ variable_values[name] = sess.run(tensor)
+ return output_graph_def, variable_values
+
+
def freeze_checkpoint(graph_def, checkpoint, output_node_names):
"""
Replaces all the variables in a graph with constants of the same values.
@@ -40,6 +92,7 @@ def freeze_checkpoint(graph_def, checkpoint, output_node_names):
:return: GraphDef containing a simplified version of the original.
"""
tf.import_graph_def(graph_def, name="")
+
with tf.Session() as sess:
var_list = {}
var_to_shape_map = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint).get_variable_to_shape_map()
@@ -54,7 +107,8 @@ def freeze_checkpoint(graph_def, checkpoint, output_node_names):
return output_graph_def
-def read_file_to_graph_def(graph_def: [tf.GraphDef, tf.MetaGraphDef], graph_file_name: str = "", is_binary: bool = True):
+def read_file_to_graph_def(graph_def: [tf.GraphDef, tf.MetaGraphDef], graph_file_name: str = "",
+ is_binary: bool = True):
"""
Reads file to protobuf
:param graph_def: GraphDef orr MetaGraphDef object to store the network
@@ -141,16 +195,22 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
'--input_checkpoint "path/to/*.ckpt"'
'\n\n2. For "*.meta" file:'
'\npython3 mo_tf.py --input_meta_graph "path/to/*.meta"')
-
+ variables_values = {}
try:
if graph_file_name and not meta_graph_file and not checkpoint:
# frozen graph
- return read_file_to_graph_def(graph_def, graph_file_name, is_binary)
+ return read_file_to_graph_def(graph_def, graph_file_name, is_binary), variables_values
if graph_file_name and not meta_graph_file and checkpoint:
# inference graph and checkpoint
graph_def = read_file_to_graph_def(graph_def, graph_file_name, is_binary)
outputs = get_output_node_names_list(graph_def, user_output_node_names_list)
- return freeze_checkpoint(graph_def=graph_def, checkpoint=checkpoint, output_node_names=outputs)
+ if os.path.isfile(checkpoint):
+ graph_def = freeze_checkpoint(graph_def=graph_def, checkpoint=checkpoint, output_node_names=outputs)
+ elif os.path.isdir(checkpoint):
+ graph_def, variables_values = freeze_checkpoints(graph_def=graph_def, checkpoint_dir=checkpoint,
+ output_node_names=outputs)
+ # we are sure that checkpoint is existing file or directory due to cli_parser configuration
+ return graph_def, variables_values
if not graph_file_name and meta_graph_file:
meta_graph_file = deducing_metagraph_path(meta_graph_file)
input_meta_graph_def = read_file_to_graph_def(tf.MetaGraphDef(), meta_graph_file, is_binary)
@@ -159,14 +219,16 @@ def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpo
restorer = tf.train.import_meta_graph(input_meta_graph_def)
restorer.restore(sess, re.sub('\.meta$', '', meta_graph_file))
outputs = get_output_node_names_list(input_meta_graph_def.graph_def, user_output_node_names_list)
- return tf.graph_util.convert_variables_to_constants(sess, input_meta_graph_def.graph_def, outputs)
+ graph_def = tf.graph_util.convert_variables_to_constants(sess, input_meta_graph_def.graph_def, outputs)
+ return graph_def, variables_values
if model_dir:
# saved model directory
tags = saved_model_tags if saved_model_tags is not None else [tf.saved_model.tag_constants.SERVING]
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess, tags, model_dir)
outputs = get_output_node_names_list(meta_graph_def.graph_def, user_output_node_names_list)
- return tf.graph_util.convert_variables_to_constants(sess, meta_graph_def.graph_def, outputs)
+ graph_def = tf.graph_util.convert_variables_to_constants(sess, meta_graph_def.graph_def, outputs)
+ return graph_def, variables_values
except Exception as e:
raise FrameworkError('Cannot load input model: {}', e) from e
raise Error("Unknown configuration of input model parameters")
@@ -194,3 +256,24 @@ def protobuf2nx(pb: tf.GraphDef):
index = index + 1
return graph
+
+
+def variables_to_constants(graph: nx.MultiDiGraph, variables_values: dict):
+ """
+ Converts `Variable<V2>` operations to FakeConst operations with `value` from `variables_values` dictionary
+ :param graph: graph to operate on
+ :param variables_values: dictionary with variable names as keys and np.array data as values
+ """
+ variable_operations = ['Variable', 'VariableV2']
+ for node_name in graph.nodes():
+ node_attr_dict = graph.node[node_name]
+ if 'op' not in node_attr_dict:
+ continue
+ op_name = node_attr_dict['op']
+ if op_name not in variable_operations:
+ continue
+ if node_name not in variables_values:
+ log.debug("There is no value for '{}': {} in checkpoint variable values".format(op_name, node_name))
+ continue
+ graph.node[node_name]['op'] = 'FakeConst'
+ graph.node[node_name]['value'] = variables_values[node_name]