blob: 80847b7df1dc3d69ade36a019f2650f0e8716118 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
from common_place import *
import tensorflow as tf
def run_tflite(model, input_path, output_path=''):
input = read_input(input_path)
interpreter = tf.contrib.lite.Interpreter(model_path=model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_data = input
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
save_result(output_path, output_data)
if __name__ == '__main__':
args = regular_step()
run_tflite(args.model[0], args.input, args.output_path)
|