summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-14 14:37:33 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-14 14:37:33 -0700
commitac5e6fa1ce11fb8e6c8577122d8a00194d6ef6fa (patch)
tree4764291ddb959e5ec67b48ef168add4ae074dbd3 /python
parent9d4324e5e7f0187027c4cf6634d8b00116ffb8ce (diff)
downloadcaffe-ac5e6fa1ce11fb8e6c8577122d8a00194d6ef6fa.tar.gz
caffe-ac5e6fa1ce11fb8e6c8577122d8a00194d6ef6fa.tar.bz2
caffe-ac5e6fa1ce11fb8e6c8577122d8a00194d6ef6fa.zip
python Net.backward() helper and Net.BackwardPrefilled()
Diffstat (limited to 'python')
-rw-r--r--python/caffe/_caffe.cpp5
-rw-r--r--python/caffe/pycaffe.py37
2 files changed, 40 insertions, 2 deletions
diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp
index 18b96b92..c4460b92 100644
--- a/python/caffe/_caffe.cpp
+++ b/python/caffe/_caffe.cpp
@@ -293,6 +293,10 @@ struct CaffeNet {
net_->ForwardPrefilled();
}
+ void BackwardPrefilled() {
+ net_->Backward();
+ }
+
void set_input_arrays(object data_obj, object labels_obj) {
// check that this network has an input MemoryDataLayer
shared_ptr<MemoryDataLayer<float> > md_layer =
@@ -411,6 +415,7 @@ BOOST_PYTHON_MODULE(_caffe) {
.def("Forward", &CaffeNet::Forward)
.def("ForwardPrefilled", &CaffeNet::ForwardPrefilled)
.def("Backward", &CaffeNet::Backward)
+ .def("BackwardPrefilled", &CaffeNet::BackwardPrefilled)
.def("set_mode_cpu", &CaffeNet::set_mode_cpu)
.def("set_mode_gpu", &CaffeNet::set_mode_gpu)
.def("set_phase_train", &CaffeNet::set_phase_train)
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py
index 40538154..101deabd 100644
--- a/python/caffe/pycaffe.py
+++ b/python/caffe/pycaffe.py
@@ -52,9 +52,8 @@ def _Net_forward(self, **kwargs):
If None, input is taken from data layers by ForwardPrefilled().
Give
- out: {output blob name: list of output blobs} dict.
+ outs: {output blob name: list of output blobs} dict.
"""
- outs = {}
if not kwargs:
# Carry out prefilled forward pass and unpack output.
self.ForwardPrefilled()
@@ -70,6 +69,7 @@ def _Net_forward(self, **kwargs):
self.Forward(in_blobs, out_blobs)
# Unpack output blobs
+ outs = {}
for out, out_blob in zip(self.outputs, out_blobs):
outs[out] = [out_blob[ix, :, :, :].squeeze()
for ix in range(out_blob.shape[0])]
@@ -78,6 +78,39 @@ def _Net_forward(self, **kwargs):
Net.forward = _Net_forward
+def _Net_backward(self, **kwargs):
+ """
+ Backward pass: prepare diffs and run the net backward.
+
+ Take
+ kwargs: Keys are output blob names and values are lists of diffs.
+ If None, input is taken from data layers by BackwardPrefilled().
+
+ Give
+ bottom_diffs: {input blob name: list of diffs} dict.
+ """
+ if not kwargs:
+ self.BackwardPrefilled()
+ bottom_diffs = [self.blobs[in_].diff for in_ in self.inputs]
+ else:
+ # Create top and bottom diffs according to net defined shapes
+ # and make arrays single and C-contiguous as Caffe expects.
+ top_diffs = [np.ascontiguousarray(np.concatenate(kwargs[out]),
+ dtype=np.float32) for out in self.outputs]
+ bottom_diffs = [np.empty(self.blobs[bottom].data.shape, dtype=np.float32)
+ for bottom in self.inputs]
+ self.Backward(top_diffs, bottom_diffs)
+
+ # Unpack bottom diffs
+ bottom_diffs = {}
+ for bottom, bottom_diff in zip(self.inputs, bottom_diffs):
+ bottom_diffs[bottom] = [bottom_diff[ix, :, :, :].squeeze()
+ for ix in range(bottom_diff.shape[0])]
+ return bottom_diffs
+
+Net.backward = _Net_backward
+
+
def _Net_set_mean(self, input_, mean_f, mode='image'):
"""
Set the mean to subtract for data centering.