summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tools/nnpackage_tool/tflite2circle/README.md28
-rwxr-xr-xtools/nnpackage_tool/tflite2circle/tflite2circle.sh73
-rwxr-xr-xtools/nnpackage_tool/tflite2circle/tflitejson2circlejson.py28
3 files changed, 129 insertions, 0 deletions
diff --git a/tools/nnpackage_tool/tflite2circle/README.md b/tools/nnpackage_tool/tflite2circle/README.md
new file mode 100644
index 000000000..94ef5068c
--- /dev/null
+++ b/tools/nnpackage_tool/tflite2circle/README.md
@@ -0,0 +1,28 @@
+# tflite2circle
+
+`tflite2circle` is a tool to convert tflite into circle.
+
+## Usage
+
+```
+Usage: tflite2circle.sh [options] tflite
+Convert tflite to circle
+
+Returns
+ 0 success
+ non-zero failure
+
+Options:
+ -h show this help
+ -o set output directory (default=.)
+
+Environment variables:
+ flatc path to flatc
+ (default=./build/externals/FLATBUFFERS/build/flatc)
+ tflite_schema path to schema.fbs
+ (default=./externals/TENSORFLOW-1.12/tensorflow/contrib/lite/schema/schema.fbs)
+
+Examples:
+ tflite2circle.sh Add_000.tflite => convert Add_000.tflite into Add_000.circle
+ tflite2circle.sh -o my/circles Add_000 => convert Add_000.tflite into my/circles/Add_000.circle
+```
diff --git a/tools/nnpackage_tool/tflite2circle/tflite2circle.sh b/tools/nnpackage_tool/tflite2circle/tflite2circle.sh
new file mode 100755
index 000000000..259e57d83
--- /dev/null
+++ b/tools/nnpackage_tool/tflite2circle/tflite2circle.sh
@@ -0,0 +1,73 @@
+#!/bin/bash
+
+set -u
+
+progname=$(basename "${BASH_SOURCE[0]}")
+script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+outdir="."
+flatc=${flatc:-"./build/externals/FLATBUFFERS/build/flatc"}
+tflite_schema=${tflite_schema:-"./externals/TENSORFLOW-1.12/tensorflow/contrib/lite/schema/schema.fbs"}
+circle_schema=${circle_schema:-"./nnpackage/schema/circle_schema.fbs"}
+
+usage() {
+ echo "Usage: $progname [options] tflite"
+ echo "Convert tflite to circle"
+ echo ""
+ echo "Returns"
+ echo " 0 success"
+ echo " non-zero failure"
+ echo ""
+ echo "Options:"
+ echo " -h show this help"
+ echo " -o set output directory (default=$outdir)"
+ echo ""
+ echo "Environment variables:"
+ echo " flatc path to flatc"
+ echo " (default=./build/externals/FLATBUFFERS/build/flatc)"
+ echo " tflite_schema path to tflite schema (i.e. schema.fbs)"
+ echo " (default=./externals/TENSORFLOW-1.12/tensorflow/contrib/lite/schema/schema.fbs)"
+ echo " circle_schema path to circle schema"
+ echo " (default=./nnpackage/schema/circle_schema.fbs)"
+ echo ""
+ echo "Examples:"
+ echo " $progname Add_000.tflite => convert Add_000.tflite into Add_000.circle"
+ echo " $progname -o my/circles Add_000 => convert Add_000.tflite into my/circles/Add_000.circle"
+ exit 1
+}
+
+if [ $# -eq 0 ]; then
+ echo "For help, type $progname -h"
+ exit 1
+fi
+
+while getopts "ho:" OPTION; do
+case "${OPTION}" in
+ h) usage;;
+ o) outdir=$OPTARG;;
+ ?) exit 1;;
+esac
+done
+
+shift $((OPTIND-1))
+
+if [ $# -ne 1 ]; then
+ echo "error: wrong argument (no argument or too many arguments)."
+ echo "For help, type $progname -h"
+ exit 1
+fi
+
+if [ ! -e Product ]; then
+ echo "error: please make sure to run this script in nnfw home."
+ exit 1
+fi
+
+tflite_base=$(basename "$1")
+name=${tflite_base%.*}
+
+# convert
+
+mkdir -p "${outdir}"
+${flatc} -o ${outdir} --defaults-json --strict-json -t ${tflite_schema} -- $1
+${script_dir}/tflitejson2circlejson.py "${outdir}/${name}.json" > "${outdir}/${name}.circle"
+${flatc} -o ${outdir} -b ${circle_schema} "${outdir}/${name}.circle"
+rm -f ${outdir}/${name}.json
diff --git a/tools/nnpackage_tool/tflite2circle/tflitejson2circlejson.py b/tools/nnpackage_tool/tflite2circle/tflitejson2circlejson.py
new file mode 100755
index 000000000..c20a0c53e
--- /dev/null
+++ b/tools/nnpackage_tool/tflite2circle/tflitejson2circlejson.py
@@ -0,0 +1,28 @@
+#!/usr/bin/python3
+
+import json
+import os
+import sys
+from collections import OrderedDict
+
+
+def usage():
+ script = os.path.basename(os.path.basename(__file__))
+ print("Usage: {} path_to_tflite_in_json".format(script))
+ sys.exit(-1)
+
+
+if __name__ == '__main__':
+ if len(sys.argv) != 2:
+ usage()
+
+ json_path = sys.argv[1]
+ with open(json_path, "r") as f:
+ try:
+ json_dict = json.load(f, object_pairs_hook=OrderedDict)
+ for subgraph in json_dict["subgraphs"]:
+ subgraph["data_format"] = "CHANNELS_LAST"
+ print(json.dumps(json_dict, indent=2))
+ except KeyError:
+ print("subgraphs attribute does not exist.")
+ sys.exit(-2)