summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-14 13:39:06 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-14 13:44:02 -0700
commit0e5a5cf50e9d17dbfe96b8269145b934e99b29a5 (patch)
tree3b4a131cf7050fa4e7e26c2d258ad7f4712b1ec9 /python
parent96cd02dd538bcdb793070b4c9320eadfc9c7962d (diff)
downloadcaffe-0e5a5cf50e9d17dbfe96b8269145b934e99b29a5.tar.gz
caffe-0e5a5cf50e9d17dbfe96b8269145b934e99b29a5.tar.bz2
caffe-0e5a5cf50e9d17dbfe96b8269145b934e99b29a5.zip
pycaffe Net.forward() helper
Do forward pass by prefilled or packaging input + output blobs and returning a {output blob name: output list} dict.
Diffstat (limited to 'python')
-rw-r--r--python/caffe/pycaffe.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py
index a7bc2783..40538154 100644
--- a/python/caffe/pycaffe.py
+++ b/python/caffe/pycaffe.py
@@ -42,6 +42,42 @@ def _Net_params(self):
Net.params = _Net_params
+def _Net_forward(self, **kwargs):
+ """
+ Forward pass: prepare inputs and run the net forward.
+
+ Take
+ kwargs: Keys are input blob names and values are lists of inputs.
+ Images must be (H x W x K) ndarrays.
+ If None, input is taken from data layers by ForwardPrefilled().
+
+ Give
+ out: {output blob name: list of output blobs} dict.
+ """
+ outs = {}
+ if not kwargs:
+ # Carry out prefilled forward pass and unpack output.
+ self.ForwardPrefilled()
+ out_blobs = [self.blobs[out].data for out in self.outputs]
+ else:
+ # Create input and output blobs according to net defined shapes
+ # and make arrays single and C-contiguous as Caffe expects.
+ in_blobs = [np.ascontiguousarray(np.concatenate(kwargs[in_]),
+ dtype=np.float32) for in_ in self.inputs]
+ out_blobs = [np.empty(self.blobs[out].data.shape, dtype=np.float32)
+ for out in self.outputs]
+
+ self.Forward(in_blobs, out_blobs)
+
+ # Unpack output blobs
+ for out, out_blob in zip(self.outputs, out_blobs):
+ outs[out] = [out_blob[ix, :, :, :].squeeze()
+ for ix in range(out_blob.shape[0])]
+ return outs
+
+Net.forward = _Net_forward
+
+
def _Net_set_mean(self, input_, mean_f, mode='image'):
"""
Set the mean to subtract for data centering.