summaryrefslogtreecommitdiff
path: root/infra/packaging/res/tf2nnpkg.20210910
blob: 0d44818a116d79eae52b2d3a48f856fd2aad39d0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/bin/bash

set -e

ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"

command_exists() {
  if [ "$#" -le 0 ]; then
    return 1
  fi
  command -v "$@" > /dev/null 2>&1
}

usage()
{
  echo "Convert TensorFlow model to nnpackage."
  echo "Usage: tf2nnpkg"
  echo "    --info <path/to/info>"
  echo "    --graphdef <path/to/pb>"
  echo "    -o <path/to/nnpkg/directory>"
  echo "    --v2 (optional) Use TF 2.x interface"
  exit 255
}

TF_INTERFACE="--v1"

# Parse command-line arguments
#
while [ "$#" -ne 0 ]; do
  CUR="$1"

  case $CUR in
    '--help')
      usage
      ;;
    '--info')
      export INFO_FILE="$2"
      shift 2
      ;;
    '--graphdef')
      export GRAPHDEF_FILE="$2"
      shift 2
      ;;
    '-o')
      export OUTPUT_DIR="$2"
      shift 2
      ;;
    '--v2')
      TF_INTERFACE="--v2"
      shift
      ;;
    *)
      echo "${CUR}"
      shift
      ;;
  esac
done

if [ -z ${GRAPHDEF_FILE} ] || [ ! -e ${GRAPHDEF_FILE} ]; then
  echo "pb is not found. Please check --graphdef is correct."
  exit 2
fi

if [ -z ${INFO_FILE} ] || [ ! -e ${INFO_FILE} ]; then
  echo "info is not found. Please check --info is correct."
  exit 2
fi

if [ -z ${OUTPUT_DIR} ]; then
  echo "output directory is not specifed. Please check -o is correct.."
  exit 2
fi

FILE_BASE=$(basename ${GRAPHDEF_FILE})
MODEL_NAME="${FILE_BASE%.*}"
TMPDIR=$(mktemp -d)
trap "{ rm -rf $TMPDIR; }" EXIT

# activate python virtual environment
VIRTUALENV_LINUX="${ROOT}/bin/venv/bin/activate"
VIRTUALENV_WINDOWS="${ROOT}/bin/venv/Scripts/activate"

if [ -e ${VIRTUALENV_LINUX} ]; then
  source ${VIRTUALENV_LINUX}
elif [ -e ${VIRTUALENV_WINDOWS} ]; then
  source ${VIRTUALENV_WINDOWS}
fi

# parse inputs, outputs from info file
INPUT=$(awk -F, '/^input/ { print $2 }' ${INFO_FILE} | cut -d: -f1 | tr -d ' ' | paste -d, -s)
OUTPUT=$(awk -F, '/^output/ { print $2 }' ${INFO_FILE} | cut -d: -f1 | tr -d ' ' | paste -d, -s)

INPUT_SHAPES=$(grep ^input ${INFO_FILE} | cut -d "[" -f2 | cut -d "]" -f1 | tr -d ' ' | xargs | tr ' ' ':')

ONE_IMPORT_BCQ_SCRIPT="${ROOT}/bin/one-import-bcq ${TF_INTERFACE} "
ONE_IMPORT_BCQ_SCRIPT+="-i ${GRAPHDEF_FILE} "
ONE_IMPORT_BCQ_SCRIPT+="-o ${TMPDIR}/${MODEL_NAME}.tmp.circle "
ONE_IMPORT_BCQ_SCRIPT+="-I ${INPUT} "
ONE_IMPORT_BCQ_SCRIPT+="-O ${OUTPUT} "
if [ ! -z ${INPUT_SHAPES} ]; then
  ONE_IMPORT_BCQ_SCRIPT+="-s ${INPUT_SHAPES} "
fi

${ONE_IMPORT_BCQ_SCRIPT}

# optimize
"${ROOT}/bin/circle2circle" --O1 "${TMPDIR}/${MODEL_NAME}.tmp.circle" "${TMPDIR}/${MODEL_NAME}.circle"

"${ROOT}/bin/model2nnpkg.sh" -o "${OUTPUT_DIR}" "${TMPDIR}/${MODEL_NAME}.circle"