diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-05-14 13:39:06 -0700 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-05-14 13:44:02 -0700 |
commit | 0e5a5cf50e9d17dbfe96b8269145b934e99b29a5 (patch) | |
tree | 3b4a131cf7050fa4e7e26c2d258ad7f4712b1ec9 /python | |
parent | 96cd02dd538bcdb793070b4c9320eadfc9c7962d (diff) | |
download | caffe-0e5a5cf50e9d17dbfe96b8269145b934e99b29a5.tar.gz caffe-0e5a5cf50e9d17dbfe96b8269145b934e99b29a5.tar.bz2 caffe-0e5a5cf50e9d17dbfe96b8269145b934e99b29a5.zip |
pycaffe Net.forward() helper
Do forward pass by prefilled or packaging input + output blobs and
returning a {output blob name: output list} dict.
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/pycaffe.py | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index a7bc2783..40538154 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -42,6 +42,42 @@ def _Net_params(self): Net.params = _Net_params +def _Net_forward(self, **kwargs): + """ + Forward pass: prepare inputs and run the net forward. + + Take + kwargs: Keys are input blob names and values are lists of inputs. + Images must be (H x W x K) ndarrays. + If None, input is taken from data layers by ForwardPrefilled(). + + Give + out: {output blob name: list of output blobs} dict. + """ + outs = {} + if not kwargs: + # Carry out prefilled forward pass and unpack output. + self.ForwardPrefilled() + out_blobs = [self.blobs[out].data for out in self.outputs] + else: + # Create input and output blobs according to net defined shapes + # and make arrays single and C-contiguous as Caffe expects. + in_blobs = [np.ascontiguousarray(np.concatenate(kwargs[in_]), + dtype=np.float32) for in_ in self.inputs] + out_blobs = [np.empty(self.blobs[out].data.shape, dtype=np.float32) + for out in self.outputs] + + self.Forward(in_blobs, out_blobs) + + # Unpack output blobs + for out, out_blob in zip(self.outputs, out_blobs): + outs[out] = [out_blob[ix, :, :, :].squeeze() + for ix in range(out_blob.shape[0])] + return outs + +Net.forward = _Net_forward + + def _Net_set_mean(self, input_, mean_f, mode='image'): """ Set the mean to subtract for data centering. |