summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-13 19:56:14 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-14 13:44:02 -0700
commit96cd02dd538bcdb793070b4c9320eadfc9c7962d (patch)
treef03677e137d5ddf75401b8b6aeff84445aca9608 /python
parent56ca978c4e14740d76cdadd5bccc019edbc6d235 (diff)
downloadcaffe-96cd02dd538bcdb793070b4c9320eadfc9c7962d.tar.gz
caffe-96cd02dd538bcdb793070b4c9320eadfc9c7962d.tar.bz2
caffe-96cd02dd538bcdb793070b4c9320eadfc9c7962d.zip
set input preprocessing per blob in python
Diffstat (limited to 'python')
-rw-r--r--python/caffe/pycaffe.py86
1 files changed, 64 insertions, 22 deletions
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py
index 83906678..a7bc2783 100644
--- a/python/caffe/pycaffe.py
+++ b/python/caffe/pycaffe.py
@@ -12,12 +12,10 @@ from ._caffe import Net, SGDSolver
# inheritance) so that nets created by caffe (e.g., by SGDSolver) will
# automatically have the improved interface
-Net.input = property(lambda self: self.blobs.values()[0])
-Net.input_scale = None # for a model that expects data = input * input_scale
-
-Net.output = property(lambda self: self.blobs.values()[-1])
-
-Net.mean = None # image mean (ndarray, input dimensional or broadcastable)
+# Input preprocessing
+Net.mean = {} # image mean (ndarray, input dimensional or broadcastable)
+Net.input_scale = {} # for a model that expects data = input * input_scale
+Net.channel_swap = {} # for RGB -> BGR and the like
@property
@@ -44,33 +42,69 @@ def _Net_params(self):
Net.params = _Net_params
-def _Net_set_mean(self, mean_f, mode='image'):
+def _Net_set_mean(self, input_, mean_f, mode='image'):
"""
Set the mean to subtract for data centering.
Take
+ input_: which input to assign this mean.
mean_f: path to mean .npy
mode: image = use the whole-image mean (and check dimensions)
channel = channel constant (i.e. mean pixel instead of mean image)
"""
+ if input_ not in self.inputs:
+ raise Exception('Input not in {}'.format(self.inputs))
mean = np.load(mean_f)
if mode == 'image':
if mean.shape != self.input.data.shape[1:]:
raise Exception('The mean shape does not match the input shape.')
- self.mean = mean
+ self.mean[input_] = mean
elif mode == 'channel':
- self.mean = mean.mean(1).mean(1)
+ self.mean[input_] = mean.mean(1).mean(1)
else:
raise Exception('Mode not in {}'.format(['image', 'channel']))
Net.set_mean = _Net_set_mean
-def _Net_format_image(self, image):
+def _Net_set_input_scale(self, input_, scale):
+ """
+ Set the input feature scaling factor s.t. input blob = input * scale.
+
+ Take
+ input_: which input to assign this scale factor
+ scale: scale coefficient
+ """
+ if input_ not in self.inputs:
+ raise Exception('Input not in {}'.format(self.inputs))
+ self.input_scale[input_] = scale
+
+Net.set_input_scale = _Net_set_input_scale
+
+
+def _Net_set_channel_swap(self, input_, order):
+ """
+ Set the input channel order for e.g. RGB to BGR conversion
+ as needed for the reference ImageNet model.
+
+ Take
+ input_: which input to assign this channel order
+ order: the order to take the channels. (2,1,0) maps RGB to BGR for example.
+ """
+ if input_ not in self.inputs:
+ raise Exception('Input not in {}'.format(self.inputs))
+ self.channel_swap[input_] = order
+
+Net.set_channel_swap = _Net_set_channel_swap
+
+
+def _Net_format_image(self, input_, image):
"""
Format image for input to Caffe:
- convert to single
- - reorder color to BGR
+ - scale feature
+ - reorder channels (for instance color to BGR)
+ - subtract mean
- reshape to 1 x K x H x W
Take
@@ -80,11 +114,15 @@ def _Net_format_image(self, image):
image: (K x H x W) ndarray
"""
caf_image = image.astype(np.float32)
- if self.input_scale:
- caf_image *= self.input_scale
- caf_image = caf_image[:, :, ::-1]
- if self.mean is not None:
- caf_image -= self.mean
+ input_scale = self.input_scale.get(input_)
+ channel_order = self.channel_swap.get(input_)
+ mean = self.mean.get(input_)
+ if input_scale:
+ caf_image *= input_scale
+ if channel_order:
+ caf_image = caf_image[:, :, channel_order]
+ if mean:
+ caf_image -= mean
caf_image = caf_image.transpose((2, 0, 1))
caf_image = caf_image[np.newaxis, :, :, :]
return caf_image
@@ -92,17 +130,21 @@ def _Net_format_image(self, image):
Net.format_image = _Net_format_image
-def _Net_decaffeinate_image(self, image):
+def _Net_decaffeinate_image(self, input_, image):
"""
Invert Caffe formatting; see _Net_format_image().
"""
decaf_image = image.squeeze()
decaf_image = decaf_image.transpose((1,2,0))
- if self.mean is not None:
- decaf_image += self.mean
- decaf_image = decaf_image[:, :, ::-1]
- if self.input_scale:
- decaf_image /= self.input_scale
+ input_scale = self.input_scale.get(input_)
+ channel_order = self.channel_swap.get(input_)
+ mean = self.mean.get(input_)
+ if mean:
+ decaf_image += mean
+ if channel_order:
+ decaf_image = decaf_image[:, :, channel_order[::-1]]
+ if input_scale:
+ decaf_image /= input_scale
return decaf_image
Net.decaffeinate_image = _Net_decaffeinate_image