summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorSergey Karayev <sergeykarayev@gmail.com>2014-07-12 09:25:23 -0700
committerSergey Karayev <sergeykarayev@gmail.com>2014-07-12 09:25:23 -0700
commit9c61462cac1cea0856d180bc024e9937f357b274 (patch)
tree4a9fdd254d329ff8e4b653405c04a70015b8a6f2 /examples
parente07a2c0eea903e8d7c7050e14d177bbd0f8f611a (diff)
parentdd292da2485570488dc65464c8a9336bfc5906b3 (diff)
downloadcaffeonacl-9c61462cac1cea0856d180bc024e9937f357b274.tar.gz
caffeonacl-9c61462cac1cea0856d180bc024e9937f357b274.tar.bz2
caffeonacl-9c61462cac1cea0856d180bc024e9937f357b274.zip
back-merging [docs] changes and web demo [example] addition; updating
net_surgery example to new format Conflicts: docs/getting_pretrained_models.md docs/index.md
Diffstat (limited to 'examples')
-rw-r--r--examples/cifar10/readme.md98
-rw-r--r--examples/detection.ipynb6
-rw-r--r--examples/feature_extraction/readme.md74
-rw-r--r--examples/filter_visualization.ipynb4
-rw-r--r--examples/imagenet/readme.md105
-rw-r--r--examples/imagenet_classification.ipynb6
-rw-r--r--examples/mnist/readme.md266
-rw-r--r--examples/net_surgery.ipynb6
-rw-r--r--examples/web_demo/app.py215
-rw-r--r--examples/web_demo/exifutil.py33
-rw-r--r--examples/web_demo/readme.md30
-rw-r--r--examples/web_demo/templates/index.html138
12 files changed, 974 insertions, 7 deletions
diff --git a/examples/cifar10/readme.md b/examples/cifar10/readme.md
new file mode 100644
index 00000000..9d5bd7b2
--- /dev/null
+++ b/examples/cifar10/readme.md
@@ -0,0 +1,98 @@
+---
+title: CIFAR-10 tutorial
+category: example
+description: Train and test Caffe on CIFAR-10 data.
+include_in_docs: true
+layout: default
+---
+
+Alex's CIFAR-10 tutorial, Caffe style
+=====================================
+
+Alex Krizhevsky's [cuda-convnet](https://code.google.com/p/cuda-convnet/) details the model definitions, parameters, and training procedure for good performance on CIFAR-10. This example reproduces his results in Caffe.
+
+We will assume that you have Caffe successfully compiled. If not, please refer to the [Installation page](installation.html). In this tutorial, we will assume that your caffe installation is located at `CAFFE_ROOT`.
+
+We thank @chyojn for the pull request that defined the model schemas and solver configurations.
+
+*This example is a work-in-progress. It would be nice to further explain details of the network and training choices and benchmark the full training.*
+
+Prepare the Dataset
+-------------------
+
+You will first need to download and convert the data format from the [CIFAR-10 website](http://www.cs.toronto.edu/~kriz/cifar.html). To do this, simply run the following commands:
+
+ cd $CAFFE_ROOT/data/cifar10
+ ./get_cifar10.sh
+ cd $CAFFE_ROOT/examples/cifar10
+ ./create_cifar10.sh
+
+If it complains that `wget` or `gunzip` are not installed, you need to install them respectively. After running the script there should be the dataset, `./cifar10-leveldb`, and the data set image mean `./mean.binaryproto`.
+
+The Model
+---------
+
+The CIFAR-10 model is a CNN that composes layers of convolution, pooling, rectified linear unit (ReLU) nonlinearities, and local contrast normalization with a linear classifier on top of it all. We have defined the model in the `CAFFE_ROOT/examples/cifar10` directory's `cifar10_quick_train.prototxt`.
+
+Training and Testing the "Quick" Model
+--------------------------------------
+
+Training the model is simple after you have written the network definition protobuf and solver protobuf files. Simply run `train_quick.sh`, or the following command directly:
+
+ cd $CAFFE_ROOT/examples/cifar10
+ ./train_quick.sh
+
+`train_quick.sh` is a simple script, so have a look inside. `GLOG_logtostderr=1` is the google logging flag that prints all the logging messages directly to stderr. The main tool for training is `train_net.bin`, with the solver protobuf text file as its argument.
+
+When you run the code, you will see a lot of messages flying by like this:
+
+ I0317 21:52:48.945710 2008298256 net.cpp:74] Creating Layer conv1
+ I0317 21:52:48.945716 2008298256 net.cpp:84] conv1 <- data
+ I0317 21:52:48.945725 2008298256 net.cpp:110] conv1 -> conv1
+ I0317 21:52:49.298691 2008298256 net.cpp:125] Top shape: 100 32 32 32 (3276800)
+ I0317 21:52:49.298719 2008298256 net.cpp:151] conv1 needs backward computation.
+
+These messages tell you the details about each layer, its connections and its output shape, which may be helpful in debugging. After the initialization, the training will start:
+
+ I0317 21:52:49.309370 2008298256 net.cpp:166] Network initialization done.
+ I0317 21:52:49.309376 2008298256 net.cpp:167] Memory required for Data 23790808
+ I0317 21:52:49.309422 2008298256 solver.cpp:36] Solver scaffolding done.
+ I0317 21:52:49.309447 2008298256 solver.cpp:47] Solving CIFAR10_quick_train
+
+Based on the solver setting, we will print the training loss function every 100 iterations, and test the network every 500 iterations. You will see messages like this:
+
+ I0317 21:53:12.179772 2008298256 solver.cpp:208] Iteration 100, lr = 0.001
+ I0317 21:53:12.185698 2008298256 solver.cpp:65] Iteration 100, loss = 1.73643
+ ...
+ I0317 21:54:41.150030 2008298256 solver.cpp:87] Iteration 500, Testing net
+ I0317 21:54:47.129461 2008298256 solver.cpp:114] Test score #0: 0.5504
+ I0317 21:54:47.129500 2008298256 solver.cpp:114] Test score #1: 1.27805
+
+For each training iteration, `lr` is the learning rate of that iteration, and `loss` is the training function. For the output of the testing phase, **score 0 is the accuracy**, and **score 1 is the testing loss function**.
+
+And after making yourself a cup of coffee, you are done!
+
+ I0317 22:12:19.666914 2008298256 solver.cpp:87] Iteration 5000, Testing net
+ I0317 22:12:25.580330 2008298256 solver.cpp:114] Test score #0: 0.7533
+ I0317 22:12:25.580379 2008298256 solver.cpp:114] Test score #1: 0.739837
+ I0317 22:12:25.587262 2008298256 solver.cpp:130] Snapshotting to cifar10_quick_iter_5000
+ I0317 22:12:25.590215 2008298256 solver.cpp:137] Snapshotting solver state to cifar10_quick_iter_5000.solverstate
+ I0317 22:12:25.592813 2008298256 solver.cpp:81] Optimization Done.
+
+Our model achieved ~75% test accuracy. The model parameters are stored in binary protobuf format in
+
+ cifar10_quick_iter_5000
+
+which is ready-to-deploy in CPU or GPU mode! Refer to the `CAFFE_ROOT/examples/cifar10/cifar10_quick.prototxt` for the deployment model definition that can be called on new data.
+
+Why train on a GPU?
+-------------------
+
+CIFAR-10, while still small, has enough data to make GPU training attractive.
+
+To compare CPU vs. GPU training speed, simply change one line in all the `cifar*solver.prototxt`:
+
+ # solver mode: CPU or GPU
+ solver_mode: CPU
+
+and you will be using CPU for training.
diff --git a/examples/detection.ipynb b/examples/detection.ipynb
index ff0b7a7b..3f2cf71a 100644
--- a/examples/detection.ipynb
+++ b/examples/detection.ipynb
@@ -1,6 +1,8 @@
{
"metadata": {
- "name": ""
+ "name": "ImageNet detection",
+ "description": "Run a pretrained model as a detector in Python.",
+ "include_in_docs": true
},
"nbformat": 3,
"nbformat_minor": 0,
@@ -836,4 +838,4 @@
"metadata": {}
}
]
-} \ No newline at end of file
+}
diff --git a/examples/feature_extraction/readme.md b/examples/feature_extraction/readme.md
new file mode 100644
index 00000000..c336e718
--- /dev/null
+++ b/examples/feature_extraction/readme.md
@@ -0,0 +1,74 @@
+---
+title: Feature extraction with Caffe C++ code.
+description: Extract AlexNet features using the Caffe binary.
+category: example
+include_in_docs: true
+layout: default
+---
+
+Extracting Features
+===================
+
+In this tutorial, we will extract features using a pre-trained model.
+Follow instructions for [setting up caffe](installation.html) and for [getting](getting_pretrained_models.html) the pre-trained ImageNet model.
+If you need detailed information about the tools below, please consult their source code, in which additional documentation is usually provided.
+
+Select data to run on
+---------------------
+
+We'll make a temporary folder to store things into.
+
+ mkdir examples/_temp
+
+Generate a list of the files to process.
+We're going to use the images that ship with caffe.
+
+ find `pwd`/examples/images -type f -exec echo {} \; > examples/_temp/temp.txt
+
+The `ImageDataLayer` we'll use expects labels after each filenames, so let's add a 0 to the end of each line
+
+ sed "s/$/ 0/" examples/_temp/temp.txt > examples/_temp/file_list.txt
+
+Define the Feature Extraction Network Architecture
+--------------------------------------------------
+
+In practice, subtracting the mean image from a dataset significantly improves classification accuracies.
+Download the mean image of the ILSVRC dataset.
+
+ data/ilsvrc12/get_ilsvrc_aux.sh
+
+We will use `data/ilsvrc212/imagenet_mean.binaryproto` in the network definition prototxt.
+
+Let's copy and modify the network definition.
+We'll be using the `ImageDataLayer`, which will load and resize images for us.
+
+ cp examples/feature_extraction/imagenet_val.prototxt examples/_temp
+
+Edit `examples/_temp/imagenet_val.prototxt` to use correct path for your setup (replace `$CAFFE_DIR`)
+
+Extract Features
+----------------
+
+Now everything necessary is in place.
+
+ build/tools/extract_features.bin examples/imagenet/caffe_reference_imagenet_model examples/_temp/imagenet_val.prototxt fc7 examples/_temp/features 10
+
+The name of feature blob that you extract is `fc7`, which represents the highest level feature of the reference model.
+We can use any other layer, as well, such as `conv5` or `pool3`.
+
+The last parameter above is the number of data mini-batches.
+
+The features are stored to LevelDB `examples/_temp/features`, ready for access by some other code.
+
+If you meet with the error "Check failed: status.ok() Failed to open leveldb examples/_temp/features", it is because the directory examples/_temp/features has been created the last time you run the command. Remove it and run again.
+
+ rm -rf examples/_temp/features/
+
+If you'd like to use the Python wrapper for extracting features, check out the [layer visualization notebook](http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/filter_visualization.ipynb).
+
+Clean Up
+--------
+
+Let's remove the temporary directory now.
+
+ rm -r examples/_temp
diff --git a/examples/filter_visualization.ipynb b/examples/filter_visualization.ipynb
index 130df970..56f6b8d3 100644
--- a/examples/filter_visualization.ipynb
+++ b/examples/filter_visualization.ipynb
@@ -1,6 +1,8 @@
{
"metadata": {
- "name": ""
+ "name": "Filter visualization",
+ "description": "Extracting features and visualizing trained filters with an example image, viewed layer-by-layer.",
+ "include_in_docs": true
},
"nbformat": 3,
"nbformat_minor": 0,
diff --git a/examples/imagenet/readme.md b/examples/imagenet/readme.md
new file mode 100644
index 00000000..e74e6b86
--- /dev/null
+++ b/examples/imagenet/readme.md
@@ -0,0 +1,105 @@
+---
+title: ImageNet tutorial
+description: Train and test "CaffeNet" on ImageNet challenge data.
+category: example
+include_in_docs: true
+layout: default
+---
+
+Yangqing's Recipe on Brewing ImageNet
+=====================================
+
+ "All your braincells are belong to us."
+ - Caffeine
+
+We are going to describe a reference implementation for the approach first proposed by Krizhevsky, Sutskever, and Hinton in their [NIPS 2012 paper](http://books.nips.cc/papers/files/nips25/NIPS2012_0534.pdf). Since training the whole model takes some time and energy, we provide a model, trained in the same way as we describe here, to help fight global warming. If you would like to simply use the pretrained model, check out the [Pretrained ImageNet](getting_pretrained_models.html) page. *Note that the pretrained model is for academic research / non-commercial use only*.
+
+To clarify, by ImageNet we actually mean the ILSVRC12 challenge, but you can easily train on the whole of ImageNet as well, just with more disk space, and a little longer training time.
+
+(If you don't get the quote, visit [Yann LeCun's fun page](http://yann.lecun.com/ex/fun/).
+
+Data Preparation
+----------------
+
+We assume that you already have downloaded the ImageNet training data and validation data, and they are stored on your disk like:
+
+ /path/to/imagenet/train/n01440764/n01440764_10026.JPEG
+ /path/to/imagenet/val/ILSVRC2012_val_00000001.JPEG
+
+You will first need to prepare some auxiliary data for training. This data can be downloaded by:
+
+ cd $CAFFE_ROOT/data/ilsvrc12/
+ ./get_ilsvrc_aux.sh
+
+The training and validation input are described in `train.txt` and `val.txt` as text listing all the files and their labels. Note that we use a different indexing for labels than the ILSVRC devkit: we sort the synset names in their ASCII order, and then label them from 0 to 999. See `synset_words.txt` for the synset/name mapping.
+
+You may want to resize the images to 256x256 in advance. By default, we do not explicitly do this because in a cluster environment, one may benefit from resizing images in a parallel fashion, using mapreduce. For example, Yangqing used his lightedweighted [mincepie](https://github.com/Yangqing/mincepie) package to do mapreduce on the Berkeley cluster. If you would things to be rather simple and straightforward, you can also use shell commands, something like:
+
+ for name in /path/to/imagenet/val/*.JPEG; do
+ convert -resize 256x256\! $name $name
+ done
+
+Go to `$CAFFE_ROOT/examples/imagenet/` for the rest of this guide.
+
+Take a look at `create_imagenet.sh`. Set the paths to the train and val dirs as needed, and set "RESIZE=true" to resize all images to 256x256 if you haven't resized the images in advance. Now simply create the leveldbs with `./create_imagenet.sh`. Note that `imagenet_train_leveldb` and `imagenet_val_leveldb` should not exist before this execution. It will be created by the script. `GLOG_logtostderr=1` simply dumps more information for you to inspect, and you can safely ignore it.
+
+Compute Image Mean
+------------------
+
+The model requires us to subtract the image mean from each image, so we have to compute the mean. `tools/compute_image_mean.cpp` implements that - it is also a good example to familiarize yourself on how to manipulate the multiple components, such as protocol buffers, leveldbs, and logging, if you are not familiar with them. Anyway, the mean computation can be carried out as:
+
+ ./make_imagenet_mean.sh
+
+which will make `data/ilsvrc12/imagenet_mean.binaryproto`.
+
+Network Definition
+------------------
+
+The network definition follows strictly the one in Krizhevsky et al. You can find the detailed definition at `examples/imagenet/imagenet_train.prototxt`. Note the paths in the data layer - if you have not followed the exact paths in this guide you will need to change the following lines:
+
+ source: "ilvsrc12_train_leveldb"
+ mean_file: "../../data/ilsvrc12/imagenet_mean.binaryproto"
+
+to point to your own leveldb and image mean. Likewise, do the same for `examples/imagenet/imagenet_val.prototxt`.
+
+If you look carefully at `imagenet_train.prototxt` and `imagenet_val.prototxt`, you will notice that they are largely the same, with the only difference being the data layer sources, and the last layer: in training, we will be using a `softmax_loss` layer to compute the loss function and to initialize the backpropagation, while in validation we will be using an `accuracy` layer to inspect how well we do in terms of accuracy.
+
+We will also lay out a protocol buffer for running the solver. Let's make a few plans:
+* We will run in batches of 256, and run a total of 4,500,000 iterations (about 90 epochs).
+* For every 1,000 iterations, we test the learned net on the validation data.
+* We set the initial learning rate to 0.01, and decrease it every 100,000 iterations (about 20 epochs).
+* Information will be displayed every 20 epochs.
+* The network will be trained with momentum 0.9 and a weight decay of 0.0005.
+* For every 10,000 iterations, we will take a snapshot of the current status.
+
+Sound good? This is implemented in `examples/imagenet/imagenet_solver.prototxt`. Again, you will need to change the first two lines:
+
+ train_net: "imagenet_train.prototxt"
+ test_net: "imagenet_val.prototxt"
+
+to point to the actual path if you have changed them.
+
+Training ImageNet
+-----------------
+
+Ready? Let's train.
+
+ ./train_imagenet.sh
+
+Sit back and enjoy! On my K20 machine, every 20 iterations take about 36 seconds to run, so effectively about 7 ms per image for the full forward-backward pass. About 2.5 ms of this is on forward, and the rest is backward. If you are interested in dissecting the computation time, you can look at `examples/net_speed_benchmark.cpp`, but it was written purely for debugging purpose, so you may need to figure a few things out yourself.
+
+Resume Training?
+----------------
+
+We all experience times when the power goes out, or we feel like rewarding ourself a little by playing Battlefield (does someone still remember Quake?). Since we are snapshotting intermediate results during training, we will be able to resume from snapshots. This can be done as easy as:
+
+ ./resume_training.sh
+
+where in the script `caffe_imagenet_train_1000.solverstate` is the solver state snapshot that stores all necessary information to recover the exact solver state (including the parameters, momentum history, etc).
+
+Parting Words
+-------------
+
+Hope you liked this recipe! Many researchers have gone further since the ILSVRC 2012 challenge, changing the network architecture and/or finetuning the various parameters in the network. The recent ILSVRC 2013 challenge suggests that there are quite some room for improvement. **Caffe allows one to explore different network choices more easily, by simply writing different prototxt files** - isn't that exciting?
+
+And since now you have a trained network, check out how to use it: [Running Pretrained ImageNet](getting_pretrained_models.html). This time we will use Python, but if you have wrappers for other languages, please kindly send a pull request!
diff --git a/examples/imagenet_classification.ipynb b/examples/imagenet_classification.ipynb
index 0e0e06bb..7ac140d9 100644
--- a/examples/imagenet_classification.ipynb
+++ b/examples/imagenet_classification.ipynb
@@ -1,6 +1,8 @@
{
"metadata": {
- "name": ""
+ "description": "Use the pre-trained ImageNet model to classify images with the Python interface.",
+ "name": "ImageNet Classification",
+ "include_in_docs": true
},
"nbformat": 3,
"nbformat_minor": 0,
@@ -407,4 +409,4 @@
"metadata": {}
}
]
-} \ No newline at end of file
+}
diff --git a/examples/mnist/readme.md b/examples/mnist/readme.md
new file mode 100644
index 00000000..d609cfff
--- /dev/null
+++ b/examples/mnist/readme.md
@@ -0,0 +1,266 @@
+---
+title: MNIST Tutorial
+description: Train and test "LeNet" on MNIST data.
+category: example
+include_in_docs: true
+layout: default
+---
+
+# Training MNIST with Caffe
+
+We will assume that you have caffe successfully compiled. If not, please refer to the [Installation page](installation.html). In this tutorial, we will assume that your caffe installation is located at `CAFFE_ROOT`.
+
+## Prepare Datasets
+
+You will first need to download and convert the data format from the MNIST website. To do this, simply run the following commands:
+
+ cd $CAFFE_ROOT/data/mnist
+ ./get_mnist.sh
+ cd $CAFFE_ROOT/examples/mnist
+ ./create_mnist.sh
+
+If it complains that `wget` or `gunzip` are not installed, you need to install them respectively. After running the script there should be two datasets, `mnist-train-leveldb`, and `mnist-test-leveldb`.
+
+## LeNet: the MNIST Classification Model
+
+Before we actually run the training program, let's explain what will happen. We will use the [LeNet](http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf) network, which is known to work well on digit classification tasks. We will use a slightly different version from the original LeNet implementation, replacing the sigmoid activations with Rectified Linear Unit (ReLU) activations for the neurons.
+
+The design of LeNet contains the essence of CNNs that are still used in larger models such as the ones in ImageNet. In general, it consists of a convolutional layer followed by a pooling layer, another convolution layer followed by a pooling layer, and then two fully connected layers similar to the conventional multilayer perceptrons. We have defined the layers in `CAFFE_ROOT/data/lenet.prototxt`.
+
+## Define the MNIST Network
+
+This section explains the prototxt file `lenet_train.prototxt` used in the MNIST demo. We assume that you are familiar with [Google Protobuf](https://developers.google.com/protocol-buffers/docs/overview), and assume that you have read the protobuf definitions used by Caffe, which can be found at [src/caffe/proto/caffe.proto](https://github.com/Yangqing/caffe/blob/master/src/caffe/proto/caffe.proto).
+
+Specifically, we will write a `caffe::NetParameter` (or in python, `caffe.proto.caffe_pb2.NetParameter`) protubuf. We will start by giving the network a name:
+
+ name: "LeNet"
+
+### Writing the Data Layer
+
+Currently, we will read the MNIST data from the leveldb we created earlier in the demo. This is defined by a data layer:
+
+ layers {
+ name: "mnist"
+ type: DATA
+ data_param {
+ source: "mnist-train-leveldb"
+ batch_size: 64
+ scale: 0.00390625
+ }
+ top: "data"
+ top: "label"
+ }
+
+Specifically, this layer has name `mnist`, type `data`, and it reads the data from the given leveldb source. We will use a batch size of 64, and scale the incoming pixels so that they are in the range \[0,1\). Why 0.00390625? It is 1 divided by 256. And finally, this layer produces two blobs, one is the `data` blob, and one is the `label` blob.
+
+### Writing the Convolution Layer
+
+Let's define the first convolution layer:
+
+ layers {
+ name: "conv1"
+ type: CONVOLUTION
+ blobs_lr: 1.
+ blobs_lr: 2.
+ convolution_param {
+ num_output: 20
+ kernelsize: 5
+ stride: 1
+ weight_filler {
+ type: "xavier"
+ }
+ bias_filler {
+ type: "constant"
+ }
+ }
+ bottom: "data"
+ top: "conv1"
+ }
+
+This layer takes the `data` blob (it is provided by the data layer), and produces the `conv1` layer. It produces outputs of 20 channels, with the convolutional kernel size 5 and carried out with stride 1.
+
+The fillers allow us to randomly initialize the value of the weights and bias. For the weight filler, we will use the `xavier` algorithm that automatically determines the scale of initialization based on the number of input and output neurons. For the bias filler, we will simply initialize it as constant, with the default filling value 0.
+
+`blobs_lr` are the learning rate adjustments for the layer's learnable parameters. In this case, we will set the weight learning rate to be the same as the learning rate given by the solver during runtime, and the bias learning rate to be twice as large as that - this usually leads to better convergence rates.
+
+### Writing the Pooling Layer
+
+Phew. Pooling layers are actually much easier to define:
+
+ layers {
+ name: "pool1"
+ type: POOLING
+ pooling_param {
+ kernel_size: 2
+ stride: 2
+ pool: MAX
+ }
+ bottom: "conv1"
+ top: "pool1"
+ }
+
+This says we will perform max pooling with a pool kernel size 2 and a stride of 2 (so no overlapping between neighboring pooling regions).
+
+Similarly, you can write up the second convolution and pooling layers. Check `data/lenet.prototxt` for details.
+
+### Writing the Fully Connected Layer
+
+Writing a fully connected layer is also simple:
+
+ layers {
+ name: "ip1"
+ type: INNER_PRODUCT
+ blobs_lr: 1.
+ blobs_lr: 2.
+ inner_product_param {
+ num_output: 500
+ weight_filler {
+ type: "xavier"
+ }
+ bias_filler {
+ type: "constant"
+ }
+ }
+ bottom: "pool2"
+ top: "ip1"
+ }
+
+This defines a fully connected layer (for some legacy reason, Caffe calls it an `innerproduct` layer) with 500 outputs. All other lines look familiar, right?
+
+### Writing the ReLU Layer
+
+A ReLU Layer is also simple:
+
+ layers {
+ name: "relu1"
+ type: RELU
+ bottom: "ip1"
+ top: "ip1"
+ }
+
+Since ReLU is an element-wise operation, we can do *in-place* operations to save some memory. This is achieved by simply giving the same name to the bottom and top blobs. Of course, do NOT use duplicated blob names for other layer types!
+
+After the ReLU layer, we will write another innerproduct layer:
+
+ layers {
+ name: "ip2"
+ type: INNER_PRODUCT
+ blobs_lr: 1.
+ blobs_lr: 2.
+ inner_product_param {
+ num_output: 10
+ weight_filler {
+ type: "xavier"
+ }
+ bias_filler {
+ type: "constant"
+ }
+ }
+ bottom: "ip1"
+ top: "ip2"
+ }
+
+### Writing the Loss Layer
+
+Finally, we will write the loss!
+
+ layers {
+ name: "loss"
+ type: SOFTMAX_LOSS
+ bottom: "ip2"
+ bottom: "label"
+ }
+
+The `softmax_loss` layer implements both the softmax and the multinomial logistic loss (that saves time and improves numerical stability). It takes two blobs, the first one being the prediction and the second one being the `label` provided by the data layer (remember it?). It does not produce any outputs - all it does is to compute the loss function value, report it when backpropagation starts, and initiates the gradient with respect to `ip2`. This is where all magic starts.
+
+## Define the MNIST Solver
+
+Check out the comments explaining each line in the prototxt:
+
+ # The training protocol buffer definition
+ train_net: "lenet_train.prototxt"
+ # The testing protocol buffer definition
+ test_net: "lenet_test.prototxt"
+ # test_iter specifies how many forward passes the test should carry out.
+ # In the case of MNIST, we have test batch size 100 and 100 test iterations,
+ # covering the full 10,000 testing images.
+ test_iter: 100
+ # Carry out testing every 500 training iterations.
+ test_interval: 500
+ # The base learning rate, momentum and the weight decay of the network.
+ base_lr: 0.01
+ momentum: 0.9
+ weight_decay: 0.0005
+ # The learning rate policy
+ lr_policy: "inv"
+ gamma: 0.0001
+ power: 0.75
+ # Display every 100 iterations
+ display: 100
+ # The maximum number of iterations
+ max_iter: 10000
+ # snapshot intermediate results
+ snapshot: 5000
+ snapshot_prefix: "lenet"
+ # solver mode: 0 for CPU and 1 for GPU
+ solver_mode: 1
+
+## Training and Testing the Model
+
+Training the model is simple after you have written the network definition protobuf and solver protobuf files. Simply run `train_mnist.sh`, or the following command directly:
+
+ cd $CAFFE_ROOT/examples/mnist
+ ./train_lenet.sh
+
+`train_lenet.sh` is a simple script, but here are a few explanations: `GLOG_logtostderr=1` is the google logging flag that prints all the logging messages directly to stderr. The main tool for training is `train_net.bin`, with the solver protobuf text file as its argument.
+
+When you run the code, you will see a lot of messages flying by like this:
+
+ I1203 net.cpp:66] Creating Layer conv1
+ I1203 net.cpp:76] conv1 <- data
+ I1203 net.cpp:101] conv1 -> conv1
+ I1203 net.cpp:116] Top shape: 20 24 24
+ I1203 net.cpp:127] conv1 needs backward computation.
+
+These messages tell you the details about each layer, its connections and its output shape, which may be helpful in debugging. After the initialization, the training will start:
+
+ I1203 net.cpp:142] Network initialization done.
+ I1203 solver.cpp:36] Solver scaffolding done.
+ I1203 solver.cpp:44] Solving LeNet
+
+Based on the solver setting, we will print the training loss function every 100 iterations, and test the network every 1000 iterations. You will see messages like this:
+
+ I1203 solver.cpp:204] Iteration 100, lr = 0.00992565
+ I1203 solver.cpp:66] Iteration 100, loss = 0.26044
+ ...
+ I1203 solver.cpp:84] Testing net
+ I1203 solver.cpp:111] Test score #0: 0.9785
+ I1203 solver.cpp:111] Test score #1: 0.0606671
+
+For each training iteration, `lr` is the learning rate of that iteration, and `loss` is the training function. For the output of the testing phase, score 0 is the accuracy, and score 1 is the testing loss function.
+
+And after a few minutes, you are done!
+
+ I1203 solver.cpp:84] Testing net
+ I1203 solver.cpp:111] Test score #0: 0.9897
+ I1203 solver.cpp:111] Test score #1: 0.0324599
+ I1203 solver.cpp:126] Snapshotting to lenet_iter_10000
+ I1203 solver.cpp:133] Snapshotting solver state to lenet_iter_10000.solverstate
+ I1203 solver.cpp:78] Optimization Done.
+
+The final model, stored as a binary protobuf file, is stored at
+
+ lenet_iter_10000
+
+which you can deploy as a trained model in your application, if you are training on a real-world application dataset.
+
+### Um... How about GPU training?
+
+You just did! All the training was carried out on the GPU. In fact, if you would like to do training on CPU, you can simply change one line in `lenet_solver.prototxt`:
+
+ # solver mode: CPU or GPU
+ solver_mode: CPU
+
+and you will be using CPU for training. Isn't that easy?
+
+MNIST is a small dataset, so training with GPU does not really introduce too much benefit due to communication overheads. On larger datasets with more complex models, such as ImageNet, the computation speed difference will be more significant.
diff --git a/examples/net_surgery.ipynb b/examples/net_surgery.ipynb
index 6cba8bb9..550f4112 100644
--- a/examples/net_surgery.ipynb
+++ b/examples/net_surgery.ipynb
@@ -1,6 +1,8 @@
{
"metadata": {
- "name": ""
+ "name": "Editing model parameters",
+ "description": "How to do net surgery and manually change model parameters.",
+ "include_in_docs": true
},
"nbformat": 3,
"nbformat_minor": 0,
@@ -324,4 +326,4 @@
"metadata": {}
}
]
-} \ No newline at end of file
+}
diff --git a/examples/web_demo/app.py b/examples/web_demo/app.py
new file mode 100644
index 00000000..9bc4ed5c
--- /dev/null
+++ b/examples/web_demo/app.py
@@ -0,0 +1,215 @@
+import os
+import time
+import cPickle
+import datetime
+import logging
+import flask
+import werkzeug
+import optparse
+import tornado.wsgi
+import tornado.httpserver
+import numpy as np
+import pandas as pd
+from PIL import Image as PILImage
+import cStringIO as StringIO
+import urllib
+import caffe
+import exifutil
+
+REPO_DIRNAME = os.path.abspath(os.path.dirname(__file__) + '/../..')
+UPLOAD_FOLDER = '/tmp/caffe_demos_uploads'
+ALLOWED_IMAGE_EXTENSIONS = set(['png', 'bmp', 'jpg', 'jpe', 'jpeg', 'gif'])
+
+# Obtain the flask app object
+app = flask.Flask(__name__)
+
+
+@app.route('/')
+def index():
+ return flask.render_template('index.html', has_result=False)
+
+
+@app.route('/classify_url', methods=['GET'])
+def classify_url():
+ imageurl = flask.request.args.get('imageurl', '')
+ try:
+ string_buffer = StringIO.StringIO(
+ urllib.urlopen(imageurl).read())
+ image = caffe.io.load_image(string_buffer)
+
+ except Exception as err:
+ # For any exception we encounter in reading the image, we will just
+ # not continue.
+ logging.info('URL Image open error: %s', err)
+ return flask.render_template(
+ 'index.html', has_result=True,
+ result=(False, 'Cannot open image from URL.')
+ )
+
+ logging.info('Image: %s', imageurl)
+ result = app.clf.classify_image(image)
+ return flask.render_template(
+ 'index.html', has_result=True, result=result, imagesrc=imageurl)
+
+
+@app.route('/classify_upload', methods=['POST'])
+def classify_upload():
+ try:
+ # We will save the file to disk for possible data collection.
+ imagefile = flask.request.files['imagefile']
+ filename_ = str(datetime.datetime.now()).replace(' ', '_') + \
+ werkzeug.secure_filename(imagefile.filename)
+ filename = os.path.join(UPLOAD_FOLDER, filename_)
+ imagefile.save(filename)
+ logging.info('Saving to %s.', filename)
+ image = exifutil.open_oriented_im(filename)
+
+ except Exception as err:
+ logging.info('Uploaded image open error: %s', err)
+ return flask.render_template(
+ 'index.html', has_result=True,
+ result=(False, 'Cannot open uploaded image.')
+ )
+
+ result = app.clf.classify_image(image)
+ return flask.render_template(
+ 'index.html', has_result=True, result=result,
+ imagesrc=embed_image_html(image)
+ )
+
+
+def embed_image_html(image):
+ """Creates an image embedded in HTML base64 format."""
+ image_pil = PILImage.fromarray((255 * image).astype('uint8'))
+ image_pil = image_pil.resize((256, 256))
+ string_buf = StringIO.StringIO()
+ image_pil.save(string_buf, format='png')
+ data = string_buf.getvalue().encode('base64').replace('\n', '')
+ return 'data:image/png;base64,' + data
+
+
+def allowed_file(filename):
+ return (
+ '.' in filename and
+ filename.rsplit('.', 1)[1] in ALLOWED_IMAGE_EXTENSIONS
+ )
+
+
+class ImagenetClassifier(object):
+ default_args = {
+ 'model_def_file': (
+ '{}/examples/imagenet/imagenet_deploy.prototxt'.format(REPO_DIRNAME)),
+ 'pretrained_model_file': (
+ '{}/examples/imagenet/caffe_reference_imagenet_model'.format(REPO_DIRNAME)),
+ 'mean_file': (
+ '{}/python/caffe/imagenet/ilsvrc_2012_mean.npy'.format(REPO_DIRNAME)),
+ 'class_labels_file': (
+ '{}/data/ilsvrc12/synset_words.txt'.format(REPO_DIRNAME)),
+ 'bet_file': (
+ '{}/data/ilsvrc12/imagenet.bet.pickle'.format(REPO_DIRNAME)),
+ }
+ for key, val in default_args.iteritems():
+ if not os.path.exists(val):
+ raise Exception(
+ "File for {} is missing. Should be at: {}".format(key, val))
+ default_args['image_dim'] = 227
+ default_args['gpu_mode'] = True
+
+ def __init__(self, model_def_file, pretrained_model_file, mean_file,
+ class_labels_file, bet_file, image_dim, gpu_mode=False):
+ logging.info('Loading net and associated files...')
+ self.net = caffe.Classifier(
+ model_def_file, pretrained_model_file, input_scale=255,
+ image_dims=(image_dim, image_dim), gpu=gpu_mode,
+ mean_file=mean_file, channel_swap=(2, 1, 0)
+ )
+
+ with open(class_labels_file) as f:
+ labels_df = pd.DataFrame([
+ {
+ 'synset_id': l.strip().split(' ')[0],
+ 'name': ' '.join(l.strip().split(' ')[1:]).split(',')[0]
+ }
+ for l in f.readlines()
+ ])
+ self.labels = labels_df.sort('synset_id')['name'].values
+
+ self.bet = cPickle.load(open(bet_file))
+ # A bias to prefer children nodes in single-chain paths
+ # I am setting the value to 0.1 as a quick, simple model.
+ # We could use better psychological models here...
+ self.bet['infogain'] -= np.array(self.bet['preferences']) * 0.1
+
+ def classify_image(self, image):
+ try:
+ starttime = time.time()
+ scores = self.net.predict([image], oversample=True).flatten()
+ endtime = time.time()
+
+ indices = (-scores).argsort()[:5]
+ predictions = self.labels[indices]
+
+ # In addition to the prediction text, we will also produce
+ # the length for the progress bar visualization.
+ meta = [
+ (p, '%.5f' % scores[i])
+ for i, p in zip(indices, predictions)
+ ]
+ logging.info('result: %s', str(meta))
+
+ # Compute expected information gain
+ expected_infogain = np.dot(
+ self.bet['probmat'], scores[self.bet['idmapping']])
+ expected_infogain *= self.bet['infogain']
+
+ # sort the scores
+ infogain_sort = expected_infogain.argsort()[::-1]
+ bet_result = [(self.bet['words'][v], '%.5f' % expected_infogain[v])
+ for v in infogain_sort[:5]]
+ logging.info('bet result: %s', str(bet_result))
+
+ return (True, meta, bet_result, '%.3f' % (endtime - starttime))
+
+ except Exception as err:
+ logging.info('Classification error: %s', err)
+ return (False, 'Something went wrong when classifying the '
+ 'image. Maybe try another one?')
+
+
+def start_tornado(app, port=5000):
+ http_server = tornado.httpserver.HTTPServer(
+ tornado.wsgi.WSGIContainer(app))
+ http_server.listen(port)
+ print("Tornado server starting on port {}".format(port))
+ tornado.ioloop.IOLoop.instance().start()
+
+
+def start_from_terminal(app):
+ """
+ Parse command line options and start the server.
+ """
+ parser = optparse.OptionParser()
+ parser.add_option(
+ '-d', '--debug',
+ help="enable debug mode",
+ action="store_true", default=False)
+ parser.add_option(
+ '-p', '--port',
+ help="which port to serve content on",
+ type='int', default=5000)
+ opts, args = parser.parse_args()
+
+ # Initialize classifier
+ app.clf = ImagenetClassifier(**ImagenetClassifier.default_args)
+
+ if opts.debug:
+ app.run(debug=True, host='0.0.0.0', port=opts.port)
+ else:
+ start_tornado(app, opts.port)
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ if not os.path.exists(UPLOAD_FOLDER):
+ os.makedirs(UPLOAD_FOLDER)
+ start_from_terminal(app)
diff --git a/examples/web_demo/exifutil.py b/examples/web_demo/exifutil.py
new file mode 100644
index 00000000..8c07aa88
--- /dev/null
+++ b/examples/web_demo/exifutil.py
@@ -0,0 +1,33 @@
+"""
+This script handles the skimage exif problem.
+"""
+
+from PIL import Image
+import numpy as np
+
+ORIENTATIONS = { # used in apply_orientation
+ 2: (Image.FLIP_LEFT_RIGHT,),
+ 3: (Image.ROTATE_180,),
+ 4: (Image.FLIP_TOP_BOTTOM,),
+ 5: (Image.FLIP_LEFT_RIGHT, Image.ROTATE_90),
+ 6: (Image.ROTATE_270,),
+ 7: (Image.FLIP_LEFT_RIGHT, Image.ROTATE_270),
+ 8: (Image.ROTATE_90,)
+}
+
+
+def open_oriented_im(im_path):
+ im = Image.open(im_path)
+ if hasattr(im, '_getexif'):
+ exif = im._getexif()
+ if exif is not None and 274 in exif:
+ orientation = exif[274]
+ im = apply_orientation(im, orientation)
+ return np.asarray(im).astype(np.float32) / 255.
+
+
+def apply_orientation(im, orientation):
+ if orientation in ORIENTATIONS:
+ for method in ORIENTATIONS[orientation]:
+ im = im.transpose(method)
+ return im
diff --git a/examples/web_demo/readme.md b/examples/web_demo/readme.md
new file mode 100644
index 00000000..559c41e0
--- /dev/null
+++ b/examples/web_demo/readme.md
@@ -0,0 +1,30 @@
+---
+title: Web demo
+description: Image classification demo running as a Flask web server.
+category: example
+layout: default
+include_in_docs: true
+---
+
+# Web Demo
+
+## Requirements
+
+The demo server requires Python with some dependencies.
+To make sure you have the dependencies, please run `pip install -r examples/web_demo/requirements.txt`, and also make sure that you've compiled the Python Caffe interface and that it is on your `PYTHONPATH` (see [installation instructions](/installation.html)).
+
+Make sure that you have obtained the Caffe Reference ImageNet Model and the ImageNet Auxiliary Data ([instructions](/getting_pretrained_models.html)).
+NOTE: if you run into trouble, try re-downloading the auxiliary files.
+
+## Run
+
+Running `python examples/web_demo/app.py` will bring up the demo server, accessible at `http://0.0.0.0:5000`.
+You can enable debug mode of the web server, or switch to a different port:
+
+ % python examples/web_demo/app.py -h
+ Usage: app.py [options]
+
+ Options:
+ -h, --help show this help message and exit
+ -d, --debug enable debug mode
+ -p PORT, --port=PORT which port to serve content on
diff --git a/examples/web_demo/templates/index.html b/examples/web_demo/templates/index.html
new file mode 100644
index 00000000..87893341
--- /dev/null
+++ b/examples/web_demo/templates/index.html
@@ -0,0 +1,138 @@
+<!DOCTYPE html>
+<html lang="en">
+ <head>
+ <meta charset="utf-8">
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
+ <meta name="description" content="Caffe demos">
+ <meta name="author" content="BVLC (http://bvlc.eecs.berkeley.edu/)">
+
+ <title>Caffe Demos</title>
+
+ <link href="//netdna.bootstrapcdn.com/bootstrap/3.1.1/css/bootstrap.min.css" rel="stylesheet">
+
+ <script type="text/javascript" src="//code.jquery.com/jquery-2.1.1.js"></script>
+ <script src="//netdna.bootstrapcdn.com/bootstrap/3.1.1/js/bootstrap.min.js"></script>
+
+ <!-- Script to instantly classify an image once it is uploaded. -->
+ <script type="text/javascript">
+ $(document).ready(
+ function(){
+ $('#classifyfile').attr('disabled',true);
+ $('#imagefile').change(
+ function(){
+ if ($(this).val()){
+ $('#formupload').submit();
+ }
+ }
+ );
+ }
+ );
+ </script>
+
+ <style>
+ body {
+ font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
+ line-height:1.5em;
+ color: #232323;
+ -webkit-font-smoothing: antialiased;
+ }
+
+ h1, h2, h3 {
+ font-family: Times, serif;
+ line-height:1.5em;
+ border-bottom: 1px solid #ccc;
+ }
+ </style>
+ </head>
+
+ <body>
+ <!-- Begin page content -->
+ <div class="container">
+ <div class="page-header">
+ <h1><a href="/">Caffe Demos</a></h1>
+ <p>
+ The <a href="http://caffe.berkeleyvision.org">Caffe</a> neural network library makes implementing state-of-the-art computer vision systems easy.
+ </p>
+ </div>
+
+ <div>
+ <h2>Classification</h2>
+ <a href="/classify_url?imageurl=http%3A%2F%2Fi.telegraph.co.uk%2Fmultimedia%2Farchive%2F02351%2Fcross-eyed-cat_2351472k.jpg">Click for a Quick Example</a>
+ </div>
+
+ {% if has_result %}
+ {% if not result[0] %}
+ <!-- we have error in the result. -->
+ <div class="alert alert-danger">{{ result[1] }} Did you provide a valid URL or a valid image file? </div>
+ {% else %}
+ <div class="media">
+ <a class="pull-left" href="#"><img class="media-object" width="192" height="192" src={{ imagesrc }}></a>
+ <div class="media-body">
+ <div class="bs-example bs-example-tabs">
+ <ul id="myTab" class="nav nav-tabs">
+ <li class="active"><a href="#infopred" data-toggle="tab">Maximally accurate</a></li>
+ <li><a href="#flatpred" data-toggle="tab">Maximally specific</a></li>
+ </ul>
+ <div id="myTabContent" class="tab-content">
+ <div class="tab-pane fade in active" id="infopred">
+ <ul class="list-group">
+ {% for single_pred in result[2] %}
+ <li class="list-group-item">
+ <span class="badge">{{ single_pred[1] }}</span>
+ <h4 class="list-group-item-heading">
+ <a href="https://www.google.com/#q={{ single_pred[0] }}" target="_blank">{{ single_pred[0] }}</a>
+ </h4>
+ </li>
+ {% endfor %}
+ </ul>
+ </div>
+ <div class="tab-pane fade" id="flatpred">
+ <ul class="list-group">
+ {% for single_pred in result[1] %}
+ <li class="list-group-item">
+ <span class="badge">{{ single_pred[1] }}</span>
+ <h4 class="list-group-item-heading">
+ <a href="https://www.google.com/#q={{ single_pred[0] }}" target="_blank">{{ single_pred[0] }}</a>
+ </h4>
+ </li>
+ {% endfor %}
+ </ul>
+ </div>
+ </div>
+ </div>
+
+ </div>
+ </div>
+ <p> CNN took {{ result[3] }} seconds. </p>
+ {% endif %}
+ <hr>
+ {% endif %}
+
+ <form role="form" action="classify_url" method="get">
+ <div class="form-group">
+ <div class="input-group">
+ <input type="text" class="form-control" name="imageurl" id="imageurl" placeholder="Provide an image URL">
+ <span class="input-group-btn">
+ <input class="btn btn-primary" value="Classify URL" type="submit" id="classifyurl"></input>
+ </span>
+ </div><!-- /input-group -->
+ </div>
+ </form>
+
+ <form id="formupload" class="form-inline" role="form" action="classify_upload" method="post" enctype="multipart/form-data">
+ <div class="form-group">
+ <label for="imagefile">Or upload an image:</label>
+ <input type="file" name="imagefile" id="imagefile">
+ </div>
+ <!--<input type="submit" class="btn btn-primary" value="Classify File" id="classifyfile"></input>-->
+ </form>
+ </div>
+
+ <hr>
+ <div id="footer">
+ <div class="container">
+ <p>&copy; BVLC 2014</p>
+ </div>
+ </div>
+ </body>
+</html>