summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorJonathan L Long <jonlong@cs.berkeley.edu>2014-04-25 14:33:25 -0700
committerJonathan L Long <jonlong@cs.berkeley.edu>2014-05-02 13:25:51 -0700
commit76c255449f62232f1179fecf19f1188f53c22600 (patch)
treeabb039d57af85e9fceb03750380ee03fd00bb8e1 /python
parente1072a66d467b743df75435e1a28a1e34a1a4f25 (diff)
downloadcaffe-76c255449f62232f1179fecf19f1188f53c22600.tar.gz
caffe-76c255449f62232f1179fecf19f1188f53c22600.tar.bz2
caffe-76c255449f62232f1179fecf19f1188f53c22600.zip
pycaffe: add Net.set_input_arrays for input from numpy
This requires a net whose first layer is a MemoryDataLayer.
Diffstat (limited to 'python')
-rw-r--r--python/caffe/_caffe.cpp66
1 files changed, 65 insertions, 1 deletions
diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp
index 853ddbe1..f5d0d22f 100644
--- a/python/caffe/_caffe.cpp
+++ b/python/caffe/_caffe.cpp
@@ -158,6 +158,8 @@ struct CaffeNet {
virtual ~CaffeNet() {}
+ // this function is mostly redundant with the one below, but should go away
+ // with new pycaffe
inline void check_array_against_blob(
PyArrayObject* arr, Blob<float>* blob) {
CHECK(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS);
@@ -170,6 +172,29 @@ struct CaffeNet {
CHECK_EQ(dims[3], blob->width());
}
+ // generate Python exceptions for badly shaped or discontiguous arrays
+ inline void check_contiguous_array(PyArrayObject* arr, string name,
+ int channels, int height, int width) {
+ if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) {
+ throw std::runtime_error(name + " must be C contiguous");
+ }
+ if (PyArray_NDIM(arr) != 4) {
+ throw std::runtime_error(name + " must be 4-d");
+ }
+ if (PyArray_TYPE(arr) != NPY_FLOAT32) {
+ throw std::runtime_error(name + " must be float32");
+ }
+ if (PyArray_DIMS(arr)[1] != channels) {
+ throw std::runtime_error(name + " has wrong number of channels");
+ }
+ if (PyArray_DIMS(arr)[2] != height) {
+ throw std::runtime_error(name + " has wrong height");
+ }
+ if (PyArray_DIMS(arr)[3] != width) {
+ throw std::runtime_error(name + " has wrong width");
+ }
+ }
+
// The actual forward function. It takes in a python list of numpy arrays as
// input and a python list of numpy arrays as output. The input and output
// should all have correct shapes, are single-precisionabcdnt- and
@@ -267,6 +292,41 @@ struct CaffeNet {
net_->ForwardPrefilled();
}
+ void set_input_arrays(object data_obj, object labels_obj) {
+ // check that this network has an input MemoryDataLayer
+ shared_ptr<MemoryDataLayer<float> > md_layer =
+ boost::dynamic_pointer_cast<MemoryDataLayer<float> >(net_->layers()[0]);
+ if (!md_layer) {
+ throw std::runtime_error("set_input_arrays may only be called if the"
+ " first layer is a MemoryDataLayer");
+ }
+
+ // check that we were passed appropriately-sized contiguous memory
+ PyArrayObject* data_arr =
+ reinterpret_cast<PyArrayObject*>(data_obj.ptr());
+ PyArrayObject* labels_arr =
+ reinterpret_cast<PyArrayObject*>(labels_obj.ptr());
+ check_contiguous_array(data_arr, "data array", md_layer->datum_channels(),
+ md_layer->datum_height(), md_layer->datum_width());
+ check_contiguous_array(labels_arr, "labels array", 1, 1, 1);
+ if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) {
+ throw std::runtime_error("data and labels must have the same first"
+ " dimension");
+ }
+ if (PyArray_DIMS(data_arr)[0] % md_layer->batch_size() != 0) {
+ throw std::runtime_error("first dimensions of input arrays must be a"
+ " multiple of batch size");
+ }
+
+ // hold references
+ input_data_ = data_obj;
+ input_labels_ = labels_obj;
+
+ md_layer->Reset(static_cast<float*>(PyArray_DATA(data_arr)),
+ static_cast<float*>(PyArray_DATA(labels_arr)),
+ PyArray_DIMS(data_arr)[0]);
+ }
+
// The caffe::Caffe utility functions.
void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }
@@ -292,6 +352,9 @@ struct CaffeNet {
// The pointer to the internal caffe::Net instant.
shared_ptr<Net<float> > net_;
+ // if taking input from an ndarray, we need to hold references
+ object input_data_;
+ object input_labels_;
};
class CaffeSGDSolver {
@@ -334,7 +397,8 @@ BOOST_PYTHON_MODULE(_caffe) {
.def("set_device", &CaffeNet::set_device)
// rename blobs here since the pycaffe.py wrapper will replace it
.add_property("_blobs", &CaffeNet::blobs)
- .add_property("layers", &CaffeNet::layers);
+ .add_property("layers", &CaffeNet::layers)
+ .def("set_input_arrays", &CaffeNet::set_input_arrays);
boost::python::class_<CaffeBlob, CaffeBlobWrap>(
"Blob", boost::python::no_init)