diff options
author | Duc Ngo <duc@fb.com> | 2019-04-08 11:48:42 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-08 11:54:10 -0700 |
commit | e7b2669151ad73a9833eab519c093ef6750ac635 (patch) | |
tree | ff29c4524921c140141f8f49027998aa671957b1 /caffe2 | |
parent | 66a3277dfa028a4a00693f78184202c48395dab6 (diff) | |
download | pytorch-e7b2669151ad73a9833eab519c093ef6750ac635.tar.gz pytorch-e7b2669151ad73a9833eab519c093ef6750ac635.tar.bz2 pytorch-e7b2669151ad73a9833eab519c093ef6750ac635.zip |
caffe2 - Expose tensor filler util to Python (#18886)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18886
Expose tensor filler util to Python and add a unit test (both C++/Python)
Reviewed By: salexspb
Differential Revision: D14784470
fbshipit-source-id: bb8e013d1755c27c166e87d5a8491a97c65d3d8d
Diffstat (limited to 'caffe2')
-rw-r--r-- | caffe2/CMakeLists.txt | 1 | ||||
-rw-r--r-- | caffe2/core/test_utils.h | 2 | ||||
-rw-r--r-- | caffe2/predictor/emulator/CMakeLists.txt | 13 | ||||
-rw-r--r-- | caffe2/predictor/emulator/data_filler.cc | 9 | ||||
-rw-r--r-- | caffe2/predictor/emulator/data_filler.h | 7 | ||||
-rw-r--r-- | caffe2/predictor/emulator/data_filler_test.cc | 25 | ||||
-rw-r--r-- | caffe2/python/filler_test.py | 20 | ||||
-rw-r--r-- | caffe2/python/pybind_state.cc | 14 | ||||
-rw-r--r-- | caffe2/python/workspace.py | 5 |
9 files changed, 95 insertions, 1 deletions
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 31b842b9d6..4e3f036f7f 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -69,6 +69,7 @@ if(NOT BUILD_ATEN_ONLY) add_subdirectory(core) add_subdirectory(utils) add_subdirectory(predictor) + add_subdirectory(predictor/emulator) add_subdirectory(core/nomnigraph) add_subdirectory(serialize) if (USE_NVRTC) diff --git a/caffe2/core/test_utils.h b/caffe2/core/test_utils.h index fcf069d607..7e286e1d30 100644 --- a/caffe2/core/test_utils.h +++ b/caffe2/core/test_utils.h @@ -56,7 +56,7 @@ void assertTensorListEquals( const Workspace& workspace2); // Read a tensor from the workspace. -const caffe2::Tensor& getTensor( +CAFFE2_API const caffe2::Tensor& getTensor( const caffe2::Workspace& workspace, const std::string& name); diff --git a/caffe2/predictor/emulator/CMakeLists.txt b/caffe2/predictor/emulator/CMakeLists.txt new file mode 100644 index 0000000000..690699040e --- /dev/null +++ b/caffe2/predictor/emulator/CMakeLists.txt @@ -0,0 +1,13 @@ +set(Caffe2_EMULATOR_CPU_SRC + "${CMAKE_CURRENT_SOURCE_DIR}/data_filler.h" + "${CMAKE_CURRENT_SOURCE_DIR}/data_filler.cc" +) +set(Caffe2_EMULATOR_CPU_TEST_SRC + "${CMAKE_CURRENT_SOURCE_DIR}/data_filler_test.cc") + +# Common files that are always going to be included. +list(APPEND Caffe2_CPU_SRCS ${Caffe2_EMULATOR_CPU_SRC}) +list(APPEND Caffe2_CPU_TEST_SRCS ${Caffe2_EMULATOR_CPU_TEST_SRC}) + +set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE) +set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE) diff --git a/caffe2/predictor/emulator/data_filler.cc b/caffe2/predictor/emulator/data_filler.cc index e4e64a3fd8..1979025c14 100644 --- a/caffe2/predictor/emulator/data_filler.cc +++ b/caffe2/predictor/emulator/data_filler.cc @@ -245,5 +245,14 @@ void TestDataRandomFiller::fillInputToWorkspace(Workspace* workspace) const { } } +void fillRandomNetworkInputs( + const NetDef& net, + const std::vector<std::vector<std::vector<int64_t>>>& inputDims, + const std::vector<std::vector<std::string>>& inputTypes, + Workspace* workspace) { + TestDataRandomFiller(net, inputDims, inputTypes) + .fillInputToWorkspace(workspace); +} + } // namespace emulator } // namespace caffe2 diff --git a/caffe2/predictor/emulator/data_filler.h b/caffe2/predictor/emulator/data_filler.h index a540e4aee3..78692ac9b3 100644 --- a/caffe2/predictor/emulator/data_filler.h +++ b/caffe2/predictor/emulator/data_filler.h @@ -138,5 +138,12 @@ class TestDataRandomFiller : public DataRandomFiller { void fillInputToWorkspace(Workspace* workspace) const; }; +// Convenient helpers to fill data to workspace. +CAFFE2_API void fillRandomNetworkInputs( + const NetDef& net, + const std::vector<std::vector<std::vector<int64_t>>>& inputDims, + const std::vector<std::vector<std::string>>& inputTypes, + Workspace* workspace); + } // namespace emulator } // namespace caffe2 diff --git a/caffe2/predictor/emulator/data_filler_test.cc b/caffe2/predictor/emulator/data_filler_test.cc new file mode 100644 index 0000000000..b29bcec718 --- /dev/null +++ b/caffe2/predictor/emulator/data_filler_test.cc @@ -0,0 +1,25 @@ +#include "caffe2/core/common.h" +#include "caffe2/core/test_utils.h" +#include "caffe2/predictor/emulator/data_filler.h" + +#include <gtest/gtest.h> + +TEST(DataFiller, FillNetInputTest) { + using namespace caffe2::testing; + using namespace caffe2::emulator; + caffe2::NetDef net; + NetMutator(&net) + .newOp("Concat", {"X0", "X1", "X2"}, {"concat_out", "split_info"}) + .addArgument("axis", 1); + + std::vector<int64_t> input_dim = {30, 20}; + std::vector<std::vector<std::vector<int64_t>>> input_dims = { + {/* X0 */ input_dim, /* X1 */ input_dim, /* X2 */ input_dim}}; + std::vector<std::vector<std::string>> input_types = { + {"float", "float", "float"}}; + caffe2::Workspace workspace; + EXPECT_FALSE(workspace.HasBlob("X0")); + fillRandomNetworkInputs(net, input_dims, input_types, &workspace); + EXPECT_TRUE(workspace.HasBlob("X0")); + EXPECT_EQ(getTensor(workspace, "X0").sizes(), input_dim); +} diff --git a/caffe2/python/filler_test.py b/caffe2/python/filler_test.py new file mode 100644 index 0000000000..52ea756d5b --- /dev/null +++ b/caffe2/python/filler_test.py @@ -0,0 +1,20 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from caffe2.python import core, test_util, workspace + + +class TestFiller(test_util.TestCase): + def test_filler(self): + net = core.Net("test_filler") + net.Concat(["X0", "X1", "X2"], ["concat_out", "split_info"]) + self.assertFalse(workspace.HasBlob("X0")) + input_dim = (30, 20) + workspace.FillRandomNetworkInputs(net, [[input_dim, input_dim, input_dim]], [["float", "float", "float"]]) + self.assertTrue(workspace.HasBlob("X0")) + self.assertEqual(workspace.FetchBlob("X0").shape, input_dim) + + with self.assertRaises(RuntimeError): + # Filler should throw if number of input dims/types is mismatched. + workspace.FillRandomNetworkInputs(net, [[input_dim]], [["float"]]) diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 18f9097088..f437b4e0c3 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -25,6 +25,7 @@ #include "caffe2/opt/onnxifi_transformer.h" #include "caffe2/opt/optimize_ideep.h" #include "caffe2/opt/passes.h" +#include "caffe2/predictor/emulator/data_filler.h" #include "caffe2/predictor/predictor.h" #include "caffe2/python/pybind_state_registry.h" #include "caffe2/utils/cpuid.h" @@ -1146,6 +1147,19 @@ void addGlobalMethods(py::module& m) { return gWorkspace->HasBlob(name); }); m.def( + "fill_random_network_inputs", + [](const py::bytes& net_def, + const std::vector<std::vector<std::vector<int64_t>>>& inputDims, + const std::vector<std::vector<std::string>>& inputTypes) { + CAFFE_ENFORCE(gWorkspace); + py::gil_scoped_release g; + NetDef net; + CAFFE_ENFORCE( + ParseProtoFromLargeString(net_def.cast<std::string>(), &net)); + caffe2::emulator::fillRandomNetworkInputs( + net, inputDims, inputTypes, gWorkspace); + }); + m.def( "create_net", [](py::bytes net_def, bool overwrite) { CAFFE_ENFORCE(gWorkspace); diff --git a/caffe2/python/workspace.py b/caffe2/python/workspace.py index 18fcd9bb42..c288650970 100644 --- a/caffe2/python/workspace.py +++ b/caffe2/python/workspace.py @@ -83,6 +83,11 @@ GetNumNUMANodes = C.get_num_numa_nodes GetBlobNUMANode = C.get_blob_numa_node GetBlobSizeBytes = C.get_blob_size_bytes + +def FillRandomNetworkInputs(net, input_dims, input_types): + C.fill_random_network_inputs(net.Proto().SerializeToString(), input_dims, input_types) + + def _GetFreeFlaskPort(): """Get a free flask port.""" # We will prefer to use 5000. If not, we will then pick a random port. |