From 23a50260846f3fd4f45f81a70ab5c837a4ed0b40 Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Thu, 28 Aug 2014 21:45:46 -0700 Subject: [example] edit fine-tuning and train on ~2000 images, 1557 / 382 split - further detail merits of fine-tuning: less starving for itme and data - set random seed for reproducing the tutorial - 1557 train / 382 test split is more indicative of training quality than splits of 200 images --- examples/finetune_flickr_style/assemble_data.py | 90 ++++++ examples/finetune_flickr_style/flickr_style.csv.gz | Bin 0 -> 2178982 bytes .../flickr_style_solver.prototxt | 15 + .../flickr_style_train_val.prototxt | 349 +++++++++++++++++++++ examples/finetune_flickr_style/readme.md | 159 ++++++++++ .../finetuning_on_flickr_style/assemble_data.py | 89 ------ .../finetuning_on_flickr_style/flickr_style.csv.gz | Bin 2178982 -> 0 bytes .../finetuning_on_flickr_style/models/.gitignore | 0 examples/finetuning_on_flickr_style/readme.md | 100 ------ .../finetuning_on_flickr_style/solver.prototxt | 15 - .../finetuning_on_flickr_style/train_val.prototxt | 349 --------------------- 11 files changed, 613 insertions(+), 553 deletions(-) create mode 100755 examples/finetune_flickr_style/assemble_data.py create mode 100644 examples/finetune_flickr_style/flickr_style.csv.gz create mode 100644 examples/finetune_flickr_style/flickr_style_solver.prototxt create mode 100644 examples/finetune_flickr_style/flickr_style_train_val.prototxt create mode 100644 examples/finetune_flickr_style/readme.md delete mode 100644 examples/finetuning_on_flickr_style/assemble_data.py delete mode 100644 examples/finetuning_on_flickr_style/flickr_style.csv.gz delete mode 100644 examples/finetuning_on_flickr_style/models/.gitignore delete mode 100644 examples/finetuning_on_flickr_style/readme.md delete mode 100644 examples/finetuning_on_flickr_style/solver.prototxt delete mode 100644 examples/finetuning_on_flickr_style/train_val.prototxt (limited to 'examples') diff --git a/examples/finetune_flickr_style/assemble_data.py b/examples/finetune_flickr_style/assemble_data.py new file mode 100755 index 00000000..b4c995e8 --- /dev/null +++ b/examples/finetune_flickr_style/assemble_data.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +""" +Form a subset of the Flickr Style data, download images to dirname, and write +Caffe ImagesDataLayer training file. +""" +import os +import urllib +import hashlib +import argparse +import numpy as np +import pandas as pd +import multiprocessing + +# Flickr returns a special image if the request is unavailable. +MISSING_IMAGE_SHA1 = '6a92790b1c2a301c6e7ddef645dca1f53ea97ac2' + +example_dirname = os.path.abspath(os.path.dirname(__file__)) +caffe_dirname = os.path.abspath(os.path.join(example_dirname, '../..')) +training_dirname = os.path.join(caffe_dirname, 'data/flickr_style') + + +def download_image(args_tuple): + "For use with multiprocessing map. Returns filename on fail." + try: + url, filename = args_tuple + if not os.path.exists(filename): + urllib.urlretrieve(url, filename) + with open(filename) as f: + assert hashlib.sha1(f.read()).hexdigest() != MISSING_IMAGE_SHA1 + return True + except KeyboardInterrupt: + raise Exception() # multiprocessing doesn't catch keyboard exceptions + except: + return False + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Download a subset of Flickr Style to a directory') + parser.add_argument( + '-s', '--seed', type=int, default=0, + help="random seed") + parser.add_argument( + '-i', '--images', type=int, default=-1, + help="number of images to use (-1 for all [default])", + ) + parser.add_argument( + '-w', '--workers', type=int, default=-1, + help="num workers used to download images. -x uses (all - x) cores [-1 default]." + ) + + args = parser.parse_args() + np.random.seed(args.seed) + + # Read data, shuffle order, and subsample. + csv_filename = os.path.join(example_dirname, 'flickr_style.csv.gz') + df = pd.read_csv(csv_filename, index_col=0, compression='gzip') + df = df.iloc[np.random.permutation(df.shape[0])] + if args.images > 0 and args.images < df.shape[0]: + df = df.iloc[:args.images] + + # Make directory for images and get local filenames. + if training_dirname is None: + training_dirname = os.path.join(caffe_dirname, 'data/flickr_style') + images_dirname = os.path.join(training_dirname, 'images') + if not os.path.exists(images_dirname): + os.makedirs(images_dirname) + df['image_filename'] = [ + os.path.join(images_dirname, _.split('/')[-1]) for _ in df['image_url'] + ] + + # Download images. + num_workers = args.workers + if num_workers <= 0: + num_workers = multiprocessing.cpu_count() + num_workers + print('Downloading {} images with {} workers...'.format( + df.shape[0], num_workers)) + pool = multiprocessing.Pool(processes=num_workers) + map_args = zip(df['image_url'], df['image_filename']) + results = pool.map(download_image, map_args) + + # Only keep rows with valid images, and write out training file lists. + df = df[results] + for split in ['train', 'test']: + split_df = df[df['_split'] == split] + filename = os.path.join(training_dirname, '{}.txt'.format(split)) + split_df[['image_filename', 'label']].to_csv( + filename, sep=' ', header=None, index=None) + print('Writing train/val for {} successfully downloaded images.'.format( + df.shape[0])) diff --git a/examples/finetune_flickr_style/flickr_style.csv.gz b/examples/finetune_flickr_style/flickr_style.csv.gz new file mode 100644 index 00000000..5a84f88a Binary files /dev/null and b/examples/finetune_flickr_style/flickr_style.csv.gz differ diff --git a/examples/finetune_flickr_style/flickr_style_solver.prototxt b/examples/finetune_flickr_style/flickr_style_solver.prototxt new file mode 100644 index 00000000..740ec39f --- /dev/null +++ b/examples/finetune_flickr_style/flickr_style_solver.prototxt @@ -0,0 +1,15 @@ +net: "examples/finetune_flickr_style/flickr_style_train_val.prototxt" +test_iter: 100 +test_interval: 1000 +# lr for fine-tuning should be lower than when starting from scratch +base_lr: 0.001 +lr_policy: "step" +gamma: 0.1 +# stepsize should also be lower, as we're closer to being done +stepsize: 20000 +display: 20 +max_iter: 100000 +momentum: 0.9 +weight_decay: 0.0005 +snapshot: 10000 +snapshot_prefix: "examples/finetune_flickr_style/flickr_style" diff --git a/examples/finetune_flickr_style/flickr_style_train_val.prototxt b/examples/finetune_flickr_style/flickr_style_train_val.prototxt new file mode 100644 index 00000000..bcb1e1ce --- /dev/null +++ b/examples/finetune_flickr_style/flickr_style_train_val.prototxt @@ -0,0 +1,349 @@ +name: "FlickrStyleCaffeNet" +layers { + name: "data" + type: IMAGE_DATA + top: "data" + top: "label" + image_data_param { + source: "data/flickr_style/train.txt" + batch_size: 50 + transform_param { + crop_size: 227 + mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" + mirror: true + } + new_height: 256 + new_width: 256 + } + include: { phase: TRAIN } +} +layers { + name: "data" + type: IMAGE_DATA + top: "data" + top: "label" + image_data_param { + source: "data/flickr_style/train.txt" + batch_size: 50 + transform_param { + crop_size: 227 + mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" + mirror: true + } + new_height: 256 + new_width: 256 + } + include: { phase: TEST } +} +layers { + name: "conv1" + type: CONVOLUTION + bottom: "data" + top: "conv1" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 96 + kernel_size: 11 + stride: 4 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "relu1" + type: RELU + bottom: "conv1" + top: "conv1" +} +layers { + name: "pool1" + type: POOLING + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm1" + type: LRN + bottom: "pool1" + top: "norm1" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "conv2" + type: CONVOLUTION + bottom: "norm1" + top: "conv2" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 2 + kernel_size: 5 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu2" + type: RELU + bottom: "conv2" + top: "conv2" +} +layers { + name: "pool2" + type: POOLING + bottom: "conv2" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm2" + type: LRN + bottom: "pool2" + top: "norm2" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "conv3" + type: CONVOLUTION + bottom: "norm2" + top: "conv3" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "relu3" + type: RELU + bottom: "conv3" + top: "conv3" +} +layers { + name: "conv4" + type: CONVOLUTION + bottom: "conv3" + top: "conv4" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu4" + type: RELU + bottom: "conv4" + top: "conv4" +} +layers { + name: "conv5" + type: CONVOLUTION + bottom: "conv4" + top: "conv5" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu5" + type: RELU + bottom: "conv5" + top: "conv5" +} +layers { + name: "pool5" + type: POOLING + bottom: "conv5" + top: "pool5" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "fc6" + type: INNER_PRODUCT + bottom: "pool5" + top: "fc6" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu6" + type: RELU + bottom: "fc6" + top: "fc6" +} +layers { + name: "drop6" + type: DROPOUT + bottom: "fc6" + top: "fc6" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc7" + type: INNER_PRODUCT + bottom: "fc6" + top: "fc7" + # Note that blobs_lr can be set to 0 to disable any fine-tuning of this, and any other, layer + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu7" + type: RELU + bottom: "fc7" + top: "fc7" +} +layers { + name: "drop7" + type: DROPOUT + bottom: "fc7" + top: "fc7" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc8_flickr" + type: INNER_PRODUCT + bottom: "fc7" + top: "fc8_flickr" + # blobs_lr is set to higher than for other layers, because this layer is starting from random while the others are already trained + blobs_lr: 10 + blobs_lr: 20 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 20 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "loss" + type: SOFTMAX_LOSS + bottom: "fc8_flickr" + bottom: "label" +} +layers { + name: "accuracy" + type: ACCURACY + bottom: "fc8_flickr" + bottom: "label" + top: "accuracy" + include: { phase: TEST } +} diff --git a/examples/finetune_flickr_style/readme.md b/examples/finetune_flickr_style/readme.md new file mode 100644 index 00000000..82249982 --- /dev/null +++ b/examples/finetune_flickr_style/readme.md @@ -0,0 +1,159 @@ +--- +title: Fine-tuning for style recognition +description: Fine-tune the ImageNet-trained CaffeNet on the "Flickr Style" dataset. +category: example +include_in_docs: true +layout: default +priority: 5 +--- + +# Fine-tuning CaffeNet for Style Recognition on "Flickr Style" Data + +Fine-tuning takes an already learned model, adapts the architecture, and resumes training from the already learned model weights. +Let's fine-tune the BVLC-distributed CaffeNet model on a different dataset, [Flickr Style](http://sergeykarayev.com/files/1311.3715v3.pdf), to predict image style instead of object category. + +## Explanation + +The Flickr-sourced images of the Style dataset are visually very similar to the ImageNet dataset, on which the `caffe_reference_imagenet_model` was trained. +Since that model works well for object category classification, we'd like to use it architecture for our style classifier. +We also only have 80,000 images to train on, so we'd like to start with the parameters learned on the 1,000,000 ImageNet images, and fine-tune as needed. +If we give provide the `weights` argument to the `caffe train` command, the pretrained weights will be loaded into our model, matching layers by name. + +Because we are predicting 20 classes instead of a 1,000, we do need to change the last layer in the model. +Therefore, we change the name of the last layer from `fc8` to `fc8_flickr` in our prototxt. +Since there is no layer named that in the `caffe_reference_imagenet_model`, that layer will begin training with random weights. + +We will also decrease the overall learning rate `base_lr` in the solver prototxt, but boost the `blobs_lr` on the newly introduced layer. +The idea is to have the rest of the model change very slowly with new data, but let the new layer learn fast. +Additionally, we set `stepsize` in the solver to a lower value than if we were training from scratch, since we're virtually far along in training and therefore want the learning rate to go down faster. +Note that we could also entirely prevent fine-tuning of all layers other than `fc8_flickr` by setting their `blobs_lr` to 0. + +## Procedure + +All steps are to be done from the caffe root directory. + +The dataset is distributed as a list of URLs with corresponding labels. +Using a script, we will download a small subset of the data and split it into train and val sets. + + caffe % ./examples/finetune_flickr_style/assemble_data.py -h + usage: assemble_data.py [-h] [-s SEED] [-i IMAGES] [-w WORKERS] + + Download a subset of Flickr Style to a directory + + optional arguments: + -h, --help show this help message and exit + -s SEED, --seed SEED random seed + -i IMAGES, --images IMAGES + number of images to use (-1 for all) + -w WORKERS, --workers WORKERS + num workers used to download images. -x uses (all - x) + cores. + + caffe % python examples/finetune_flickr_style/assemble_data.py --workers=-1 --images=2000 --seed 831486 + Downloading 2000 images with 7 workers... + Writing train/val for 1939 successfully downloaded images. + +This script downloads images and writes train/val file lists into `data/flickr_style`. +With this random seed there are 1,557 train images and 382 test images. +The prototxts in this example assume this, and also assume the presence of the ImageNet mean file (run `get_ilsvrc_aux.sh` from `data/ilsvrc12` to obtain this if you haven't yet). + +We'll also need the ImageNet-trained model, which you can obtain by running `get_caffe_reference_imagenet_model.sh` from `examples/imagenet`. + +Now we can train! (You can fine-tune in CPU mode by leaving out the `-gpu` flag.) + + caffe % ./build/tools/caffe train -solver examples/finetune_flickr_style/flickr_style_solver.prototxt -weights examples/imagenet/caffe_reference_imagenet_model -gpu 0 + + [...] + + I0828 22:10:04.025378 9718 solver.cpp:46] Solver scaffolding done. + I0828 22:10:04.025388 9718 caffe.cpp:95] Use GPU with device ID 0 + I0828 22:10:04.192004 9718 caffe.cpp:107] Finetuning from examples/imagenet/caffe_reference_imagenet_model + + [...] + + I0828 22:17:48.338963 11510 solver.cpp:165] Solving FlickrStyleCaffeNet + I0828 22:17:48.339010 11510 solver.cpp:251] Iteration 0, Testing net (#0) + I0828 22:18:14.313817 11510 solver.cpp:302] Test net output #0: accuracy = 0.0416 + I0828 22:18:14.476822 11510 solver.cpp:195] Iteration 0, loss = 3.75717 + I0828 22:18:14.476878 11510 solver.cpp:397] Iteration 0, lr = 0.001 + I0828 22:18:19.700408 11510 solver.cpp:195] Iteration 20, loss = 3.1689 + I0828 22:18:19.700461 11510 solver.cpp:397] Iteration 20, lr = 0.001 + I0828 22:18:24.924685 11510 solver.cpp:195] Iteration 40, loss = 2.3549 + I0828 22:18:24.924741 11510 solver.cpp:397] Iteration 40, lr = 0.001 + I0828 22:18:30.114858 11510 solver.cpp:195] Iteration 60, loss = 2.74191 + I0828 22:18:30.114910 11510 solver.cpp:397] Iteration 60, lr = 0.001 + I0828 22:18:35.328071 11510 solver.cpp:195] Iteration 80, loss = 1.9147 + I0828 22:18:35.328127 11510 solver.cpp:397] Iteration 80, lr = 0.001 + I0828 22:18:40.588317 11510 solver.cpp:195] Iteration 100, loss = 1.81419 + I0828 22:18:40.588373 11510 solver.cpp:397] Iteration 100, lr = 0.001 + I0828 22:18:46.171576 11510 solver.cpp:195] Iteration 120, loss = 2.02105 + I0828 22:18:46.171669 11510 solver.cpp:397] Iteration 120, lr = 0.001 + I0828 22:18:51.757809 11510 solver.cpp:195] Iteration 140, loss = 1.49083 + I0828 22:18:51.757863 11510 solver.cpp:397] Iteration 140, lr = 0.001 + I0828 22:18:57.345080 11510 solver.cpp:195] Iteration 160, loss = 1.35319 + I0828 22:18:57.345135 11510 solver.cpp:397] Iteration 160, lr = 0.001 + I0828 22:19:02.928794 11510 solver.cpp:195] Iteration 180, loss = 1.11658 + I0828 22:19:02.928850 11510 solver.cpp:397] Iteration 180, lr = 0.001 + I0828 22:19:08.514497 11510 solver.cpp:195] Iteration 200, loss = 1.08851 + I0828 22:19:08.514552 11510 solver.cpp:397] Iteration 200, lr = 0.001 + + [...] + + I0828 22:22:40.789010 11510 solver.cpp:195] Iteration 960, loss = 0.0844627 + I0828 22:22:40.789175 11510 solver.cpp:397] Iteration 960, lr = 0.001 + I0828 22:22:46.376626 11510 solver.cpp:195] Iteration 980, loss = 0.0110937 + I0828 22:22:46.376682 11510 solver.cpp:397] Iteration 980, lr = 0.001 + I0828 22:22:51.687258 11510 solver.cpp:251] Iteration 1000, Testing net (#0) + I0828 22:23:17.438894 11510 solver.cpp:302] Test net output #0: accuracy = 1 + +Note how rapidly the loss went down. Although the 100% accuracy is optimistic, it is evidence the model is learning quickly and well. + +For comparison, here is how the loss goes down when we do not start with a pre-trained model: + + I0828 22:24:18.624004 12919 solver.cpp:165] Solving FlickrStyleCaffeNet + I0828 22:24:18.624099 12919 solver.cpp:251] Iteration 0, Testing net (#0) + I0828 22:24:44.520992 12919 solver.cpp:302] Test net output #0: accuracy = 0.045 + I0828 22:24:44.676905 12919 solver.cpp:195] Iteration 0, loss = 3.33111 + I0828 22:24:44.677120 12919 solver.cpp:397] Iteration 0, lr = 0.001 + I0828 22:24:50.152454 12919 solver.cpp:195] Iteration 20, loss = 2.98133 + I0828 22:24:50.152509 12919 solver.cpp:397] Iteration 20, lr = 0.001 + I0828 22:24:55.736256 12919 solver.cpp:195] Iteration 40, loss = 3.02124 + I0828 22:24:55.736311 12919 solver.cpp:397] Iteration 40, lr = 0.001 + I0828 22:25:01.316514 12919 solver.cpp:195] Iteration 60, loss = 2.99509 + I0828 22:25:01.316567 12919 solver.cpp:397] Iteration 60, lr = 0.001 + I0828 22:25:06.899554 12919 solver.cpp:195] Iteration 80, loss = 2.9928 + I0828 22:25:06.899610 12919 solver.cpp:397] Iteration 80, lr = 0.001 + I0828 22:25:12.484624 12919 solver.cpp:195] Iteration 100, loss = 2.99072 + I0828 22:25:12.484678 12919 solver.cpp:397] Iteration 100, lr = 0.001 + I0828 22:25:18.069056 12919 solver.cpp:195] Iteration 120, loss = 3.01816 + I0828 22:25:18.069149 12919 solver.cpp:397] Iteration 120, lr = 0.001 + I0828 22:25:23.650928 12919 solver.cpp:195] Iteration 140, loss = 2.9694 + I0828 22:25:23.650984 12919 solver.cpp:397] Iteration 140, lr = 0.001 + I0828 22:25:29.235535 12919 solver.cpp:195] Iteration 160, loss = 3.00383 + I0828 22:25:29.235589 12919 solver.cpp:397] Iteration 160, lr = 0.001 + I0828 22:25:34.816898 12919 solver.cpp:195] Iteration 180, loss = 2.99802 + I0828 22:25:34.816953 12919 solver.cpp:397] Iteration 180, lr = 0.001 + I0828 22:25:40.396656 12919 solver.cpp:195] Iteration 200, loss = 2.99769 + I0828 22:25:40.396711 12919 solver.cpp:397] Iteration 200, lr = 0.001 + + [...] + + I0828 22:29:12.539094 12919 solver.cpp:195] Iteration 960, loss = 2.99314 + I0828 22:29:12.539258 12919 solver.cpp:397] Iteration 960, lr = 0.001 + I0828 22:29:18.123092 12919 solver.cpp:195] Iteration 980, loss = 2.99503 + I0828 22:29:18.123147 12919 solver.cpp:397] Iteration 980, lr = 0.001 + I0828 22:29:23.432059 12919 solver.cpp:251] Iteration 1000, Testing net (#0) + I0828 22:29:49.409044 12919 solver.cpp:302] Test net output #0: accuracy = 0.0624 + +This model is only beginning to learn. + +Fine-tuning can be feasible when training from scratch would not be for lack of time or data. +Even in CPU mode each pass through the training set takes ~100 s. GPU fine-tuning is of course faster still and can learn a useful model in minutes or hours instead of days or weeks. +Furthermore, note that the model has only trained on < 2,000 instances. Transfer learning a new task like style recognition from the ImageNet pretraining can require much less data than training from scratch. +Now try fine-tuning to your own tasks and data! + +## License + +The Flickr Style dataset as distributed here contains only URLs to images. +Some of the images may have copyright. +Training a category-recognition model for research/non-commercial use may constitute fair use of this data. diff --git a/examples/finetuning_on_flickr_style/assemble_data.py b/examples/finetuning_on_flickr_style/assemble_data.py deleted file mode 100644 index d8770e92..00000000 --- a/examples/finetuning_on_flickr_style/assemble_data.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -Form a subset of the Flickr Style data, download images to dirname, and write -Caffe ImagesDataLayer training file. -""" -import os -import urllib -import hashlib -import argparse -import numpy as np -import pandas as pd -import multiprocessing - -# Flickr returns a special image if the request is unavailable. -MISSING_IMAGE_SHA1 = '6a92790b1c2a301c6e7ddef645dca1f53ea97ac2' - -example_dirname = os.path.abspath(os.path.dirname(__file__)) -caffe_dirname = os.path.abspath(os.path.join(example_dirname, '../..')) -training_dirname = os.path.join(caffe_dirname, 'data/flickr_style') - - -def download_image(args_tuple): - "For use with multiprocessing map. Returns filename on fail." - try: - url, filename = args_tuple - if not os.path.exists(filename): - urllib.urlretrieve(url, filename) - with open(filename) as f: - assert hashlib.sha1(f.read()).hexdigest() != MISSING_IMAGE_SHA1 - return True - except KeyboardInterrupt: - raise Exception() # multiprocessing doesn't catch keyboard exceptions - except: - return False - - -if __name__ == '__main__': - parser = argparse.ArgumentParser( - description='Download a subset of Flickr Style to a directory') - parser.add_argument( - '-s', '--seed', type=int, default=0, - help="random seed") - parser.add_argument( - '-i', '--images', type=int, default=-1, - help="number of images to use (-1 for all [default])", - ) - parser.add_argument( - '-w', '--workers', type=int, default=-1, - help="num workers used to download images. -x uses (all - x) cores [-1 default]." - ) - - args = parser.parse_args() - np.random.seed(args.seed) - - # Read data, shuffle order, and subsample. - csv_filename = os.path.join(example_dirname, 'flickr_style.csv.gz') - df = pd.read_csv(csv_filename, index_col=0, compression='gzip') - df = df.iloc[np.random.permutation(df.shape[0])] - if args.images > 0 and args.images < df.shape[0]: - df = df.iloc[:args.images] - - # Make directory for images and get local filenames. - if training_dirname is None: - training_dirname = os.path.join(caffe_dirname, 'data/flickr_style') - images_dirname = os.path.join(training_dirname, 'images') - if not os.path.exists(images_dirname): - os.makedirs(images_dirname) - df['image_filename'] = [ - os.path.join(images_dirname, _.split('/')[-1]) for _ in df['image_url'] - ] - - # Download images. - num_workers = args.workers - if num_workers <= 0: - num_workers = multiprocessing.cpu_count() + num_workers - print('Downloading {} images with {} workers...'.format( - df.shape[0], num_workers)) - pool = multiprocessing.Pool(processes=num_workers) - map_args = zip(df['image_url'], df['image_filename']) - results = pool.map(download_image, map_args) - - # Only keep rows with valid images, and write out training file lists. - df = df[results] - for split in ['train', 'test']: - split_df = df[df['_split'] == split] - filename = os.path.join(training_dirname, '{}.txt'.format(split)) - split_df[['image_filename', 'label']].to_csv( - filename, sep=' ', header=None, index=None) - print('Writing train/val for {} successfully downloaded images.'.format( - df.shape[0])) diff --git a/examples/finetuning_on_flickr_style/flickr_style.csv.gz b/examples/finetuning_on_flickr_style/flickr_style.csv.gz deleted file mode 100644 index 5a84f88a..00000000 Binary files a/examples/finetuning_on_flickr_style/flickr_style.csv.gz and /dev/null differ diff --git a/examples/finetuning_on_flickr_style/models/.gitignore b/examples/finetuning_on_flickr_style/models/.gitignore deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/finetuning_on_flickr_style/readme.md b/examples/finetuning_on_flickr_style/readme.md deleted file mode 100644 index 4a164e5f..00000000 --- a/examples/finetuning_on_flickr_style/readme.md +++ /dev/null @@ -1,100 +0,0 @@ ---- -title: Fine-tuning CaffeNet on "Flickr Style" data -description: We fine-tune the ImageNet-trained CaffeNet on another dataset. -category: example -include_in_docs: true -layout: default -priority: 5 ---- - -# Fine-tuning CaffeNet on "Flickr Style" data - -This example shows how to fine-tune the BVLC-distributed CaffeNet model on a different dataset: [Flickr Style](http://sergeykarayev.com/files/1311.3715v3.pdf), which has style category labels. - -## Explanation - -The Flickr-sourced data of the Style dataset is visually very similar to the ImageNet dataset, on which the `caffe_reference_imagenet_model` was trained. -Since that model works well for object category classification, we'd like to use it architecture for our style classifier. -We also only have 80,000 images to train on, so we'd like to start with the parameters learned on the 1,000,000 ImageNet images, and fine-tune as needed. -If we give provide the `model` parameter to the `caffe train` command, the trained weights will be loaded into our model, matching layers by name. - -Because we are predicting 20 classes instead of a 1,000, we do need to change the last layer in the model. -Therefore, we change the name of the last layer from `fc8` to `fc8_flickr` in our prototxt. -Since there is no layer named that in the `caffe_reference_imagenet_model`, that layer will begin training with random weights. - -We will also decrease the overall learning rate `base_lr` in the solver prototxt, but boost the `blobs_lr` on the newly introduced layer. -The idea is to have the rest of the model change very slowly with new data, but the new layer needs to learn fast. -Additionally, we set `stepsize` in the solver to a lower value than if we were training from scratch, since we're virtually far along in training and therefore want the learning rate to go down faster. -Note that we could also entirely prevent fine-tuning of all layers other than `fc8_flickr` by setting their `blobs_lr` to 0. - -## Procedure - -All steps are to be done from the root caffe directory. - -The dataset is distributed as a list of URLs with corresponding labels. -Using a script, we will download a small subset of the data and split it into train and val sets. - - caffe % python examples/finetuning_on_flickr_style/assemble_data.py -h - usage: assemble_data.py [-h] [-s SEED] [-i IMAGES] [-w WORKERS] - - Download a subset of Flickr Style to a directory - - optional arguments: - -h, --help show this help message and exit - -s SEED, --seed SEED random seed - -i IMAGES, --images IMAGES - number of images to use (-1 for all) - -w WORKERS, --workers WORKERS - num workers used to download images. -x uses (all - x) - cores. - - caffe % python examples/finetuning_on_flickr_style/assemble_data.py --workers=-1 --images=200 - Downloading 200 images with 7 workers... - Writing train/val for 190 successfully downloaded images. - -This script downloads images and writes train/val file lists into `data/flickr_style`. -The prototxt's in this example assume this, and also assume the presence of the ImageNet mean file (run `get_ilsvrc_aux.sh` from `data/ilsvrc12` to obtain this if you haven't yet). - -We'll also need the ImageNet-trained model, which you can obtain by running `get_caffe_reference_imagenet_model.sh` from `examples/imagenet`. - -Now we can train! - - caffe % ./build/tools/caffe train -solver examples/finetuning_on_flickr_style/solver.prototxt -weights examples/imagenet/caffe_reference_imagenet_model - I0827 19:41:52.455621 2129298192 caffe.cpp:90] Starting Optimization - I0827 19:41:52.456883 2129298192 solver.cpp:32] Initializing solver from parameters: - - [...] - - I0827 19:41:55.520205 2129298192 solver.cpp:46] Solver scaffolding done. - I0827 19:41:55.520211 2129298192 caffe.cpp:99] Use CPU. - I0827 19:41:55.520217 2129298192 caffe.cpp:107] Finetuning from examples/imagenet/caffe_reference_imagenet_model - I0827 19:41:57.433044 2129298192 solver.cpp:165] Solving CaffeNet - I0827 19:41:57.433104 2129298192 solver.cpp:251] Iteration 0, Testing net (#0) - I0827 19:44:44.145447 2129298192 solver.cpp:302] Test net output #0: accuracy = 0.004 - I0827 19:44:48.774271 2129298192 solver.cpp:195] Iteration 0, loss = 3.46922 - I0827 19:44:48.774333 2129298192 solver.cpp:397] Iteration 0, lr = 0.001 - I0827 19:46:15.107447 2129298192 solver.cpp:195] Iteration 20, loss = 0.0147678 - I0827 19:46:15.107511 2129298192 solver.cpp:397] Iteration 20, lr = 0.001 - I0827 19:47:41.941119 2129298192 solver.cpp:195] Iteration 40, loss = 0.00455839 - I0827 19:47:41.941181 2129298192 solver.cpp:397] Iteration 40, lr = 0.001 - -Note how rapidly the loss went down. -For comparison, here is how the loss goes down when we do not start with a pre-trained model: - - I0827 18:57:08.496208 2129298192 solver.cpp:46] Solver scaffolding done. - I0827 18:57:08.496227 2129298192 caffe.cpp:99] Use CPU. - I0827 18:57:08.496235 2129298192 solver.cpp:165] Solving CaffeNet - I0827 18:57:08.496271 2129298192 solver.cpp:251] Iteration 0, Testing net (#0) - I0827 19:00:00.894336 2129298192 solver.cpp:302] Test net output #0: accuracy = 0.075 - I0827 19:00:05.825129 2129298192 solver.cpp:195] Iteration 0, loss = 3.51759 - I0827 19:00:05.825187 2129298192 solver.cpp:397] Iteration 0, lr = 0.001 - I0827 19:01:36.090224 2129298192 solver.cpp:195] Iteration 20, loss = 3.32227 - I0827 19:01:36.091948 2129298192 solver.cpp:397] Iteration 20, lr = 0.001 - I0827 19:03:08.522105 2129298192 solver.cpp:195] Iteration 40, loss = 2.97031 - I0827 19:03:08.522176 2129298192 solver.cpp:397] Iteration 40, lr = 0.001 - -## License - -The Flickr Style dataset as distributed here contains only URLs to images. -Some of the images may have copyright. -Training a category-recognition model for research/non-commercial use may constitute fair use of this data. diff --git a/examples/finetuning_on_flickr_style/solver.prototxt b/examples/finetuning_on_flickr_style/solver.prototxt deleted file mode 100644 index ed1548c0..00000000 --- a/examples/finetuning_on_flickr_style/solver.prototxt +++ /dev/null @@ -1,15 +0,0 @@ -net: "examples/finetuning_on_flickr_style/train_val.prototxt" -test_iter: 100 -test_interval: 1000 -# lr for fine-tuning should be lower than when starting from scratch -base_lr: 0.001 -lr_policy: "step" -gamma: 0.1 -# stepsize should also be lower, as we're closer to being done -stepsize: 20000 -display: 20 -max_iter: 100000 -momentum: 0.9 -weight_decay: 0.0005 -snapshot: 10000 -snapshot_prefix: "examples/finetuning_on_flickr_style/models/finetuning" diff --git a/examples/finetuning_on_flickr_style/train_val.prototxt b/examples/finetuning_on_flickr_style/train_val.prototxt deleted file mode 100644 index bcb1e1ce..00000000 --- a/examples/finetuning_on_flickr_style/train_val.prototxt +++ /dev/null @@ -1,349 +0,0 @@ -name: "FlickrStyleCaffeNet" -layers { - name: "data" - type: IMAGE_DATA - top: "data" - top: "label" - image_data_param { - source: "data/flickr_style/train.txt" - batch_size: 50 - transform_param { - crop_size: 227 - mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" - mirror: true - } - new_height: 256 - new_width: 256 - } - include: { phase: TRAIN } -} -layers { - name: "data" - type: IMAGE_DATA - top: "data" - top: "label" - image_data_param { - source: "data/flickr_style/train.txt" - batch_size: 50 - transform_param { - crop_size: 227 - mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" - mirror: true - } - new_height: 256 - new_width: 256 - } - include: { phase: TEST } -} -layers { - name: "conv1" - type: CONVOLUTION - bottom: "data" - top: "conv1" - blobs_lr: 1 - blobs_lr: 2 - weight_decay: 1 - weight_decay: 0 - convolution_param { - num_output: 96 - kernel_size: 11 - stride: 4 - weight_filler { - type: "gaussian" - std: 0.01 - } - bias_filler { - type: "constant" - value: 0 - } - } -} -layers { - name: "relu1" - type: RELU - bottom: "conv1" - top: "conv1" -} -layers { - name: "pool1" - type: POOLING - bottom: "conv1" - top: "pool1" - pooling_param { - pool: MAX - kernel_size: 3 - stride: 2 - } -} -layers { - name: "norm1" - type: LRN - bottom: "pool1" - top: "norm1" - lrn_param { - local_size: 5 - alpha: 0.0001 - beta: 0.75 - } -} -layers { - name: "conv2" - type: CONVOLUTION - bottom: "norm1" - top: "conv2" - blobs_lr: 1 - blobs_lr: 2 - weight_decay: 1 - weight_decay: 0 - convolution_param { - num_output: 256 - pad: 2 - kernel_size: 5 - group: 2 - weight_filler { - type: "gaussian" - std: 0.01 - } - bias_filler { - type: "constant" - value: 1 - } - } -} -layers { - name: "relu2" - type: RELU - bottom: "conv2" - top: "conv2" -} -layers { - name: "pool2" - type: POOLING - bottom: "conv2" - top: "pool2" - pooling_param { - pool: MAX - kernel_size: 3 - stride: 2 - } -} -layers { - name: "norm2" - type: LRN - bottom: "pool2" - top: "norm2" - lrn_param { - local_size: 5 - alpha: 0.0001 - beta: 0.75 - } -} -layers { - name: "conv3" - type: CONVOLUTION - bottom: "norm2" - top: "conv3" - blobs_lr: 1 - blobs_lr: 2 - weight_decay: 1 - weight_decay: 0 - convolution_param { - num_output: 384 - pad: 1 - kernel_size: 3 - weight_filler { - type: "gaussian" - std: 0.01 - } - bias_filler { - type: "constant" - value: 0 - } - } -} -layers { - name: "relu3" - type: RELU - bottom: "conv3" - top: "conv3" -} -layers { - name: "conv4" - type: CONVOLUTION - bottom: "conv3" - top: "conv4" - blobs_lr: 1 - blobs_lr: 2 - weight_decay: 1 - weight_decay: 0 - convolution_param { - num_output: 384 - pad: 1 - kernel_size: 3 - group: 2 - weight_filler { - type: "gaussian" - std: 0.01 - } - bias_filler { - type: "constant" - value: 1 - } - } -} -layers { - name: "relu4" - type: RELU - bottom: "conv4" - top: "conv4" -} -layers { - name: "conv5" - type: CONVOLUTION - bottom: "conv4" - top: "conv5" - blobs_lr: 1 - blobs_lr: 2 - weight_decay: 1 - weight_decay: 0 - convolution_param { - num_output: 256 - pad: 1 - kernel_size: 3 - group: 2 - weight_filler { - type: "gaussian" - std: 0.01 - } - bias_filler { - type: "constant" - value: 1 - } - } -} -layers { - name: "relu5" - type: RELU - bottom: "conv5" - top: "conv5" -} -layers { - name: "pool5" - type: POOLING - bottom: "conv5" - top: "pool5" - pooling_param { - pool: MAX - kernel_size: 3 - stride: 2 - } -} -layers { - name: "fc6" - type: INNER_PRODUCT - bottom: "pool5" - top: "fc6" - blobs_lr: 1 - blobs_lr: 2 - weight_decay: 1 - weight_decay: 0 - inner_product_param { - num_output: 4096 - weight_filler { - type: "gaussian" - std: 0.005 - } - bias_filler { - type: "constant" - value: 1 - } - } -} -layers { - name: "relu6" - type: RELU - bottom: "fc6" - top: "fc6" -} -layers { - name: "drop6" - type: DROPOUT - bottom: "fc6" - top: "fc6" - dropout_param { - dropout_ratio: 0.5 - } -} -layers { - name: "fc7" - type: INNER_PRODUCT - bottom: "fc6" - top: "fc7" - # Note that blobs_lr can be set to 0 to disable any fine-tuning of this, and any other, layer - blobs_lr: 1 - blobs_lr: 2 - weight_decay: 1 - weight_decay: 0 - inner_product_param { - num_output: 4096 - weight_filler { - type: "gaussian" - std: 0.005 - } - bias_filler { - type: "constant" - value: 1 - } - } -} -layers { - name: "relu7" - type: RELU - bottom: "fc7" - top: "fc7" -} -layers { - name: "drop7" - type: DROPOUT - bottom: "fc7" - top: "fc7" - dropout_param { - dropout_ratio: 0.5 - } -} -layers { - name: "fc8_flickr" - type: INNER_PRODUCT - bottom: "fc7" - top: "fc8_flickr" - # blobs_lr is set to higher than for other layers, because this layer is starting from random while the others are already trained - blobs_lr: 10 - blobs_lr: 20 - weight_decay: 1 - weight_decay: 0 - inner_product_param { - num_output: 20 - weight_filler { - type: "gaussian" - std: 0.01 - } - bias_filler { - type: "constant" - value: 0 - } - } -} -layers { - name: "loss" - type: SOFTMAX_LOSS - bottom: "fc8_flickr" - bottom: "label" -} -layers { - name: "accuracy" - type: ACCURACY - bottom: "fc8_flickr" - bottom: "label" - top: "accuracy" - include: { phase: TEST } -} -- cgit v1.2.3