diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-07-28 14:14:03 -0700 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-07-28 14:14:03 -0700 |
commit | d842f4a24f2895f0569f615072d1e66c0e08ea57 (patch) | |
tree | a7206f23a6934e28052c14782fe3fd6d22ecd0a7 /python | |
parent | 86cc3e91ae01d397f624b040f7592c0f5aaea088 (diff) | |
parent | fb2f7c1c2757e2a7e48860f75b6a091f7351fc68 (diff) | |
download | caffeonacl-d842f4a24f2895f0569f615072d1e66c0e08ea57.tar.gz caffeonacl-d842f4a24f2895f0569f615072d1e66c0e08ea57.tar.bz2 caffeonacl-d842f4a24f2895f0569f615072d1e66c0e08ea57.zip |
Merge pull request #733 from longjon/pycaffe-tweaks
pycaffe fixes
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/pycaffe.py | 26 |
1 files changed, 12 insertions, 14 deletions
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index 0ac18686..870dec4f 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -70,10 +70,10 @@ def _Net_forward(self, blobs=None, start=None, end=None, **kwargs): # Set input according to defined shapes and make arrays single and # C-contiguous as Caffe expects. for in_, blob in kwargs.iteritems(): - if blob.shape[0] != self.blobs[in_].num: - raise Exception('Input is not batch sized') if blob.ndim != 4: raise Exception('{} blob is not 4-d'.format(in_)) + if blob.shape[0] != self.blobs[in_].num: + raise Exception('Input is not batch sized') self.blobs[in_].data[...] = blob self._forward(start_ind, end_ind) @@ -117,10 +117,10 @@ def _Net_backward(self, diffs=None, start=None, end=None, **kwargs): # Set top diffs according to defined shapes and make arrays single and # C-contiguous as Caffe expects. for top, diff in kwargs.iteritems(): - if diff.shape[0] != self.blobs[top].num: - raise Exception('Diff is not batch sized') if diff.ndim != 4: raise Exception('{} diff is not 4-d'.format(top)) + if diff.shape[0] != self.blobs[top].num: + raise Exception('Diff is not batch sized') self.blobs[top].diff[...] = diff self._backward(start_ind, end_ind) @@ -284,17 +284,16 @@ def _Net_preprocess(self, input_name, input_): caffe_in = input_.astype(np.float32) input_scale = self.input_scale.get(input_name) channel_order = self.channel_swap.get(input_name) - mean = self.mean.get(input_name) in_size = self.blobs[input_name].data.shape[2:] if caffe_in.shape[:2] != in_size: caffe_in = caffe.io.resize_image(caffe_in, in_size) - if input_scale: + if input_scale is not None: caffe_in *= input_scale - if channel_order: + if channel_order is not None: caffe_in = caffe_in[:, :, channel_order] caffe_in = caffe_in.transpose((2, 0, 1)) - if mean is not None: - caffe_in -= mean + if hasattr(self, 'mean'): + caffe_in -= self.mean.get(input_name, 0) return caffe_in @@ -305,15 +304,14 @@ def _Net_deprocess(self, input_name, input_): decaf_in = input_.copy().squeeze() input_scale = self.input_scale.get(input_name) channel_order = self.channel_swap.get(input_name) - mean = self.mean.get(input_name) - if mean is not None: - decaf_in += mean + if hasattr(self, 'mean'): + decaf_in += self.mean.get(input_name, 0) decaf_in = decaf_in.transpose((1,2,0)) - if channel_order: + if channel_order is not None: channel_order_inverse = [channel_order.index(i) for i in range(decaf_in.shape[2])] decaf_in = decaf_in[:, :, channel_order_inverse] - if input_scale: + if input_scale is not None: decaf_in /= input_scale return decaf_in |