summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-07-31 16:19:20 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-08-05 23:17:59 -0700
commitd5c3cef47155b5c2aad146465187364a2b41fd99 (patch)
tree8ebaf286f71556f3f182251664f95997b6661120 /python
parentf1eb9821ba717a55b684d42ef8c87125e855b402 (diff)
downloadcaffeonacl-d5c3cef47155b5c2aad146465187364a2b41fd99.tar.gz
caffeonacl-d5c3cef47155b5c2aad146465187364a2b41fd99.tar.bz2
caffeonacl-d5c3cef47155b5c2aad146465187364a2b41fd99.zip
fix pycaffe input processing
- load an image as [0,1] single / np.float32 according to Python convention - fix input scaling during preprocessing: - scale input for preprocessing by `raw_scale` e.g. to map an image to [0, 255] for the CaffeNet and AlexNet ImageNet models - scale feature space by `input_scale` after mean subtraction - switch examples to raw scale for ImageNet models - fix #525 - preserve type after resizing. - resize 1, 3, or K channel images with special casing between skimage.transform (1 and 3) and scipy.ndimage (K) for speed
Diffstat (limited to 'python')
-rw-r--r--python/caffe/_caffe.cpp2
-rw-r--r--python/caffe/classifier.py15
-rw-r--r--python/caffe/detector.py17
-rw-r--r--python/caffe/io.py18
-rw-r--r--python/caffe/pycaffe.py53
-rwxr-xr-xpython/classify.py8
-rwxr-xr-xpython/detect.py11
7 files changed, 88 insertions, 36 deletions
diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp
index 30c86aeb..59317727 100644
--- a/python/caffe/_caffe.cpp
+++ b/python/caffe/_caffe.cpp
@@ -278,6 +278,7 @@ struct CaffeNet {
// Input preprocessing configuration attributes.
dict mean_;
dict input_scale_;
+ dict raw_scale_;
dict channel_swap_;
// if taking input from an ndarray, we need to hold references
object input_data_;
@@ -329,6 +330,7 @@ BOOST_PYTHON_MODULE(_caffe) {
.add_property("outputs", &CaffeNet::outputs)
.add_property("mean", &CaffeNet::mean_)
.add_property("input_scale", &CaffeNet::input_scale_)
+ .add_property("raw_scale", &CaffeNet::raw_scale_)
.add_property("channel_swap", &CaffeNet::channel_swap_)
.def("_set_input_arrays", &CaffeNet::set_input_arrays)
.def("save", &CaffeNet::save);
diff --git a/python/caffe/classifier.py b/python/caffe/classifier.py
index f347be42..48835bab 100644
--- a/python/caffe/classifier.py
+++ b/python/caffe/classifier.py
@@ -14,13 +14,14 @@ class Classifier(caffe.Net):
by scaling, center cropping, or oversampling.
"""
def __init__(self, model_file, pretrained_file, image_dims=None,
- gpu=False, mean_file=None, input_scale=None, channel_swap=None):
+ gpu=False, mean_file=None, input_scale=None, raw_scale=None,
+ channel_swap=None):
"""
Take
image_dims: dimensions to scale input for cropping/sampling.
- Default is to scale to net input size for whole-image crop.
- gpu, mean_file, input_scale, channel_swap: convenience params for
- setting mode, mean, input scale, and channel order.
+ Default is to scale to net input size for whole-image crop.
+ gpu, mean_file, input_scale, raw_scale, channel_swap: params for
+ preprocessing options.
"""
caffe.Net.__init__(self, model_file, pretrained_file)
self.set_phase_test()
@@ -32,9 +33,11 @@ class Classifier(caffe.Net):
if mean_file:
self.set_mean(self.inputs[0], mean_file)
- if input_scale:
+ if input_scale is not None:
self.set_input_scale(self.inputs[0], input_scale)
- if channel_swap:
+ if raw_scale is not None:
+ self.set_raw_scale(self.inputs[0], raw_scale)
+ if channel_swap is not None:
self.set_channel_swap(self.inputs[0], channel_swap)
self.crop_dims = np.array(self.blobs[self.inputs[0]].data.shape[2:])
diff --git a/python/caffe/detector.py b/python/caffe/detector.py
index 56c26aef..a9b06cd1 100644
--- a/python/caffe/detector.py
+++ b/python/caffe/detector.py
@@ -25,11 +25,12 @@ class Detector(caffe.Net):
selective search proposals.
"""
def __init__(self, model_file, pretrained_file, gpu=False, mean_file=None,
- input_scale=None, channel_swap=None, context_pad=None):
+ input_scale=None, raw_scale=None, channel_swap=None,
+ context_pad=None):
"""
Take
- gpu, mean_file, input_scale, channel_swap: convenience params for
- setting mode, mean, input scale, and channel order.
+ gpu, mean_file, input_scale, raw_scale, channel_swap: params for
+ preprocessing options.
context_pad: amount of surrounding context to take s.t. a `context_pad`
sized border of pixels in the network input image is context, as in
R-CNN feature extraction.
@@ -44,9 +45,11 @@ class Detector(caffe.Net):
if mean_file:
self.set_mean(self.inputs[0], mean_file)
- if input_scale:
+ if input_scale is not None:
self.set_input_scale(self.inputs[0], input_scale)
- if channel_swap:
+ if raw_scale is not None:
+ self.set_raw_scale(self.inputs[0], raw_scale)
+ if channel_swap is not None:
self.set_channel_swap(self.inputs[0], channel_swap)
self.configure_crop(context_pad)
@@ -180,7 +183,7 @@ class Detector(caffe.Net):
"""
self.context_pad = context_pad
if self.context_pad:
- input_scale = self.input_scale.get(self.inputs[0])
+ raw_scale = self.raw_scale.get(self.inputs[0])
channel_order = self.channel_swap.get(self.inputs[0])
# Padding context crops needs the mean in unprocessed input space.
self.crop_mean = self.mean[self.inputs[0]].copy()
@@ -188,4 +191,4 @@ class Detector(caffe.Net):
channel_order_inverse = [channel_order.index(i)
for i in range(self.crop_mean.shape[2])]
self.crop_mean = self.crop_mean[:,:, channel_order_inverse]
- self.crop_mean /= input_scale
+ self.crop_mean /= raw_scale
diff --git a/python/caffe/io.py b/python/caffe/io.py
index 1fc97231..aabcfddb 100644
--- a/python/caffe/io.py
+++ b/python/caffe/io.py
@@ -1,6 +1,7 @@
import numpy as np
import skimage.io
-import skimage.transform
+from scipy.ndimage import zoom
+from skimage.transform import resize
from caffe.proto import caffe_pb2
@@ -15,7 +16,8 @@ def load_image(filename, color=True):
loads as intensity (if image is already grayscale).
Give
- image: an image with type np.float32 of size (H x W x 3) in RGB or
+ image: an image with type np.float32 in range [0, 1]
+ of size (H x W x 3) in RGB or
of size (H x W x 1) in grayscale.
"""
img = skimage.img_as_float(skimage.io.imread(filename)).astype(np.float32)
@@ -40,7 +42,17 @@ def resize_image(im, new_dims, interp_order=1):
Give
im: resized ndarray with shape (new_dims[0], new_dims[1], K)
"""
- return skimage.transform.resize(im, new_dims, order=interp_order)
+ if im.shape[-1] == 1 or im.shape[-1] == 3:
+ # skimage is fast but only understands {1,3} channel images in [0, 1].
+ im_min, im_max = im.min(), im.max()
+ im_std = (im - im_min) / (im_max - im_min)
+ resized_std = resize(im_std, new_dims, order=interp_order)
+ resized_im = resized_std * (im_max - im_min) + im_min
+ else:
+ # ndimage interpolates anything but more slowly.
+ scale = tuple(np.array(new_dims) / np.array(im.shape[:2]))
+ resized_im = zoom(im, scale + (1,), order=interp_order)
+ return resized_im.astype(np.float32)
def oversample(images, crop_dims):
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py
index 64747f30..43648d0a 100644
--- a/python/caffe/pycaffe.py
+++ b/python/caffe/pycaffe.py
@@ -216,12 +216,10 @@ def _Net_set_mean(self, input_, mean_f, mode='elementwise'):
in_shape = self.blobs[input_].data.shape
mean = np.load(mean_f)
if mode == 'elementwise':
- if mean.shape != in_shape[1:]:
- # Resize mean (which requires H x W x K input in range [0,1]).
- m_min, m_max = mean.min(), mean.max()
- normal_mean = (mean - m_min) / (m_max - m_min)
- mean = caffe.io.resize_image(normal_mean.transpose((1,2,0)),
- in_shape[2:]).transpose((2,0,1)) * (m_max - m_min) + m_min
+ if mean.shape[1:] != in_shape[2:]:
+ # Resize mean (which requires H x W x K input).
+ mean = caffe.io.resize_image(mean.transpose((1,2,0)),
+ in_shape[2:]).transpose((2,0,1))
self.mean[input_] = mean
elif mode == 'channel':
self.mean[input_] = mean.mean(1).mean(1).reshape((in_shape[1], 1, 1))
@@ -229,10 +227,11 @@ def _Net_set_mean(self, input_, mean_f, mode='elementwise'):
raise Exception('Mode not in {}'.format(['elementwise', 'channel']))
-
def _Net_set_input_scale(self, input_, scale):
"""
- Set the input feature scaling factor s.t. input blob = input * scale.
+ Set the scale of preprocessed inputs s.t. the blob = blob * scale.
+ N.B. input_scale is done AFTER mean subtraction and other preprocessing
+ while raw_scale is done BEFORE.
Take
input_: which input to assign this scale factor
@@ -243,6 +242,22 @@ def _Net_set_input_scale(self, input_, scale):
self.input_scale[input_] = scale
+def _Net_set_raw_scale(self, input_, scale):
+ """
+ Set the scale of raw features s.t. the input blob = input * scale.
+ While Python represents images in [0, 1], certain Caffe models
+ like CaffeNet and AlexNet represent images in [0, 255] so the raw_scale
+ of these models must be 255.
+
+ Take
+ input_: which input to assign this scale factor
+ scale: scale coefficient
+ """
+ if input_ not in self.inputs:
+ raise Exception('Input not in {}'.format(self.inputs))
+ self.raw_scale[input_] = scale
+
+
def _Net_set_channel_swap(self, input_, order):
"""
Set the input channel order for e.g. RGB to BGR conversion
@@ -263,10 +278,11 @@ def _Net_preprocess(self, input_name, input_):
Format input for Caffe:
- convert to single
- resize to input dimensions (preserving number of channels)
- - scale feature
- reorder channels (for instance color to BGR)
- - subtract mean
+ - scale raw input (e.g. from [0, 1] to [0, 255] for ImageNet models)
- transpose dimensions to K x H x W
+ - subtract mean
+ - scale feature
Take
input_name: name of input blob to preprocess for
@@ -275,20 +291,23 @@ def _Net_preprocess(self, input_name, input_):
Give
caffe_inputs: (K x H x W) ndarray
"""
- caffe_in = input_.astype(np.float32)
+ caffe_in = input_.astype(np.float32, copy=False)
mean = self.mean.get(input_name)
input_scale = self.input_scale.get(input_name)
+ raw_scale = self.raw_scale.get(input_name)
channel_order = self.channel_swap.get(input_name)
in_size = self.blobs[input_name].data.shape[2:]
if caffe_in.shape[:2] != in_size:
caffe_in = caffe.io.resize_image(caffe_in, in_size)
- if input_scale is not None:
- caffe_in *= input_scale
if channel_order is not None:
caffe_in = caffe_in[:, :, channel_order]
caffe_in = caffe_in.transpose((2, 0, 1))
+ if raw_scale is not None:
+ caffe_in *= raw_scale
if mean is not None:
caffe_in -= mean
+ if input_scale is not None:
+ caffe_in *= input_scale
return caffe_in
@@ -299,16 +318,19 @@ def _Net_deprocess(self, input_name, input_):
decaf_in = input_.copy().squeeze()
mean = self.mean.get(input_name)
input_scale = self.input_scale.get(input_name)
+ raw_scale = self.raw_scale.get(input_name)
channel_order = self.channel_swap.get(input_name)
+ if input_scale is not None:
+ decaf_in /= input_scale
if mean is not None:
decaf_in += mean
+ if raw_scale is not None:
+ decaf_in /= raw_scale
decaf_in = decaf_in.transpose((1,2,0))
if channel_order is not None:
channel_order_inverse = [channel_order.index(i)
for i in range(decaf_in.shape[2])]
decaf_in = decaf_in[:, :, channel_order_inverse]
- if input_scale is not None:
- decaf_in /= input_scale
return decaf_in
@@ -364,6 +386,7 @@ Net.forward_all = _Net_forward_all
Net.forward_backward_all = _Net_forward_backward_all
Net.set_mean = _Net_set_mean
Net.set_input_scale = _Net_set_input_scale
+Net.set_raw_scale = _Net_set_raw_scale
Net.set_channel_swap = _Net_set_channel_swap
Net.preprocess = _Net_preprocess
Net.deprocess = _Net_deprocess
diff --git a/python/classify.py b/python/classify.py
index fdaeeb01..417f8b54 100755
--- a/python/classify.py
+++ b/python/classify.py
@@ -66,8 +66,12 @@ def main(argv):
parser.add_argument(
"--input_scale",
type=float,
- default=255,
- help="Multiply input features by this scale before input to net"
+ help="Multiply input features by this scale to finish input preprocessing."
+ )
+ parser.add_argument(
+ "--raw_scale",
+ type=float,
+ help="Multiply raw input by this scale before preprocessing."
)
parser.add_argument(
"--channel_swap",
diff --git a/python/detect.py b/python/detect.py
index a3bee5c5..4cfe0825 100755
--- a/python/detect.py
+++ b/python/detect.py
@@ -76,8 +76,12 @@ def main(argv):
parser.add_argument(
"--input_scale",
type=float,
- default=255,
- help="Multiply input features by this scale before input to net"
+ help="Multiply input features by this scale to finish input preprocessing."
+ )
+ parser.add_argument(
+ "--raw_scale",
+ type=float,
+ help="Multiply raw input by this scale before preprocessing."
)
parser.add_argument(
"--channel_swap",
@@ -99,7 +103,8 @@ def main(argv):
# Make detector.
detector = caffe.Detector(args.model_def, args.pretrained_model,
gpu=args.gpu, mean_file=args.mean_file,
- input_scale=args.input_scale, channel_swap=channel_swap,
+ input_scale=args.input_scale, raw_scale=args.raw_scale,
+ channel_swap=channel_swap,
context_pad=args.context_pad)
if args.gpu: