summaryrefslogtreecommitdiff
path: root/python/caffe/io.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/caffe/io.py')
-rw-r--r--python/caffe/io.py154
1 files changed, 154 insertions, 0 deletions
diff --git a/python/caffe/io.py b/python/caffe/io.py
new file mode 100644
index 00000000..0bd2f812
--- /dev/null
+++ b/python/caffe/io.py
@@ -0,0 +1,154 @@
+import numpy as np
+import skimage.io
+import skimage.transform
+
+from caffe.proto import caffe_pb2
+
+
+def load_image(filename):
+ """
+ Load an image converting from grayscale or alpha as needed.
+
+ Take
+ filename: string
+
+ Give
+ image: an image of size (H x W x 3) with RGB channels of type uint8.
+ """
+ img = skimage.img_as_float(skimage.io.imread(filename)).astype(np.float32)
+ if img.ndim == 2:
+ img = np.tile(img[:, :, np.newaxis], (1, 1, 3))
+ elif img.shape[2] == 4:
+ img = img[:, :, :3]
+ return img
+
+
+def resize_image(im, new_dims, interp_order=1):
+ """
+ Resize an image array with interpolation.
+
+ Take
+ im: (H x W x K) ndarray
+ new_dims: (height, width) tuple of new dimensions.
+ interp_order: interpolation order, default is linear.
+
+ Give
+ im: resized ndarray with shape (new_dims[0], new_dims[1], K)
+ """
+ return skimage.transform.resize(im, new_dims, order=interp_order)
+
+
+def oversample(images, crop_dims):
+ """
+ Crop images into the four corners, center, and their mirrored versions.
+
+ Take
+ image: iterable of (H x W x K) ndarrays
+ crop_dims: (height, width) tuple for the crops.
+
+ Give
+ crops: (10*N x H x W x K) ndarray of crops for number of inputs N.
+ """
+ # Dimensions and center.
+ im_shape = np.array(images[0].shape)
+ crop_dims = np.array(crop_dims)
+ im_center = im_shape[:2] / 2.0
+
+ # Make crop coordinates
+ h_indices = (0, im_shape[0] - crop_dims[0])
+ w_indices = (0, im_shape[1] - crop_dims[1])
+ crops_ix = np.empty((5, 4), dtype=int)
+ curr = 0
+ for i in h_indices:
+ for j in w_indices:
+ crops_ix[curr] = (i, j, i + crop_dims[0], j + crop_dims[1])
+ curr += 1
+ crops_ix[4] = np.tile(im_center, (1, 2)) + np.concatenate([
+ -crop_dims / 2.0,
+ crop_dims / 2.0
+ ])
+ crops_ix = np.tile(crops_ix, (2, 1))
+
+ # Extract crops
+ crops = np.empty((10 * len(images), crop_dims[0], crop_dims[1],
+ im_shape[-1]), dtype=np.float32)
+ ix = 0
+ for im in images:
+ for crop in crops_ix:
+ crops[ix] = im[crop[0]:crop[2], crop[1]:crop[3], :]
+ ix += 1
+ crops[ix-5:ix] = crops[ix-5:ix, :, ::-1, :] # flip for mirrors
+ return crops
+
+
+def blobproto_to_array(blob, return_diff=False):
+ """Convert a blob proto to an array. In default, we will just return the data,
+ unless return_diff is True, in which case we will return the diff.
+ """
+ if return_diff:
+ return np.array(blob.diff).reshape(
+ blob.num, blob.channels, blob.height, blob.width)
+ else:
+ return np.array(blob.data).reshape(
+ blob.num, blob.channels, blob.height, blob.width)
+
+
+def array_to_blobproto(arr, diff=None):
+ """Converts a 4-dimensional array to blob proto. If diff is given, also
+ convert the diff. You need to make sure that arr and diff have the same
+ shape, and this function does not do sanity check.
+ """
+ if arr.ndim != 4:
+ raise ValueError('Incorrect array shape.')
+ blob = caffe_pb2.BlobProto()
+ blob.num, blob.channels, blob.height, blob.width = arr.shape;
+ blob.data.extend(arr.astype(float).flat)
+ if diff is not None:
+ blob.diff.extend(diff.astype(float).flat)
+ return blob
+
+
+def arraylist_to_blobprotovecor_str(arraylist):
+ """Converts a list of arrays to a serialized blobprotovec, which could be
+ then passed to a network for processing.
+ """
+ vec = caffe_pb2.BlobProtoVector()
+ vec.blobs.extend([array_to_blobproto(arr) for arr in arraylist])
+ return vec.SerializeToString()
+
+
+def blobprotovector_str_to_arraylist(str):
+ """Converts a serialized blobprotovec to a list of arrays.
+ """
+ vec = caffe_pb2.BlobProtoVector()
+ vec.ParseFromString(str)
+ return [blobproto_to_array(blob) for blob in vec.blobs]
+
+
+def array_to_datum(arr, label=0):
+ """Converts a 3-dimensional array to datum. If the array has dtype uint8,
+ the output data will be encoded as a string. Otherwise, the output data
+ will be stored in float format.
+ """
+ if arr.ndim != 3:
+ raise ValueError('Incorrect array shape.')
+ datum = caffe_pb2.Datum()
+ datum.channels, datum.height, datum.width = arr.shape
+ if arr.dtype == np.uint8:
+ datum.data = arr.tostring()
+ else:
+ datum.float_data.extend(arr.flat)
+ datum.label = label
+ return datum
+
+
+def datum_to_array(datum):
+ """Converts a datum to an array. Note that the label is not returned,
+ as one can easily get it by calling datum.label.
+ """
+ if len(datum.data):
+ return np.fromstring(datum.data, dtype = np.uint8).reshape(
+ datum.channels, datum.height, datum.width)
+ else:
+ return np.array(datum.float_data).astype(float).reshape(
+ datum.channels, datum.height, datum.width)