summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-14 14:02:54 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-14 14:02:54 -0700
commit9d4324e5e7f0187027c4cf6634d8b00116ffb8ce (patch)
treea753e46df03e70b5f910d3076fb9a6f819034ec0 /python
parent0e5a5cf50e9d17dbfe96b8269145b934e99b29a5 (diff)
downloadcaffe-9d4324e5e7f0187027c4cf6634d8b00116ffb8ce.tar.gz
caffe-9d4324e5e7f0187027c4cf6634d8b00116ffb8ce.tar.bz2
caffe-9d4324e5e7f0187027c4cf6634d8b00116ffb8ce.zip
bad forward/backward inputs throw exceptions instead of crashing python
Diffstat (limited to 'python')
-rw-r--r--python/caffe/_caffe.cpp27
1 files changed, 14 insertions, 13 deletions
diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp
index e1ee652b..18b96b92 100644
--- a/python/caffe/_caffe.cpp
+++ b/python/caffe/_caffe.cpp
@@ -162,15 +162,12 @@ struct CaffeNet {
// Check that an array is acceptable for blob assignment
// as described in the preface to Forward().
inline void check_array_against_blob(
- PyArrayObject* arr, Blob<float>* blob) {
- CHECK(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS);
- CHECK_EQ(PyArray_NDIM(arr), 4);
- CHECK_EQ(PyArray_ITEMSIZE(arr), 4);
- npy_intp* dims = PyArray_DIMS(arr);
- CHECK_EQ(dims[0], blob->num());
- CHECK_EQ(dims[1], blob->channels());
- CHECK_EQ(dims[2], blob->height());
- CHECK_EQ(dims[3], blob->width());
+ PyArrayObject* arr, Blob<float>* blob, string name) {
+ check_contiguous_array(arr, name, blob->channels(), blob->height(),
+ blob->width());
+ if (PyArray_DIMS(arr)[0] != blob->num()) {
+ throw std::runtime_error(name + " has wrong batch size");
+ }
}
// generate Python exceptions for badly shaped or discontiguous arrays
@@ -207,7 +204,8 @@ struct CaffeNet {
for (int i = 0; i < input_blobs.size(); ++i) {
object elem = bottom[i];
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
- check_array_against_blob(arr, input_blobs[i]);
+ check_array_against_blob(arr, input_blobs[i],
+ net_->blob_names()[net_->input_blob_indices()[i]]);
switch (Caffe::mode()) {
case Caffe::CPU:
memcpy(input_blobs[i]->mutable_cpu_data(), PyArray_DATA(arr),
@@ -227,7 +225,8 @@ struct CaffeNet {
for (int i = 0; i < output_blobs.size(); ++i) {
object elem = top[i];
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
- check_array_against_blob(arr, output_blobs[i]);
+ check_array_against_blob(arr, output_blobs[i],
+ net_->blob_names()[net_->input_blob_indices()[i]]);
switch (Caffe::mode()) {
case Caffe::CPU:
memcpy(PyArray_DATA(arr), output_blobs[i]->cpu_data(),
@@ -252,7 +251,8 @@ struct CaffeNet {
for (int i = 0; i < output_blobs.size(); ++i) {
object elem = top_diff[i];
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
- check_array_against_blob(arr, output_blobs[i]);
+ check_array_against_blob(arr, output_blobs[i],
+ net_->blob_names()[net_->input_blob_indices()[i]]);
switch (Caffe::mode()) {
case Caffe::CPU:
memcpy(output_blobs[i]->mutable_cpu_diff(), PyArray_DATA(arr),
@@ -272,7 +272,8 @@ struct CaffeNet {
for (int i = 0; i < input_blobs.size(); ++i) {
object elem = bottom_diff[i];
PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
- check_array_against_blob(arr, input_blobs[i]);
+ check_array_against_blob(arr, input_blobs[i],
+ net_->blob_names()[net_->input_blob_indices()[i]]);
switch (Caffe::mode()) {
case Caffe::CPU:
memcpy(PyArray_DATA(arr), input_blobs[i]->cpu_diff(),