summaryrefslogtreecommitdiff
path: root/compiler/bcq-tools/generate_bcq_output_arrays
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/bcq-tools/generate_bcq_output_arrays')
-rw-r--r--compiler/bcq-tools/generate_bcq_output_arrays90
1 files changed, 0 insertions, 90 deletions
diff --git a/compiler/bcq-tools/generate_bcq_output_arrays b/compiler/bcq-tools/generate_bcq_output_arrays
deleted file mode 100644
index 48e8a9373..000000000
--- a/compiler/bcq-tools/generate_bcq_output_arrays
+++ /dev/null
@@ -1,90 +0,0 @@
-#!/usr/bin/env python3
-
-import tensorflow as tf
-
-import argparse
-import sys
-
-
-def _get_parser():
- """
- Returns an ArgumentParser for generating output_arrays.
- """
- parser = argparse.ArgumentParser(
- description=("Command line tool to generated output_arrays of BCQ nodes"))
-
- # Input and output path.
- parser.add_argument(
- "-i",
- "--input_path",
- type=str,
- help="Full filepath of the input file.",
- required=True)
- parser.add_argument(
- "-o",
- "--output_path",
- type=str,
- help="Full filepath of the output file.",
- required=True)
-
- return parser
-
-
-def load_graph(frozen_graph_filename):
- """
- Load graph from frozen pb file
- """
- with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f:
- graph_def = tf.compat.v1.GraphDef()
- graph_def.ParseFromString(f.read())
- with tf.Graph().as_default() as graph:
- tf.import_graph_def(graph_def, name='')
- return graph
-
-
-def dtype2str(dtype):
- if dtype == "int32":
- return "TF_INT32"
- elif dtype == "int64":
- return "TF_INT64"
- elif dtype == "float32":
- return "TF_FLOAT"
- elif dtype == "bool":
- return "TF_BOOL"
- else:
- raise Exception("Not supported dtype")
-
-
-def print_output_arrays(flags):
- graph_model = load_graph(flags.input_path)
- graph_model_def = graph_model.as_graph_def()
- ops = graph_model.get_operations()
-
- output_names = [op.outputs[0].name for op in ops
- if op.type == "Const" and "bcqinfo_" in op.outputs[0].name]
-
- output_arrays = ""
- for output_name in output_names:
- output_arrays += ","
-
- colon_index = output_name.find(":")
- if colon_index == -1:
- output_arrays += output_name
- else:
- output_arrays += output_name[:colon_index]
-
- f = open(flags.output_path, 'w')
- f.write(output_arrays)
- f.close()
-
-
-def main():
- # Parse argument.
- parser = _get_parser()
- flags = parser.parse_known_args(args=sys.argv[1:])
-
- print_output_arrays(flags[0])
-
-
-if __name__ == "__main__":
- main()