diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-05-14 23:45:33 -0700 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-05-14 23:48:01 -0700 |
commit | 1b236802449227c1e6b841368868d1edb3cda732 (patch) | |
tree | 2865f3b6d66f180f807d0fd9ef071c0286e00548 /python | |
parent | 8af33e8ca2f56a3bef23935990b1c9ed65629918 (diff) | |
download | caffe-1b236802449227c1e6b841368868d1edb3cda732.tar.gz caffe-1b236802449227c1e6b841368868d1edb3cda732.tar.bz2 caffe-1b236802449227c1e6b841368868d1edb3cda732.zip |
batch inputs in python by forward_all() and forward_backward_all()
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/imagenet/wrapper.py | 4 | ||||
-rw-r--r-- | python/caffe/pycaffe.py | 100 |
2 files changed, 102 insertions, 2 deletions
diff --git a/python/caffe/imagenet/wrapper.py b/python/caffe/imagenet/wrapper.py index 4a5b6ed8..dd505e42 100644 --- a/python/caffe/imagenet/wrapper.py +++ b/python/caffe/imagenet/wrapper.py @@ -30,7 +30,6 @@ def oversample(image, center_only=False): Output: images: the output of size (10 x 3 x 227 x 227) """ - image = image.swapaxes(1, 2).swapaxes(0, 1) indices = [0, IMAGE_DIM - CROPPED_DIM] center = int(indices[1] / 2) if center_only: @@ -58,8 +57,9 @@ def prepare_image(filename, center_only=False): img = np.tile(img[:, :, np.newaxis], (1, 1, 3)) elif img.shape[2] == 4: img = img[:, :, :3] - # Resize and convert to BGR + # Resize, convert to BGR, and permute axes to caffe order img_reshape = (transform.resize(img, (IMAGE_DIM,IMAGE_DIM)) * 255)[:, :, ::-1] + img_reshape = img_reshape.swapaxes(1, 2).swapaxes(0, 1) # subtract main img_reshape -= IMAGENET_MEAN return oversample(img_reshape, center_only) diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index d965227d..6f6cedd5 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -4,6 +4,7 @@ interface. """ from collections import OrderedDict +from itertools import izip_longest import numpy as np from ._caffe import Net, SGDSolver @@ -125,6 +126,76 @@ def _Net_backward(self, diffs=None, **kwargs): Net.backward = _Net_backward +def _Net_forward_all(self, blobs=None, **kwargs): + """ + Run net forward in batches. + + Take + blobs: list of blobs to extract as in forward() + kwargs: Keys are input blob names and values are lists of blobs. + Refer to forward(). + + Give + all_outs: {blob name: list of blobs} dict. + """ + # Collect outputs from batches + all_outs = {out: [] for out in self.outputs + blobs} + for batch in self._batch(kwargs): + outs = self.forward(blobs=blobs, **batch) + for out, out_blobs in outs.items(): + all_outs[out].extend(out_blobs) + # Discard padding at the end. + pad = len(all_outs.itervalues().next()) - len(kwargs.itervalues().next()) + if pad: + for out in all_outs: + del all_outs[out][-pad:] + return all_outs + +Net.forward_all = _Net_forward_all + + +def _Net_forward_backward_all(self, blobs=None, diffs=None, **kwargs): + """ + Run net forward + backward in batches. + + Take + blobs: list of blobs to extract as in forward() + diffs: list of diffs to extract as in backward() + kwargs: Keys are input (for forward) and output (for backward) blob names + and values are lists of blobs. Refer to forward() and backward(). + Prefilled variants are called for lack of input or output blobs. + + Give + all_blobs: {blob name: list of blobs} dict. + all_diffs: {blob name: list of diffs} dict. + """ + # Batch blobs and diffs. + all_outs = {out: [] for out in self.outputs + (blobs or [])} + all_diffs = {diff: [] for diff in self.inputs + (diffs or [])} + forward_batches = self._batch({in_: kwargs[in_] + for in_ in self.inputs if in_ in kwargs}) + backward_batches = self._batch({out: kwargs[out] + for out in self.outputs if out in kwargs}) + # Collect outputs from batches (and heed lack of forward/backward batches). + for fb, bb in izip_longest(forward_batches, backward_batches, fillvalue={}): + batch_blobs = self.forward(blobs=blobs, **fb) + batch_diffs = self.backward(diffs=diffs, **bb) + for out, out_blobs in batch_blobs.items(): + all_outs[out].extend(out_blobs) + for diff, out_diffs in batch_diffs.items(): + all_diffs[diff].extend(out_diffs) + # Discard padding at the end. + pad = len(all_outs.itervalues().next()) - len(kwargs.itervalues().next()) + if pad: + for out in all_outs: + del all_outs[out][-pad:] + for diff in all_diffs: + del all_diffs[diff][-pad:] + return all_outs, all_diffs + +Net.forward_backward_all = _Net_forward_backward_all + + def _Net_set_mean(self, input_, mean_f, mode='image'): """ Set the mean to subtract for data centering. @@ -244,3 +315,32 @@ def _Net_set_input_arrays(self, data, labels): return self._set_input_arrays(data, labels) Net.set_input_arrays = _Net_set_input_arrays + + +def _Net_batch(self, blobs): + """ + Batch blob lists according to net's batch size. + + Take + blobs: Keys blob names and values are lists of blobs (of any length). + Naturally, all the lists should have the same length. + + Give (yield) + batch: {blob name: list of blobs} dict for a single batch. + """ + num = len(blobs.itervalues().next()) + batch_size = self.blobs.itervalues().next().num + remainder = num % batch_size + num_batches = (num + remainder) / batch_size + + # Yield full batches. + for b in range(num_batches-1): + for i in [b * batch_size]: + yield {name: blobs[name][i:i + batch_size] for name in blobs} + + # Yield last padded batch, if any. + if remainder > 0: + yield {name: blobs[name][-remainder:] + + [np.zeros_like(blobs[name][0])] * remainder for name in blobs} + +Net._batch = _Net_batch |