diff options
Diffstat (limited to 'compiler/bcq-tools/generate_bcq_output_arrays')
-rw-r--r-- | compiler/bcq-tools/generate_bcq_output_arrays | 90 |
1 files changed, 0 insertions, 90 deletions
diff --git a/compiler/bcq-tools/generate_bcq_output_arrays b/compiler/bcq-tools/generate_bcq_output_arrays deleted file mode 100644 index 48e8a9373..000000000 --- a/compiler/bcq-tools/generate_bcq_output_arrays +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python3 - -import tensorflow as tf - -import argparse -import sys - - -def _get_parser(): - """ - Returns an ArgumentParser for generating output_arrays. - """ - parser = argparse.ArgumentParser( - description=("Command line tool to generated output_arrays of BCQ nodes")) - - # Input and output path. - parser.add_argument( - "-i", - "--input_path", - type=str, - help="Full filepath of the input file.", - required=True) - parser.add_argument( - "-o", - "--output_path", - type=str, - help="Full filepath of the output file.", - required=True) - - return parser - - -def load_graph(frozen_graph_filename): - """ - Load graph from frozen pb file - """ - with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f: - graph_def = tf.compat.v1.GraphDef() - graph_def.ParseFromString(f.read()) - with tf.Graph().as_default() as graph: - tf.import_graph_def(graph_def, name='') - return graph - - -def dtype2str(dtype): - if dtype == "int32": - return "TF_INT32" - elif dtype == "int64": - return "TF_INT64" - elif dtype == "float32": - return "TF_FLOAT" - elif dtype == "bool": - return "TF_BOOL" - else: - raise Exception("Not supported dtype") - - -def print_output_arrays(flags): - graph_model = load_graph(flags.input_path) - graph_model_def = graph_model.as_graph_def() - ops = graph_model.get_operations() - - output_names = [op.outputs[0].name for op in ops - if op.type == "Const" and "bcqinfo_" in op.outputs[0].name] - - output_arrays = "" - for output_name in output_names: - output_arrays += "," - - colon_index = output_name.find(":") - if colon_index == -1: - output_arrays += output_name - else: - output_arrays += output_name[:colon_index] - - f = open(flags.output_path, 'w') - f.write(output_arrays) - f.close() - - -def main(): - # Parse argument. - parser = _get_parser() - flags = parser.parse_known_args(args=sys.argv[1:]) - - print_output_arrays(flags[0]) - - -if __name__ == "__main__": - main() |