summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-07-28 14:14:03 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-07-28 14:14:03 -0700
commitd842f4a24f2895f0569f615072d1e66c0e08ea57 (patch)
treea7206f23a6934e28052c14782fe3fd6d22ecd0a7 /python
parent86cc3e91ae01d397f624b040f7592c0f5aaea088 (diff)
parentfb2f7c1c2757e2a7e48860f75b6a091f7351fc68 (diff)
downloadcaffeonacl-d842f4a24f2895f0569f615072d1e66c0e08ea57.tar.gz
caffeonacl-d842f4a24f2895f0569f615072d1e66c0e08ea57.tar.bz2
caffeonacl-d842f4a24f2895f0569f615072d1e66c0e08ea57.zip
Merge pull request #733 from longjon/pycaffe-tweaks
pycaffe fixes
Diffstat (limited to 'python')
-rw-r--r--python/caffe/pycaffe.py26
1 files changed, 12 insertions, 14 deletions
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py
index 0ac18686..870dec4f 100644
--- a/python/caffe/pycaffe.py
+++ b/python/caffe/pycaffe.py
@@ -70,10 +70,10 @@ def _Net_forward(self, blobs=None, start=None, end=None, **kwargs):
# Set input according to defined shapes and make arrays single and
# C-contiguous as Caffe expects.
for in_, blob in kwargs.iteritems():
- if blob.shape[0] != self.blobs[in_].num:
- raise Exception('Input is not batch sized')
if blob.ndim != 4:
raise Exception('{} blob is not 4-d'.format(in_))
+ if blob.shape[0] != self.blobs[in_].num:
+ raise Exception('Input is not batch sized')
self.blobs[in_].data[...] = blob
self._forward(start_ind, end_ind)
@@ -117,10 +117,10 @@ def _Net_backward(self, diffs=None, start=None, end=None, **kwargs):
# Set top diffs according to defined shapes and make arrays single and
# C-contiguous as Caffe expects.
for top, diff in kwargs.iteritems():
- if diff.shape[0] != self.blobs[top].num:
- raise Exception('Diff is not batch sized')
if diff.ndim != 4:
raise Exception('{} diff is not 4-d'.format(top))
+ if diff.shape[0] != self.blobs[top].num:
+ raise Exception('Diff is not batch sized')
self.blobs[top].diff[...] = diff
self._backward(start_ind, end_ind)
@@ -284,17 +284,16 @@ def _Net_preprocess(self, input_name, input_):
caffe_in = input_.astype(np.float32)
input_scale = self.input_scale.get(input_name)
channel_order = self.channel_swap.get(input_name)
- mean = self.mean.get(input_name)
in_size = self.blobs[input_name].data.shape[2:]
if caffe_in.shape[:2] != in_size:
caffe_in = caffe.io.resize_image(caffe_in, in_size)
- if input_scale:
+ if input_scale is not None:
caffe_in *= input_scale
- if channel_order:
+ if channel_order is not None:
caffe_in = caffe_in[:, :, channel_order]
caffe_in = caffe_in.transpose((2, 0, 1))
- if mean is not None:
- caffe_in -= mean
+ if hasattr(self, 'mean'):
+ caffe_in -= self.mean.get(input_name, 0)
return caffe_in
@@ -305,15 +304,14 @@ def _Net_deprocess(self, input_name, input_):
decaf_in = input_.copy().squeeze()
input_scale = self.input_scale.get(input_name)
channel_order = self.channel_swap.get(input_name)
- mean = self.mean.get(input_name)
- if mean is not None:
- decaf_in += mean
+ if hasattr(self, 'mean'):
+ decaf_in += self.mean.get(input_name, 0)
decaf_in = decaf_in.transpose((1,2,0))
- if channel_order:
+ if channel_order is not None:
channel_order_inverse = [channel_order.index(i)
for i in range(decaf_in.shape[2])]
decaf_in = decaf_in[:, :, channel_order_inverse]
- if input_scale:
+ if input_scale is not None:
decaf_in /= input_scale
return decaf_in