summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDuc Ngo <duc@fb.com>2019-04-08 11:48:42 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-08 11:54:10 -0700
commite7b2669151ad73a9833eab519c093ef6750ac635 (patch)
treeff29c4524921c140141f8f49027998aa671957b1
parent66a3277dfa028a4a00693f78184202c48395dab6 (diff)
downloadpytorch-e7b2669151ad73a9833eab519c093ef6750ac635.tar.gz
pytorch-e7b2669151ad73a9833eab519c093ef6750ac635.tar.bz2
pytorch-e7b2669151ad73a9833eab519c093ef6750ac635.zip
caffe2 - Expose tensor filler util to Python (#18886)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18886 Expose tensor filler util to Python and add a unit test (both C++/Python) Reviewed By: salexspb Differential Revision: D14784470 fbshipit-source-id: bb8e013d1755c27c166e87d5a8491a97c65d3d8d
-rw-r--r--caffe2/CMakeLists.txt1
-rw-r--r--caffe2/core/test_utils.h2
-rw-r--r--caffe2/predictor/emulator/CMakeLists.txt13
-rw-r--r--caffe2/predictor/emulator/data_filler.cc9
-rw-r--r--caffe2/predictor/emulator/data_filler.h7
-rw-r--r--caffe2/predictor/emulator/data_filler_test.cc25
-rw-r--r--caffe2/python/filler_test.py20
-rw-r--r--caffe2/python/pybind_state.cc14
-rw-r--r--caffe2/python/workspace.py5
9 files changed, 95 insertions, 1 deletions
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 31b842b9d6..4e3f036f7f 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -69,6 +69,7 @@ if(NOT BUILD_ATEN_ONLY)
add_subdirectory(core)
add_subdirectory(utils)
add_subdirectory(predictor)
+ add_subdirectory(predictor/emulator)
add_subdirectory(core/nomnigraph)
add_subdirectory(serialize)
if (USE_NVRTC)
diff --git a/caffe2/core/test_utils.h b/caffe2/core/test_utils.h
index fcf069d607..7e286e1d30 100644
--- a/caffe2/core/test_utils.h
+++ b/caffe2/core/test_utils.h
@@ -56,7 +56,7 @@ void assertTensorListEquals(
const Workspace& workspace2);
// Read a tensor from the workspace.
-const caffe2::Tensor& getTensor(
+CAFFE2_API const caffe2::Tensor& getTensor(
const caffe2::Workspace& workspace,
const std::string& name);
diff --git a/caffe2/predictor/emulator/CMakeLists.txt b/caffe2/predictor/emulator/CMakeLists.txt
new file mode 100644
index 0000000000..690699040e
--- /dev/null
+++ b/caffe2/predictor/emulator/CMakeLists.txt
@@ -0,0 +1,13 @@
+set(Caffe2_EMULATOR_CPU_SRC
+ "${CMAKE_CURRENT_SOURCE_DIR}/data_filler.h"
+ "${CMAKE_CURRENT_SOURCE_DIR}/data_filler.cc"
+)
+set(Caffe2_EMULATOR_CPU_TEST_SRC
+ "${CMAKE_CURRENT_SOURCE_DIR}/data_filler_test.cc")
+
+# Common files that are always going to be included.
+list(APPEND Caffe2_CPU_SRCS ${Caffe2_EMULATOR_CPU_SRC})
+list(APPEND Caffe2_CPU_TEST_SRCS ${Caffe2_EMULATOR_CPU_TEST_SRC})
+
+set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
+set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
diff --git a/caffe2/predictor/emulator/data_filler.cc b/caffe2/predictor/emulator/data_filler.cc
index e4e64a3fd8..1979025c14 100644
--- a/caffe2/predictor/emulator/data_filler.cc
+++ b/caffe2/predictor/emulator/data_filler.cc
@@ -245,5 +245,14 @@ void TestDataRandomFiller::fillInputToWorkspace(Workspace* workspace) const {
}
}
+void fillRandomNetworkInputs(
+ const NetDef& net,
+ const std::vector<std::vector<std::vector<int64_t>>>& inputDims,
+ const std::vector<std::vector<std::string>>& inputTypes,
+ Workspace* workspace) {
+ TestDataRandomFiller(net, inputDims, inputTypes)
+ .fillInputToWorkspace(workspace);
+}
+
} // namespace emulator
} // namespace caffe2
diff --git a/caffe2/predictor/emulator/data_filler.h b/caffe2/predictor/emulator/data_filler.h
index a540e4aee3..78692ac9b3 100644
--- a/caffe2/predictor/emulator/data_filler.h
+++ b/caffe2/predictor/emulator/data_filler.h
@@ -138,5 +138,12 @@ class TestDataRandomFiller : public DataRandomFiller {
void fillInputToWorkspace(Workspace* workspace) const;
};
+// Convenient helpers to fill data to workspace.
+CAFFE2_API void fillRandomNetworkInputs(
+ const NetDef& net,
+ const std::vector<std::vector<std::vector<int64_t>>>& inputDims,
+ const std::vector<std::vector<std::string>>& inputTypes,
+ Workspace* workspace);
+
} // namespace emulator
} // namespace caffe2
diff --git a/caffe2/predictor/emulator/data_filler_test.cc b/caffe2/predictor/emulator/data_filler_test.cc
new file mode 100644
index 0000000000..b29bcec718
--- /dev/null
+++ b/caffe2/predictor/emulator/data_filler_test.cc
@@ -0,0 +1,25 @@
+#include "caffe2/core/common.h"
+#include "caffe2/core/test_utils.h"
+#include "caffe2/predictor/emulator/data_filler.h"
+
+#include <gtest/gtest.h>
+
+TEST(DataFiller, FillNetInputTest) {
+ using namespace caffe2::testing;
+ using namespace caffe2::emulator;
+ caffe2::NetDef net;
+ NetMutator(&net)
+ .newOp("Concat", {"X0", "X1", "X2"}, {"concat_out", "split_info"})
+ .addArgument("axis", 1);
+
+ std::vector<int64_t> input_dim = {30, 20};
+ std::vector<std::vector<std::vector<int64_t>>> input_dims = {
+ {/* X0 */ input_dim, /* X1 */ input_dim, /* X2 */ input_dim}};
+ std::vector<std::vector<std::string>> input_types = {
+ {"float", "float", "float"}};
+ caffe2::Workspace workspace;
+ EXPECT_FALSE(workspace.HasBlob("X0"));
+ fillRandomNetworkInputs(net, input_dims, input_types, &workspace);
+ EXPECT_TRUE(workspace.HasBlob("X0"));
+ EXPECT_EQ(getTensor(workspace, "X0").sizes(), input_dim);
+}
diff --git a/caffe2/python/filler_test.py b/caffe2/python/filler_test.py
new file mode 100644
index 0000000000..52ea756d5b
--- /dev/null
+++ b/caffe2/python/filler_test.py
@@ -0,0 +1,20 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from caffe2.python import core, test_util, workspace
+
+
+class TestFiller(test_util.TestCase):
+ def test_filler(self):
+ net = core.Net("test_filler")
+ net.Concat(["X0", "X1", "X2"], ["concat_out", "split_info"])
+ self.assertFalse(workspace.HasBlob("X0"))
+ input_dim = (30, 20)
+ workspace.FillRandomNetworkInputs(net, [[input_dim, input_dim, input_dim]], [["float", "float", "float"]])
+ self.assertTrue(workspace.HasBlob("X0"))
+ self.assertEqual(workspace.FetchBlob("X0").shape, input_dim)
+
+ with self.assertRaises(RuntimeError):
+ # Filler should throw if number of input dims/types is mismatched.
+ workspace.FillRandomNetworkInputs(net, [[input_dim]], [["float"]])
diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc
index 18f9097088..f437b4e0c3 100644
--- a/caffe2/python/pybind_state.cc
+++ b/caffe2/python/pybind_state.cc
@@ -25,6 +25,7 @@
#include "caffe2/opt/onnxifi_transformer.h"
#include "caffe2/opt/optimize_ideep.h"
#include "caffe2/opt/passes.h"
+#include "caffe2/predictor/emulator/data_filler.h"
#include "caffe2/predictor/predictor.h"
#include "caffe2/python/pybind_state_registry.h"
#include "caffe2/utils/cpuid.h"
@@ -1146,6 +1147,19 @@ void addGlobalMethods(py::module& m) {
return gWorkspace->HasBlob(name);
});
m.def(
+ "fill_random_network_inputs",
+ [](const py::bytes& net_def,
+ const std::vector<std::vector<std::vector<int64_t>>>& inputDims,
+ const std::vector<std::vector<std::string>>& inputTypes) {
+ CAFFE_ENFORCE(gWorkspace);
+ py::gil_scoped_release g;
+ NetDef net;
+ CAFFE_ENFORCE(
+ ParseProtoFromLargeString(net_def.cast<std::string>(), &net));
+ caffe2::emulator::fillRandomNetworkInputs(
+ net, inputDims, inputTypes, gWorkspace);
+ });
+ m.def(
"create_net",
[](py::bytes net_def, bool overwrite) {
CAFFE_ENFORCE(gWorkspace);
diff --git a/caffe2/python/workspace.py b/caffe2/python/workspace.py
index 18fcd9bb42..c288650970 100644
--- a/caffe2/python/workspace.py
+++ b/caffe2/python/workspace.py
@@ -83,6 +83,11 @@ GetNumNUMANodes = C.get_num_numa_nodes
GetBlobNUMANode = C.get_blob_numa_node
GetBlobSizeBytes = C.get_blob_size_bytes
+
+def FillRandomNetworkInputs(net, input_dims, input_types):
+ C.fill_random_network_inputs(net.Proto().SerializeToString(), input_dims, input_types)
+
+
def _GetFreeFlaskPort():
"""Get a free flask port."""
# We will prefer to use 5000. If not, we will then pick a random port.