summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-04-09 18:37:32 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-13 18:10:41 -0700
commit8da2a3209c8a64e58c5cbbbdd2040c37e6e22673 (patch)
treeca4e64ef8c485f27b9ae4b8844892c6b96b87b05 /python
parent47ec9ace417b8fc8a086783648631906be5097d9 (diff)
downloadcaffe-8da2a3209c8a64e58c5cbbbdd2040c37e6e22673.tar.gz
caffe-8da2a3209c8a64e58c5cbbbdd2040c37e6e22673.tar.bz2
caffe-8da2a3209c8a64e58c5cbbbdd2040c37e6e22673.zip
add python io getters, mean helper, and image caffeinator/decaffeinator
Diffstat (limited to 'python')
-rw-r--r--python/caffe/pycaffe.py73
1 files changed, 73 insertions, 0 deletions
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py
index 6b7b00e0..83906678 100644
--- a/python/caffe/pycaffe.py
+++ b/python/caffe/pycaffe.py
@@ -12,6 +12,14 @@ from ._caffe import Net, SGDSolver
# inheritance) so that nets created by caffe (e.g., by SGDSolver) will
# automatically have the improved interface
+Net.input = property(lambda self: self.blobs.values()[0])
+Net.input_scale = None # for a model that expects data = input * input_scale
+
+Net.output = property(lambda self: self.blobs.values()[-1])
+
+Net.mean = None # image mean (ndarray, input dimensional or broadcastable)
+
+
@property
def _Net_blobs(self):
"""
@@ -35,6 +43,71 @@ def _Net_params(self):
Net.params = _Net_params
+
+def _Net_set_mean(self, mean_f, mode='image'):
+ """
+ Set the mean to subtract for data centering.
+
+ Take
+ mean_f: path to mean .npy
+ mode: image = use the whole-image mean (and check dimensions)
+ channel = channel constant (i.e. mean pixel instead of mean image)
+ """
+ mean = np.load(mean_f)
+ if mode == 'image':
+ if mean.shape != self.input.data.shape[1:]:
+ raise Exception('The mean shape does not match the input shape.')
+ self.mean = mean
+ elif mode == 'channel':
+ self.mean = mean.mean(1).mean(1)
+ else:
+ raise Exception('Mode not in {}'.format(['image', 'channel']))
+
+Net.set_mean = _Net_set_mean
+
+
+def _Net_format_image(self, image):
+ """
+ Format image for input to Caffe:
+ - convert to single
+ - reorder color to BGR
+ - reshape to 1 x K x H x W
+
+ Take
+ image: (H x W x K) ndarray
+
+ Give
+ image: (K x H x W) ndarray
+ """
+ caf_image = image.astype(np.float32)
+ if self.input_scale:
+ caf_image *= self.input_scale
+ caf_image = caf_image[:, :, ::-1]
+ if self.mean is not None:
+ caf_image -= self.mean
+ caf_image = caf_image.transpose((2, 0, 1))
+ caf_image = caf_image[np.newaxis, :, :, :]
+ return caf_image
+
+Net.format_image = _Net_format_image
+
+
+def _Net_decaffeinate_image(self, image):
+ """
+ Invert Caffe formatting; see _Net_format_image().
+ """
+ decaf_image = image.squeeze()
+ decaf_image = decaf_image.transpose((1,2,0))
+ if self.mean is not None:
+ decaf_image += self.mean
+ decaf_image = decaf_image[:, :, ::-1]
+ if self.input_scale:
+ decaf_image /= self.input_scale
+ return decaf_image
+
+Net.decaffeinate_image = _Net_decaffeinate_image
+
+
def _Net_set_input_arrays(self, data, labels):
"""
Set input arrays of the in-memory MemoryDataLayer.