summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-14 23:45:33 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-14 23:48:01 -0700
commit1b236802449227c1e6b841368868d1edb3cda732 (patch)
tree2865f3b6d66f180f807d0fd9ef071c0286e00548 /python
parent8af33e8ca2f56a3bef23935990b1c9ed65629918 (diff)
downloadcaffe-1b236802449227c1e6b841368868d1edb3cda732.tar.gz
caffe-1b236802449227c1e6b841368868d1edb3cda732.tar.bz2
caffe-1b236802449227c1e6b841368868d1edb3cda732.zip
batch inputs in python by forward_all() and forward_backward_all()
Diffstat (limited to 'python')
-rw-r--r--python/caffe/imagenet/wrapper.py4
-rw-r--r--python/caffe/pycaffe.py100
2 files changed, 102 insertions, 2 deletions
diff --git a/python/caffe/imagenet/wrapper.py b/python/caffe/imagenet/wrapper.py
index 4a5b6ed8..dd505e42 100644
--- a/python/caffe/imagenet/wrapper.py
+++ b/python/caffe/imagenet/wrapper.py
@@ -30,7 +30,6 @@ def oversample(image, center_only=False):
Output:
images: the output of size (10 x 3 x 227 x 227)
"""
- image = image.swapaxes(1, 2).swapaxes(0, 1)
indices = [0, IMAGE_DIM - CROPPED_DIM]
center = int(indices[1] / 2)
if center_only:
@@ -58,8 +57,9 @@ def prepare_image(filename, center_only=False):
img = np.tile(img[:, :, np.newaxis], (1, 1, 3))
elif img.shape[2] == 4:
img = img[:, :, :3]
- # Resize and convert to BGR
+ # Resize, convert to BGR, and permute axes to caffe order
img_reshape = (transform.resize(img, (IMAGE_DIM,IMAGE_DIM)) * 255)[:, :, ::-1]
+ img_reshape = img_reshape.swapaxes(1, 2).swapaxes(0, 1)
# subtract main
img_reshape -= IMAGENET_MEAN
return oversample(img_reshape, center_only)
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py
index d965227d..6f6cedd5 100644
--- a/python/caffe/pycaffe.py
+++ b/python/caffe/pycaffe.py
@@ -4,6 +4,7 @@ interface.
"""
from collections import OrderedDict
+from itertools import izip_longest
import numpy as np
from ._caffe import Net, SGDSolver
@@ -125,6 +126,76 @@ def _Net_backward(self, diffs=None, **kwargs):
Net.backward = _Net_backward
+def _Net_forward_all(self, blobs=None, **kwargs):
+ """
+ Run net forward in batches.
+
+ Take
+ blobs: list of blobs to extract as in forward()
+ kwargs: Keys are input blob names and values are lists of blobs.
+ Refer to forward().
+
+ Give
+ all_outs: {blob name: list of blobs} dict.
+ """
+ # Collect outputs from batches
+ all_outs = {out: [] for out in self.outputs + blobs}
+ for batch in self._batch(kwargs):
+ outs = self.forward(blobs=blobs, **batch)
+ for out, out_blobs in outs.items():
+ all_outs[out].extend(out_blobs)
+ # Discard padding at the end.
+ pad = len(all_outs.itervalues().next()) - len(kwargs.itervalues().next())
+ if pad:
+ for out in all_outs:
+ del all_outs[out][-pad:]
+ return all_outs
+
+Net.forward_all = _Net_forward_all
+
+
+def _Net_forward_backward_all(self, blobs=None, diffs=None, **kwargs):
+ """
+ Run net forward + backward in batches.
+
+ Take
+ blobs: list of blobs to extract as in forward()
+ diffs: list of diffs to extract as in backward()
+ kwargs: Keys are input (for forward) and output (for backward) blob names
+ and values are lists of blobs. Refer to forward() and backward().
+ Prefilled variants are called for lack of input or output blobs.
+
+ Give
+ all_blobs: {blob name: list of blobs} dict.
+ all_diffs: {blob name: list of diffs} dict.
+ """
+ # Batch blobs and diffs.
+ all_outs = {out: [] for out in self.outputs + (blobs or [])}
+ all_diffs = {diff: [] for diff in self.inputs + (diffs or [])}
+ forward_batches = self._batch({in_: kwargs[in_]
+ for in_ in self.inputs if in_ in kwargs})
+ backward_batches = self._batch({out: kwargs[out]
+ for out in self.outputs if out in kwargs})
+ # Collect outputs from batches (and heed lack of forward/backward batches).
+ for fb, bb in izip_longest(forward_batches, backward_batches, fillvalue={}):
+ batch_blobs = self.forward(blobs=blobs, **fb)
+ batch_diffs = self.backward(diffs=diffs, **bb)
+ for out, out_blobs in batch_blobs.items():
+ all_outs[out].extend(out_blobs)
+ for diff, out_diffs in batch_diffs.items():
+ all_diffs[diff].extend(out_diffs)
+ # Discard padding at the end.
+ pad = len(all_outs.itervalues().next()) - len(kwargs.itervalues().next())
+ if pad:
+ for out in all_outs:
+ del all_outs[out][-pad:]
+ for diff in all_diffs:
+ del all_diffs[diff][-pad:]
+ return all_outs, all_diffs
+
+Net.forward_backward_all = _Net_forward_backward_all
+
+
def _Net_set_mean(self, input_, mean_f, mode='image'):
"""
Set the mean to subtract for data centering.
@@ -244,3 +315,32 @@ def _Net_set_input_arrays(self, data, labels):
return self._set_input_arrays(data, labels)
Net.set_input_arrays = _Net_set_input_arrays
+
+
+def _Net_batch(self, blobs):
+ """
+ Batch blob lists according to net's batch size.
+
+ Take
+ blobs: Keys blob names and values are lists of blobs (of any length).
+ Naturally, all the lists should have the same length.
+
+ Give (yield)
+ batch: {blob name: list of blobs} dict for a single batch.
+ """
+ num = len(blobs.itervalues().next())
+ batch_size = self.blobs.itervalues().next().num
+ remainder = num % batch_size
+ num_batches = (num + remainder) / batch_size
+
+ # Yield full batches.
+ for b in range(num_batches-1):
+ for i in [b * batch_size]:
+ yield {name: blobs[name][i:i + batch_size] for name in blobs}
+
+ # Yield last padded batch, if any.
+ if remainder > 0:
+ yield {name: blobs[name][-remainder:] +
+ [np.zeros_like(blobs[name][0])] * remainder for name in blobs}
+
+Net._batch = _Net_batch