diff options
Diffstat (limited to 'compiler/bcq-tools/generate_bcq_output_arrays')
-rw-r--r-- | compiler/bcq-tools/generate_bcq_output_arrays | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/compiler/bcq-tools/generate_bcq_output_arrays b/compiler/bcq-tools/generate_bcq_output_arrays new file mode 100644 index 000000000..48e8a9373 --- /dev/null +++ b/compiler/bcq-tools/generate_bcq_output_arrays @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 + +import tensorflow as tf + +import argparse +import sys + + +def _get_parser(): + """ + Returns an ArgumentParser for generating output_arrays. + """ + parser = argparse.ArgumentParser( + description=("Command line tool to generated output_arrays of BCQ nodes")) + + # Input and output path. + parser.add_argument( + "-i", + "--input_path", + type=str, + help="Full filepath of the input file.", + required=True) + parser.add_argument( + "-o", + "--output_path", + type=str, + help="Full filepath of the output file.", + required=True) + + return parser + + +def load_graph(frozen_graph_filename): + """ + Load graph from frozen pb file + """ + with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name='') + return graph + + +def dtype2str(dtype): + if dtype == "int32": + return "TF_INT32" + elif dtype == "int64": + return "TF_INT64" + elif dtype == "float32": + return "TF_FLOAT" + elif dtype == "bool": + return "TF_BOOL" + else: + raise Exception("Not supported dtype") + + +def print_output_arrays(flags): + graph_model = load_graph(flags.input_path) + graph_model_def = graph_model.as_graph_def() + ops = graph_model.get_operations() + + output_names = [op.outputs[0].name for op in ops + if op.type == "Const" and "bcqinfo_" in op.outputs[0].name] + + output_arrays = "" + for output_name in output_names: + output_arrays += "," + + colon_index = output_name.find(":") + if colon_index == -1: + output_arrays += output_name + else: + output_arrays += output_name[:colon_index] + + f = open(flags.output_path, 'w') + f.write(output_arrays) + f.close() + + +def main(): + # Parse argument. + parser = _get_parser() + flags = parser.parse_known_args(args=sys.argv[1:]) + + print_output_arrays(flags[0]) + + +if __name__ == "__main__": + main() |