diff options
Diffstat (limited to 'caffe2/python/pybind_state.cc')
-rw-r--r-- | caffe2/python/pybind_state.cc | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 18f9097088..f437b4e0c3 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -25,6 +25,7 @@ #include "caffe2/opt/onnxifi_transformer.h" #include "caffe2/opt/optimize_ideep.h" #include "caffe2/opt/passes.h" +#include "caffe2/predictor/emulator/data_filler.h" #include "caffe2/predictor/predictor.h" #include "caffe2/python/pybind_state_registry.h" #include "caffe2/utils/cpuid.h" @@ -1146,6 +1147,19 @@ void addGlobalMethods(py::module& m) { return gWorkspace->HasBlob(name); }); m.def( + "fill_random_network_inputs", + [](const py::bytes& net_def, + const std::vector<std::vector<std::vector<int64_t>>>& inputDims, + const std::vector<std::vector<std::string>>& inputTypes) { + CAFFE_ENFORCE(gWorkspace); + py::gil_scoped_release g; + NetDef net; + CAFFE_ENFORCE( + ParseProtoFromLargeString(net_def.cast<std::string>(), &net)); + caffe2::emulator::fillRandomNetworkInputs( + net, inputDims, inputTypes, gWorkspace); + }); + m.def( "create_net", [](py::bytes net_def, bool overwrite) { CAFFE_ENFORCE(gWorkspace); |