diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-05-14 14:02:54 -0700 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-05-14 14:02:54 -0700 |
commit | 9d4324e5e7f0187027c4cf6634d8b00116ffb8ce (patch) | |
tree | a753e46df03e70b5f910d3076fb9a6f819034ec0 /python | |
parent | 0e5a5cf50e9d17dbfe96b8269145b934e99b29a5 (diff) | |
download | caffe-9d4324e5e7f0187027c4cf6634d8b00116ffb8ce.tar.gz caffe-9d4324e5e7f0187027c4cf6634d8b00116ffb8ce.tar.bz2 caffe-9d4324e5e7f0187027c4cf6634d8b00116ffb8ce.zip |
bad forward/backward inputs throw exceptions instead of crashing python
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/_caffe.cpp | 27 |
1 files changed, 14 insertions, 13 deletions
diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index e1ee652b..18b96b92 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -162,15 +162,12 @@ struct CaffeNet { // Check that an array is acceptable for blob assignment // as described in the preface to Forward(). inline void check_array_against_blob( - PyArrayObject* arr, Blob<float>* blob) { - CHECK(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS); - CHECK_EQ(PyArray_NDIM(arr), 4); - CHECK_EQ(PyArray_ITEMSIZE(arr), 4); - npy_intp* dims = PyArray_DIMS(arr); - CHECK_EQ(dims[0], blob->num()); - CHECK_EQ(dims[1], blob->channels()); - CHECK_EQ(dims[2], blob->height()); - CHECK_EQ(dims[3], blob->width()); + PyArrayObject* arr, Blob<float>* blob, string name) { + check_contiguous_array(arr, name, blob->channels(), blob->height(), + blob->width()); + if (PyArray_DIMS(arr)[0] != blob->num()) { + throw std::runtime_error(name + " has wrong batch size"); + } } // generate Python exceptions for badly shaped or discontiguous arrays @@ -207,7 +204,8 @@ struct CaffeNet { for (int i = 0; i < input_blobs.size(); ++i) { object elem = bottom[i]; PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr()); - check_array_against_blob(arr, input_blobs[i]); + check_array_against_blob(arr, input_blobs[i], + net_->blob_names()[net_->input_blob_indices()[i]]); switch (Caffe::mode()) { case Caffe::CPU: memcpy(input_blobs[i]->mutable_cpu_data(), PyArray_DATA(arr), @@ -227,7 +225,8 @@ struct CaffeNet { for (int i = 0; i < output_blobs.size(); ++i) { object elem = top[i]; PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr()); - check_array_against_blob(arr, output_blobs[i]); + check_array_against_blob(arr, output_blobs[i], + net_->blob_names()[net_->input_blob_indices()[i]]); switch (Caffe::mode()) { case Caffe::CPU: memcpy(PyArray_DATA(arr), output_blobs[i]->cpu_data(), @@ -252,7 +251,8 @@ struct CaffeNet { for (int i = 0; i < output_blobs.size(); ++i) { object elem = top_diff[i]; PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr()); - check_array_against_blob(arr, output_blobs[i]); + check_array_against_blob(arr, output_blobs[i], + net_->blob_names()[net_->input_blob_indices()[i]]); switch (Caffe::mode()) { case Caffe::CPU: memcpy(output_blobs[i]->mutable_cpu_diff(), PyArray_DATA(arr), @@ -272,7 +272,8 @@ struct CaffeNet { for (int i = 0; i < input_blobs.size(); ++i) { object elem = bottom_diff[i]; PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr()); - check_array_against_blob(arr, input_blobs[i]); + check_array_against_blob(arr, input_blobs[i], + net_->blob_names()[net_->input_blob_indices()[i]]); switch (Caffe::mode()) { case Caffe::CPU: memcpy(PyArray_DATA(arr), input_blobs[i]->cpu_diff(), |