summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-01-25 20:53:00 -0800
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-01-25 20:53:00 -0800
commitec010d549e929ff0fa420c99a096642ae8d769ab (patch)
treee74122abbebe6ecf39913611ab7c3408dace836e /python
parent6408ac1098b10f6f31b5231c46db23d08a5c3964 (diff)
downloadcaffeonacl-ec010d549e929ff0fa420c99a096642ae8d769ab.tar.gz
caffeonacl-ec010d549e929ff0fa420c99a096642ae8d769ab.tar.bz2
caffeonacl-ec010d549e929ff0fa420c99a096642ae8d769ab.zip
automagically set detection batch size from network
Diffstat (limited to 'python')
-rw-r--r--python/caffe/detection/detector.py30
1 files changed, 14 insertions, 16 deletions
diff --git a/python/caffe/detection/detector.py b/python/caffe/detection/detector.py
index 2b713369..26515389 100644
--- a/python/caffe/detection/detector.py
+++ b/python/caffe/detection/detector.py
@@ -36,10 +36,12 @@ IMAGE_CENTER = None
IMAGE_MEAN = None
CROPPED_IMAGE_MEAN = None
+BATCH_SIZE = None
NUM_OUTPUT = None
CROP_MODES = ['center_only', 'corners', 'selective_search']
+
def load_image(filename):
"""
Input:
@@ -181,7 +183,7 @@ def _assemble_images_selective_search(image_fnames):
return images_df
-def assemble_batches(image_fnames, crop_mode='center_only', batch_size=10):
+def assemble_batches(image_fnames, crop_mode='center_only'):
"""
Assemble DataFrame of image crops for feature computation.
@@ -195,7 +197,7 @@ def assemble_batches(image_fnames, crop_mode='center_only', batch_size=10):
image, and take each enclosing subwindow.
Output:
- df_batches: list of DataFrames, each one of batch_size rows.
+ df_batches: list of DataFrames, each one of BATCH_SIZE rows.
Each row has 'image', 'filename', and 'window' info.
Column 'image' contains (X x 3 x 227 x 227) ndarrays.
Column 'filename' contains source filenames.
@@ -216,23 +218,23 @@ def assemble_batches(image_fnames, crop_mode='center_only', batch_size=10):
else:
raise Exception("Unknown mode: not in {}".format(CROP_MODES))
- # Make sure the DataFrame has a multiple of batch_size rows:
+ # Make sure the DataFrame has a multiple of BATCH_SIZE rows:
# just fill the extra rows with NaN filenames and all-zero images.
N = images_df.shape[0]
- remainder = N % batch_size
+ remainder = N % BATCH_SIZE
if remainder > 0:
zero_image = np.zeros_like(images_df['image'].iloc[0])
remainder_df = pd.DataFrame([{
'filename': None,
'image': zero_image,
'window': [0, 0, 0, 0]
- }] * (batch_size - remainder))
+ }] * (BATCH_SIZE - remainder))
images_df = images_df.append(remainder_df)
N = images_df.shape[0]
- # Split into batches of batch_size.
- ind = np.arange(N) / batch_size
- df_batches = [images_df[ind == i] for i in range(N / batch_size)]
+ # Split into batches of BATCH_SIZE.
+ ind = np.arange(N) / BATCH_SIZE
+ df_batches = [images_df[ind == i] for i in range(N / BATCH_SIZE)]
return df_batches
@@ -254,7 +256,7 @@ def compute_feats(images_df):
def config(model_def, pretrained_model, gpu, image_dim, image_mean_file):
global IMAGE_DIM, CROPPED_DIM, IMAGE_CENTER, IMAGE_MEAN, CROPPED_IMAGE_MEAN
- global NET, NUM_OUTPUT
+ global NET, BATCH_SIZE, NUM_OUTPUT
# Initialize network by loading model definition and weights.
t = time.time()
@@ -273,11 +275,11 @@ def config(model_def, pretrained_model, gpu, image_dim, image_mean_file):
# Load the data set mean file
IMAGE_MEAN = np.load(image_mean_file)
-
CROPPED_IMAGE_MEAN = IMAGE_MEAN[IMAGE_CENTER:IMAGE_CENTER + CROPPED_DIM,
IMAGE_CENTER:IMAGE_CENTER + CROPPED_DIM,
:]
- NUM_OUTPUT = NET.blobs()[-1].channels # number of output classes
+ BATCH_SIZE = NET.blobs()[0].num # network batch size
+ NUM_OUTPUT = NET.blobs()[-1].channels # number of output classes
if __name__ == "__main__":
@@ -293,8 +295,6 @@ if __name__ == "__main__":
gflags.DEFINE_string(
"images_file", "", "Image filenames file.")
gflags.DEFINE_string(
- "batch_size", 10, "Number of image crops to let through in one go")
- gflags.DEFINE_string(
"output_file", "", "Output DataFrame HDF5 filename.")
gflags.DEFINE_string(
"images_dim", 256, "Canonical dimension of (square) images.")
@@ -305,7 +305,6 @@ if __name__ == "__main__":
FLAGS = gflags.FLAGS
FLAGS(sys.argv)
-
# Configure network, input, output
config(FLAGS.model_def, FLAGS.pretrained_model, FLAGS.gpu, FLAGS.images_dim,
FLAGS.images_mean_file)
@@ -315,8 +314,7 @@ if __name__ == "__main__":
print('Assembling batches...')
with open(FLAGS.images_file) as f:
image_fnames = [_.strip() for _ in f.readlines()]
- image_batches = assemble_batches(image_fnames, FLAGS.crop_mode,
- FLAGS.batch_size)
+ image_batches = assemble_batches(image_fnames, FLAGS.crop_mode)
print('{} batches assembled in {:.3f} s'.format(len(image_batches),
time.time() - t))