summaryrefslogtreecommitdiff
path: root/tools/tflkit
diff options
context:
space:
mode:
Diffstat (limited to 'tools/tflkit')
-rwxr-xr-xtools/tflkit/freeze_graph.sh26
-rw-r--r--tools/tflkit/summarize_pb.py2
-rwxr-xr-xtools/tflkit/tflite_convert.sh42
3 files changed, 38 insertions, 32 deletions
diff --git a/tools/tflkit/freeze_graph.sh b/tools/tflkit/freeze_graph.sh
index c491ba4d2..ae771cf80 100755
--- a/tools/tflkit/freeze_graph.sh
+++ b/tools/tflkit/freeze_graph.sh
@@ -4,13 +4,11 @@ usage()
{
echo "usage : $0"
echo " --info=Information file"
- echo " --tensorflow_path=TensorFlow path (Use externals/tensorflow by default)"
+ echo " [--tensorflow_path=TensorFlow path] (If omitted, the module installed in system will be used by default.)"
}
SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
-TF_DIR="${SCRIPT_PATH}/../../externals/tensorflow"
-
for i in "$@"
do
case $i in
@@ -37,11 +35,6 @@ if [ -z "$INFO" ]; then
usage
exit 1
fi
-if [ -z "$TF_DIR" ]; then
- echo "tensorflow_path is unset or set to the empty string"
- usage
- exit 1
-fi
if [ ! -x "$(command -v bazel)" ]; then
echo "Cannot find bazel. Please install bazel."
@@ -74,17 +67,22 @@ fi
CUR_DIR=$(pwd)
{
- echo "Enter $TF_DIR"
- pushd $TF_DIR > /dev/null
+ if [ -e "$TF_DIR" ]; then
+ echo "Enter $TF_DIR"
+ pushd $TF_DIR > /dev/null
+ FREEZE_GRAPH="bazel run tensorflow/python/tools:freeze_graph -- "
+ else
+ FREEZE_GRAPH="python -m tensorflow.python.tools.freeze_graph "
+ fi
if [ ! -z $SAVED_MODEL ]; then
- bazel run tensorflow/python/tools:freeze_graph -- \
+ $FREEZE_GRAPH \
--input_saved_model_dir="$SAVED_MODEL" \
--input_binary=True \
--output_node_names="$OUTPUT" \
--output_graph="$FROZEN_PATH"
else
- bazel run tensorflow/python/tools:freeze_graph -- \
+ $FREEZE_GRAPH \
--input_meta_graph="$META_GRAPH" \
--input_checkpoint="$CKPT_PATH" \
--input_binary=True \
@@ -92,7 +90,9 @@ CUR_DIR=$(pwd)
--output_graph="$FROZEN_PATH"
fi
- popd
+ if [ -e "$TF_DIR" ]; then
+ popd
+ fi
echo "OUTPUT FILE : $FROZEN_PATH"
}
diff --git a/tools/tflkit/summarize_pb.py b/tools/tflkit/summarize_pb.py
index 633804114..bdc6b252c 100644
--- a/tools/tflkit/summarize_pb.py
+++ b/tools/tflkit/summarize_pb.py
@@ -40,7 +40,7 @@ def PrintInput(data):
def PrintOutput(data):
print("Outputs")
- sub = re.findall('\((.*?)\)', data)
+ sub = re.findall(r'\((.*?)\)', data)
for i in sub:
print('\t' + i)
diff --git a/tools/tflkit/tflite_convert.sh b/tools/tflkit/tflite_convert.sh
index 2056797ab..f5b94ede0 100755
--- a/tools/tflkit/tflite_convert.sh
+++ b/tools/tflkit/tflite_convert.sh
@@ -3,15 +3,12 @@
usage()
{
echo "usage : $0"
- echo " --info=Information file"
- echo " --tensorflow_path=TensorFlow path (Use externals/tensorflow by default)"
- echo " --tensorflow_version=TensorFlow version (Must be entered)"
+ echo " --info=<infroamtion file>"
+ echo " [ --tensorflow_path=<path> --tensorflow_version=<version> ] (If omitted, the module installed in system will be used by default.)"
}
SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
-TF_DIR="${SCRIPT_PATH}/../../externals/tensorflow"
-
for i in "$@"
do
case $i in
@@ -41,15 +38,16 @@ if [ -z "$INFO" ]; then
usage
exit 1
fi
-if [ -z "$TF_DIR" ]; then
- echo "tensorflow_path is unset or set to the empty string"
- usage
- exit 1
-fi
+
if [ -z "$TF_VERSION" ]; then
- echo "tensorflow_version is unset or set to the empty string"
- usage
- exit 1
+ if [ -z "$TF_DIR" ]; then
+ TF_VERSION=$(python -c 'import tensorflow as tf; print(tf.__version__)')
+ echo "TensorFlow version detected : $TF_VERSION"
+ else
+ echo "tensorflow_version is unset or set to the empty string"
+ usage
+ exit 1
+ fi
fi
if [ ! -x "$(command -v bazel)" ]; then
@@ -87,8 +85,13 @@ fi
CUR_DIR=$(pwd)
{
- echo "Enter $TF_DIR"
- pushd $TF_DIR > /dev/null
+ if [ -e "$TF_DIR" ]; then
+ echo "Enter $TF_DIR"
+ pushd $TF_DIR > /dev/null
+ TFLITE_CONVERT="bazel run tensorflow/lite/python:tflite_convert -- "
+ else
+ TFLITE_CONVERT="python -m tensorflow.lite.python.tflite_convert "
+ fi
NAME_LIST=()
INPUT_SHAPE_LIST=()
@@ -111,7 +114,7 @@ CUR_DIR=$(pwd)
for (( i=0; i < ${#NAME_LIST[@]}; ++i )); do
if [ "${TF_VERSION%%.*}" = "2" ]; then
- bazel run tensorflow/lite/python:tflite_convert -- \
+ $TFLITE_CONVERT \
--output_file="${NAME_LIST[$i]}" \
--graph_def_file="$GRAPHDEF_PATH" \
--input_arrays="$INPUT" \
@@ -119,7 +122,7 @@ CUR_DIR=$(pwd)
--output_arrays="$OUTPUT" \
--allow_custom_ops=true
else
- bazel run tensorflow/contrib/lite/python:tflite_convert -- \
+ $TFLITE_CONVERT \
--output_file="${NAME_LIST[$i]}" \
--graph_def_file="$GRAPHDEF_PATH" \
--input_arrays="$INPUT" \
@@ -128,7 +131,10 @@ CUR_DIR=$(pwd)
--allow_custom_ops
fi
done
- popd
+
+ if [ -e "$TF_DIR" ]; then
+ popd
+ fi
for (( i=0; i < ${#NAME_LIST[@]}; ++i )); do
echo "OUTPUT FILE : ${NAME_LIST[$i]}"