summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/caffe/io.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/python/caffe/io.py b/python/caffe/io.py
index aabcfddb..a8354021 100644
--- a/python/caffe/io.py
+++ b/python/caffe/io.py
@@ -43,11 +43,17 @@ def resize_image(im, new_dims, interp_order=1):
im: resized ndarray with shape (new_dims[0], new_dims[1], K)
"""
if im.shape[-1] == 1 or im.shape[-1] == 3:
- # skimage is fast but only understands {1,3} channel images in [0, 1].
im_min, im_max = im.min(), im.max()
- im_std = (im - im_min) / (im_max - im_min)
- resized_std = resize(im_std, new_dims, order=interp_order)
- resized_im = resized_std * (im_max - im_min) + im_min
+ if im_max > im_min:
+ # skimage is fast but only understands {1,3} channel images in [0, 1].
+ im_std = (im - im_min) / (im_max - im_min)
+ resized_std = resize(im_std, new_dims, order=interp_order)
+ resized_im = resized_std * (im_max - im_min) + im_min
+ else:
+ # the image is a constant -- avoid divide by 0
+ ret = np.empty((new_dims[0], new_dims[1], im.shape[-1]), dtype=np.float32)
+ ret.fill(im_min)
+ return ret
else:
# ndimage interpolates anything but more slowly.
scale = tuple(np.array(new_dims) / np.array(im.shape[:2]))