diff options
author | Jonathan L Long <jonlong@cs.berkeley.edu> | 2014-04-25 14:33:25 -0700 |
---|---|---|
committer | Jonathan L Long <jonlong@cs.berkeley.edu> | 2014-05-02 13:25:51 -0700 |
commit | 76c255449f62232f1179fecf19f1188f53c22600 (patch) | |
tree | abb039d57af85e9fceb03750380ee03fd00bb8e1 /python | |
parent | e1072a66d467b743df75435e1a28a1e34a1a4f25 (diff) | |
download | caffe-76c255449f62232f1179fecf19f1188f53c22600.tar.gz caffe-76c255449f62232f1179fecf19f1188f53c22600.tar.bz2 caffe-76c255449f62232f1179fecf19f1188f53c22600.zip |
pycaffe: add Net.set_input_arrays for input from numpy
This requires a net whose first layer is a MemoryDataLayer.
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/_caffe.cpp | 66 |
1 files changed, 65 insertions, 1 deletions
diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index 853ddbe1..f5d0d22f 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -158,6 +158,8 @@ struct CaffeNet { virtual ~CaffeNet() {} + // this function is mostly redundant with the one below, but should go away + // with new pycaffe inline void check_array_against_blob( PyArrayObject* arr, Blob<float>* blob) { CHECK(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS); @@ -170,6 +172,29 @@ struct CaffeNet { CHECK_EQ(dims[3], blob->width()); } + // generate Python exceptions for badly shaped or discontiguous arrays + inline void check_contiguous_array(PyArrayObject* arr, string name, + int channels, int height, int width) { + if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) { + throw std::runtime_error(name + " must be C contiguous"); + } + if (PyArray_NDIM(arr) != 4) { + throw std::runtime_error(name + " must be 4-d"); + } + if (PyArray_TYPE(arr) != NPY_FLOAT32) { + throw std::runtime_error(name + " must be float32"); + } + if (PyArray_DIMS(arr)[1] != channels) { + throw std::runtime_error(name + " has wrong number of channels"); + } + if (PyArray_DIMS(arr)[2] != height) { + throw std::runtime_error(name + " has wrong height"); + } + if (PyArray_DIMS(arr)[3] != width) { + throw std::runtime_error(name + " has wrong width"); + } + } + // The actual forward function. It takes in a python list of numpy arrays as // input and a python list of numpy arrays as output. The input and output // should all have correct shapes, are single-precisionabcdnt- and @@ -267,6 +292,41 @@ struct CaffeNet { net_->ForwardPrefilled(); } + void set_input_arrays(object data_obj, object labels_obj) { + // check that this network has an input MemoryDataLayer + shared_ptr<MemoryDataLayer<float> > md_layer = + boost::dynamic_pointer_cast<MemoryDataLayer<float> >(net_->layers()[0]); + if (!md_layer) { + throw std::runtime_error("set_input_arrays may only be called if the" + " first layer is a MemoryDataLayer"); + } + + // check that we were passed appropriately-sized contiguous memory + PyArrayObject* data_arr = + reinterpret_cast<PyArrayObject*>(data_obj.ptr()); + PyArrayObject* labels_arr = + reinterpret_cast<PyArrayObject*>(labels_obj.ptr()); + check_contiguous_array(data_arr, "data array", md_layer->datum_channels(), + md_layer->datum_height(), md_layer->datum_width()); + check_contiguous_array(labels_arr, "labels array", 1, 1, 1); + if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) { + throw std::runtime_error("data and labels must have the same first" + " dimension"); + } + if (PyArray_DIMS(data_arr)[0] % md_layer->batch_size() != 0) { + throw std::runtime_error("first dimensions of input arrays must be a" + " multiple of batch size"); + } + + // hold references + input_data_ = data_obj; + input_labels_ = labels_obj; + + md_layer->Reset(static_cast<float*>(PyArray_DATA(data_arr)), + static_cast<float*>(PyArray_DATA(labels_arr)), + PyArray_DIMS(data_arr)[0]); + } + // The caffe::Caffe utility functions. void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); } void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); } @@ -292,6 +352,9 @@ struct CaffeNet { // The pointer to the internal caffe::Net instant. shared_ptr<Net<float> > net_; + // if taking input from an ndarray, we need to hold references + object input_data_; + object input_labels_; }; class CaffeSGDSolver { @@ -334,7 +397,8 @@ BOOST_PYTHON_MODULE(_caffe) { .def("set_device", &CaffeNet::set_device) // rename blobs here since the pycaffe.py wrapper will replace it .add_property("_blobs", &CaffeNet::blobs) - .add_property("layers", &CaffeNet::layers); + .add_property("layers", &CaffeNet::layers) + .def("set_input_arrays", &CaffeNet::set_input_arrays); boost::python::class_<CaffeBlob, CaffeBlobWrap>( "Blob", boost::python::no_init) |