summaryrefslogtreecommitdiff
path: root/compiler/nnc/utils/model_runner/model_runner_tflite.py
blob: 80847b7df1dc3d69ade36a019f2650f0e8716118 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from common_place import *
import tensorflow as tf


def run_tflite(model, input_path, output_path=''):
    input = read_input(input_path)

    interpreter = tf.contrib.lite.Interpreter(model_path=model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    input_data = input
    interpreter.set_tensor(input_details[0]['index'], input_data)

    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])
    print(output_data)
    save_result(output_path, output_data)


if __name__ == '__main__':
    args = regular_step()

    run_tflite(args.model[0], args.input, args.output_path)