diff options
Diffstat (limited to 'compiler/bcq-tools/preserve_bcq_info')
-rw-r--r-- | compiler/bcq-tools/preserve_bcq_info | 116 |
1 files changed, 0 insertions, 116 deletions
diff --git a/compiler/bcq-tools/preserve_bcq_info b/compiler/bcq-tools/preserve_bcq_info deleted file mode 100644 index 2ede8d4d0..000000000 --- a/compiler/bcq-tools/preserve_bcq_info +++ /dev/null @@ -1,116 +0,0 @@ -#!/usr/bin/env python3 - -import tensorflow as tf -import numpy as np - -import argparse -import sys - - -def _get_parser(): - """ - Returns an ArgumentParser for preserving BCQ information. - """ - parser = argparse.ArgumentParser( - description=("Command line tool to preserve BCQ information")) - - # Input and output path. - parser.add_argument( - "-i", - "--input_path", - type=str, - help="Full filepath of the input file.", - required=True) - parser.add_argument( - "-o", - "--output_path", - type=str, - help="Full filepath of the output file.", - required=True) - - return parser - - -def load_graph(frozen_graph_filename): - """ - Load graph from frozen pb file - """ - with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f: - graph_def = tf.compat.v1.GraphDef() - graph_def.ParseFromString(f.read()) - with tf.Graph().as_default() as graph: - tf.import_graph_def(graph_def, name='') - return graph - - -def preserve_bcq_info(flags): - """ - Generate unique dummy value from -1 to -N. - - We use negative values to preserve BCQ information because - positive values may cause some confusion with real BCQ information values. - """ - - class UniqueValueGen: - def __init__(self): - self.unique_value = -1 - - def gen(self): - val = self.unique_value - self.unique_value = val - 1 - return val - - unique_value = UniqueValueGen() - - original_graph_model = load_graph(flags.input_path) - original_graph_model_def = original_graph_model.as_graph_def() - - new_graph = tf.compat.v1.GraphDef() - substitution_dict = {} - - DT_INT32 = None # Just for copying DT_INT32 attribute value - - for node in original_graph_model_def.node: - if node.op == "Const": - # Because bcqinfo_do_w_x is BOOL type, we cannot add dummy value at the end. - # Therefore we should convert the type to INT32 type. - if "/bcqinfo_do_w_x" in node.name: - original_tensor = tf.make_ndarray(node.attr["value"].tensor) - substitution_dict[node.name] = tf.make_tensor_proto( - [int(original_tensor[0]), unique_value.gen()], tf.int32) - - preserved_bcqinfo_list = ["/bcqinfo_number_of_clusters", "/bcqinfo_size_of_clusters", - "/bcqinfo_qbits_of_clusters"] - - if any(name in node.name for name in preserved_bcqinfo_list): - original_tensor = tf.make_ndarray( - node.attr["value"].tensor) # variable name change - substitution_dict[node.name] = tf.make_tensor_proto( - np.append(original_tensor, unique_value.gen()), tf.int32) - DT_INT32 = node.attr["dtype"] - - for node in original_graph_model_def.node: - if node.name in substitution_dict: - new_node = new_graph.node.add() - new_node.op = "Const" - new_node.name = node.name - new_node.attr["dtype"].CopyFrom(DT_INT32) - new_node.attr["value"].tensor.CopyFrom(substitution_dict[node.name]) - else: - new_node = new_graph.node.add() - new_node.CopyFrom(node) - - tf.io.write_graph(new_graph, '.', flags.output_path, False) - - -def main(): - # Parse argument. - parser = _get_parser() - flags = parser.parse_known_args(args=sys.argv[1:]) - - # Generate a new pb file, which BCQ information is preserved. - preserve_bcq_info(flags[0]) - - -if __name__ == "__main__": - main() |