diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-04-09 18:37:32 -0700 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-05-13 18:10:41 -0700 |
commit | 8da2a3209c8a64e58c5cbbbdd2040c37e6e22673 (patch) | |
tree | ca4e64ef8c485f27b9ae4b8844892c6b96b87b05 /python | |
parent | 47ec9ace417b8fc8a086783648631906be5097d9 (diff) | |
download | caffe-8da2a3209c8a64e58c5cbbbdd2040c37e6e22673.tar.gz caffe-8da2a3209c8a64e58c5cbbbdd2040c37e6e22673.tar.bz2 caffe-8da2a3209c8a64e58c5cbbbdd2040c37e6e22673.zip |
add python io getters, mean helper, and image caffeinator/decaffeinator
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/pycaffe.py | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index 6b7b00e0..83906678 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -12,6 +12,14 @@ from ._caffe import Net, SGDSolver # inheritance) so that nets created by caffe (e.g., by SGDSolver) will # automatically have the improved interface +Net.input = property(lambda self: self.blobs.values()[0]) +Net.input_scale = None # for a model that expects data = input * input_scale + +Net.output = property(lambda self: self.blobs.values()[-1]) + +Net.mean = None # image mean (ndarray, input dimensional or broadcastable) + + @property def _Net_blobs(self): """ @@ -35,6 +43,71 @@ def _Net_params(self): Net.params = _Net_params + +def _Net_set_mean(self, mean_f, mode='image'): + """ + Set the mean to subtract for data centering. + + Take + mean_f: path to mean .npy + mode: image = use the whole-image mean (and check dimensions) + channel = channel constant (i.e. mean pixel instead of mean image) + """ + mean = np.load(mean_f) + if mode == 'image': + if mean.shape != self.input.data.shape[1:]: + raise Exception('The mean shape does not match the input shape.') + self.mean = mean + elif mode == 'channel': + self.mean = mean.mean(1).mean(1) + else: + raise Exception('Mode not in {}'.format(['image', 'channel'])) + +Net.set_mean = _Net_set_mean + + +def _Net_format_image(self, image): + """ + Format image for input to Caffe: + - convert to single + - reorder color to BGR + - reshape to 1 x K x H x W + + Take + image: (H x W x K) ndarray + + Give + image: (K x H x W) ndarray + """ + caf_image = image.astype(np.float32) + if self.input_scale: + caf_image *= self.input_scale + caf_image = caf_image[:, :, ::-1] + if self.mean is not None: + caf_image -= self.mean + caf_image = caf_image.transpose((2, 0, 1)) + caf_image = caf_image[np.newaxis, :, :, :] + return caf_image + +Net.format_image = _Net_format_image + + +def _Net_decaffeinate_image(self, image): + """ + Invert Caffe formatting; see _Net_format_image(). + """ + decaf_image = image.squeeze() + decaf_image = decaf_image.transpose((1,2,0)) + if self.mean is not None: + decaf_image += self.mean + decaf_image = decaf_image[:, :, ::-1] + if self.input_scale: + decaf_image /= self.input_scale + return decaf_image + +Net.decaffeinate_image = _Net_decaffeinate_image + + def _Net_set_input_arrays(self, data, labels): """ Set input arrays of the in-memory MemoryDataLayer. |