diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-01-25 20:53:00 -0800 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-01-25 20:53:00 -0800 |
commit | ec010d549e929ff0fa420c99a096642ae8d769ab (patch) | |
tree | e74122abbebe6ecf39913611ab7c3408dace836e /python | |
parent | 6408ac1098b10f6f31b5231c46db23d08a5c3964 (diff) | |
download | caffeonacl-ec010d549e929ff0fa420c99a096642ae8d769ab.tar.gz caffeonacl-ec010d549e929ff0fa420c99a096642ae8d769ab.tar.bz2 caffeonacl-ec010d549e929ff0fa420c99a096642ae8d769ab.zip |
automagically set detection batch size from network
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/detection/detector.py | 30 |
1 files changed, 14 insertions, 16 deletions
diff --git a/python/caffe/detection/detector.py b/python/caffe/detection/detector.py index 2b713369..26515389 100644 --- a/python/caffe/detection/detector.py +++ b/python/caffe/detection/detector.py @@ -36,10 +36,12 @@ IMAGE_CENTER = None IMAGE_MEAN = None CROPPED_IMAGE_MEAN = None +BATCH_SIZE = None NUM_OUTPUT = None CROP_MODES = ['center_only', 'corners', 'selective_search'] + def load_image(filename): """ Input: @@ -181,7 +183,7 @@ def _assemble_images_selective_search(image_fnames): return images_df -def assemble_batches(image_fnames, crop_mode='center_only', batch_size=10): +def assemble_batches(image_fnames, crop_mode='center_only'): """ Assemble DataFrame of image crops for feature computation. @@ -195,7 +197,7 @@ def assemble_batches(image_fnames, crop_mode='center_only', batch_size=10): image, and take each enclosing subwindow. Output: - df_batches: list of DataFrames, each one of batch_size rows. + df_batches: list of DataFrames, each one of BATCH_SIZE rows. Each row has 'image', 'filename', and 'window' info. Column 'image' contains (X x 3 x 227 x 227) ndarrays. Column 'filename' contains source filenames. @@ -216,23 +218,23 @@ def assemble_batches(image_fnames, crop_mode='center_only', batch_size=10): else: raise Exception("Unknown mode: not in {}".format(CROP_MODES)) - # Make sure the DataFrame has a multiple of batch_size rows: + # Make sure the DataFrame has a multiple of BATCH_SIZE rows: # just fill the extra rows with NaN filenames and all-zero images. N = images_df.shape[0] - remainder = N % batch_size + remainder = N % BATCH_SIZE if remainder > 0: zero_image = np.zeros_like(images_df['image'].iloc[0]) remainder_df = pd.DataFrame([{ 'filename': None, 'image': zero_image, 'window': [0, 0, 0, 0] - }] * (batch_size - remainder)) + }] * (BATCH_SIZE - remainder)) images_df = images_df.append(remainder_df) N = images_df.shape[0] - # Split into batches of batch_size. - ind = np.arange(N) / batch_size - df_batches = [images_df[ind == i] for i in range(N / batch_size)] + # Split into batches of BATCH_SIZE. + ind = np.arange(N) / BATCH_SIZE + df_batches = [images_df[ind == i] for i in range(N / BATCH_SIZE)] return df_batches @@ -254,7 +256,7 @@ def compute_feats(images_df): def config(model_def, pretrained_model, gpu, image_dim, image_mean_file): global IMAGE_DIM, CROPPED_DIM, IMAGE_CENTER, IMAGE_MEAN, CROPPED_IMAGE_MEAN - global NET, NUM_OUTPUT + global NET, BATCH_SIZE, NUM_OUTPUT # Initialize network by loading model definition and weights. t = time.time() @@ -273,11 +275,11 @@ def config(model_def, pretrained_model, gpu, image_dim, image_mean_file): # Load the data set mean file IMAGE_MEAN = np.load(image_mean_file) - CROPPED_IMAGE_MEAN = IMAGE_MEAN[IMAGE_CENTER:IMAGE_CENTER + CROPPED_DIM, IMAGE_CENTER:IMAGE_CENTER + CROPPED_DIM, :] - NUM_OUTPUT = NET.blobs()[-1].channels # number of output classes + BATCH_SIZE = NET.blobs()[0].num # network batch size + NUM_OUTPUT = NET.blobs()[-1].channels # number of output classes if __name__ == "__main__": @@ -293,8 +295,6 @@ if __name__ == "__main__": gflags.DEFINE_string( "images_file", "", "Image filenames file.") gflags.DEFINE_string( - "batch_size", 10, "Number of image crops to let through in one go") - gflags.DEFINE_string( "output_file", "", "Output DataFrame HDF5 filename.") gflags.DEFINE_string( "images_dim", 256, "Canonical dimension of (square) images.") @@ -305,7 +305,6 @@ if __name__ == "__main__": FLAGS = gflags.FLAGS FLAGS(sys.argv) - # Configure network, input, output config(FLAGS.model_def, FLAGS.pretrained_model, FLAGS.gpu, FLAGS.images_dim, FLAGS.images_mean_file) @@ -315,8 +314,7 @@ if __name__ == "__main__": print('Assembling batches...') with open(FLAGS.images_file) as f: image_fnames = [_.strip() for _ in f.readlines()] - image_batches = assemble_batches(image_fnames, FLAGS.crop_mode, - FLAGS.batch_size) + image_batches = assemble_batches(image_fnames, FLAGS.crop_mode) print('{} batches assembled in {:.3f} s'.format(len(image_batches), time.time() - t)) |