summaryrefslogtreecommitdiff
path: root/caffe2/python/pybind_state.cc
diff options
context:
space:
mode:
Diffstat (limited to 'caffe2/python/pybind_state.cc')
-rw-r--r--caffe2/python/pybind_state.cc14
1 files changed, 14 insertions, 0 deletions
diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc
index 18f9097088..f437b4e0c3 100644
--- a/caffe2/python/pybind_state.cc
+++ b/caffe2/python/pybind_state.cc
@@ -25,6 +25,7 @@
#include "caffe2/opt/onnxifi_transformer.h"
#include "caffe2/opt/optimize_ideep.h"
#include "caffe2/opt/passes.h"
+#include "caffe2/predictor/emulator/data_filler.h"
#include "caffe2/predictor/predictor.h"
#include "caffe2/python/pybind_state_registry.h"
#include "caffe2/utils/cpuid.h"
@@ -1146,6 +1147,19 @@ void addGlobalMethods(py::module& m) {
return gWorkspace->HasBlob(name);
});
m.def(
+ "fill_random_network_inputs",
+ [](const py::bytes& net_def,
+ const std::vector<std::vector<std::vector<int64_t>>>& inputDims,
+ const std::vector<std::vector<std::string>>& inputTypes) {
+ CAFFE_ENFORCE(gWorkspace);
+ py::gil_scoped_release g;
+ NetDef net;
+ CAFFE_ENFORCE(
+ ParseProtoFromLargeString(net_def.cast<std::string>(), &net));
+ caffe2::emulator::fillRandomNetworkInputs(
+ net, inputDims, inputTypes, gWorkspace);
+ });
+ m.def(
"create_net",
[](py::bytes net_def, bool overwrite) {
CAFFE_ENFORCE(gWorkspace);