summaryrefslogtreecommitdiff
path: root/compiler/nnc/utils/model_runner/model_runner_caffe2.py
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/nnc/utils/model_runner/model_runner_caffe2.py')
-rwxr-xr-xcompiler/nnc/utils/model_runner/model_runner_caffe2.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/compiler/nnc/utils/model_runner/model_runner_caffe2.py b/compiler/nnc/utils/model_runner/model_runner_caffe2.py
new file mode 100755
index 000000000..0c8feca92
--- /dev/null
+++ b/compiler/nnc/utils/model_runner/model_runner_caffe2.py
@@ -0,0 +1,23 @@
+from common_place import *
+
+from caffe2.python import workspace
+
+
+def run_caffe2(init_net, predict_net, input_path, output_path=''):
+ x = read_input(input_path)
+ with open(init_net, 'rb') as f:
+ init_net = f.read()
+
+ with open(predict_net, 'rb') as f:
+ predict_net = f.read()
+ p = workspace.Predictor(init_net, predict_net)
+ # TODO get 'data' parameter more universal, blobs contain other names
+ results = p.run({'data': x})
+ print(results)
+ save_result(output_path, results)
+
+
+if __name__ == '__main__':
+ args = regular_step()
+
+ run_caffe2(args.model[0], args.model[1], args.input, args.output_path)