diff options
author | Yangqing Jia <jiayq84@gmail.com> | 2013-10-10 14:45:10 -0700 |
---|---|---|
committer | Yangqing Jia <jiayq84@gmail.com> | 2013-10-10 14:46:02 -0700 |
commit | 7d3fcf92abaabd3e2acea9b1d294cc01dbb855e6 (patch) | |
tree | 91a0ddbc06244194b795c539fc6b389cb4d79380 /src | |
parent | 3f0efbd2af1a8f23407e90b67e62df9cd6fae583 (diff) | |
download | caffeonacl-7d3fcf92abaabd3e2acea9b1d294cc01dbb855e6.tar.gz caffeonacl-7d3fcf92abaabd3e2acea9b1d294cc01dbb855e6.tar.bz2 caffeonacl-7d3fcf92abaabd3e2acea9b1d294cc01dbb855e6.zip |
misc update
Diffstat (limited to 'src')
-rw-r--r-- | src/Makefile | 5 | ||||
-rw-r--r-- | src/caffe/layers/data_layer.cpp | 53 | ||||
-rw-r--r-- | src/caffe/optimization/solver.cpp | 14 | ||||
-rw-r--r-- | src/caffe/proto/caffe.proto | 2 | ||||
-rw-r--r-- | src/caffe/util/im2col.cu | 15 | ||||
-rw-r--r-- | src/programs/imagenet.prototxt | 332 | ||||
-rw-r--r-- | src/programs/imagenet_solver.prototxt | 10 | ||||
-rw-r--r-- | src/programs/net_speed_benchmark.cpp | 62 | ||||
-rw-r--r-- | src/programs/train_alexnet.cpp | 73 | ||||
-rw-r--r-- | src/programs/train_net.cpp | 39 |
10 files changed, 509 insertions, 96 deletions
diff --git a/src/Makefile b/src/Makefile index d78e99bc..27fd7603 100644 --- a/src/Makefile +++ b/src/Makefile @@ -42,12 +42,13 @@ LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread curand \ leveldb snappy opencv_core opencv_highgui pthread tcmalloc WARNINGS := -Wall -CXXFLAGS += -fPIC $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) +CXXFLAGS += -pthread -fPIC -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) +NVCCFLAGS := -Xcompiler -fPIC -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) LDFLAGS += $(foreach library,$(LIBRARIES),-l$(library)) LINK = $(CXX) $(CXXFLAGS) $(CPPFLAGS) $(LDFLAGS) $(WARNINGS) -NVCC = nvcc ${CXXFLAGS:-fPIC=-Xcompiler -fPIC} $(CPPFLAGS) $(CUDA_ARCH) +NVCC = nvcc $(NVCCFLAGS) $(CPPFLAGS) $(CUDA_ARCH) .PHONY: all test clean distclean linecount program diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 5b957701..7993a43c 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -24,39 +24,64 @@ void* DataLayerPrefetch(void* layer_pointer) { const Dtype subtraction = layer->layer_param_.subtraction(); const int batchsize = layer->layer_param_.batchsize(); const int cropsize = layer->layer_param_.cropsize(); + const bool mirror = layer->layer_param_.mirror(); + if (mirror && cropsize == 0) { + LOG(FATAL) << "Current implementation requires mirror and cropsize to be " + << "set at the same time."; + } + // datum scales + const int channels = layer->datum_channels_; + const int height = layer->datum_height_; + const int width = layer->datum_width_; + const int size = layer->datum_size_; for (int itemid = 0; itemid < batchsize; ++itemid) { // get a blob datum.ParseFromString(layer->iter_->value().ToString()); const string& data = datum.data(); if (cropsize) { CHECK(data.size()) << "Image cropping only support uint8 data"; - int h_offset = rand() % (layer->datum_height_ - cropsize); - int w_offset = rand() % (layer->datum_width_ - cropsize); - for (int c = 0; c < layer->datum_channels_; ++c) { - for (int h = 0; h < cropsize; ++h) { - for (int w = 0; w < cropsize; ++w) { - top_data[((itemid * layer->datum_channels_ + c) * cropsize + h) * cropsize + w] = - static_cast<Dtype>((uint8_t)data[ - (c * layer->datum_height_ + h + h_offset) * layer->datum_width_ - + w + w_offset] - ) * scale - subtraction; + int h_offset = rand() % (height - cropsize); + int w_offset = rand() % (width - cropsize); + if (mirror && rand() % 2) { + // Copy mirrored version + for (int c = 0; c < channels; ++c) { + for (int h = 0; h < cropsize; ++h) { + for (int w = 0; w < cropsize; ++w) { + top_data[((itemid * channels + c) * cropsize + h) * cropsize + cropsize - 1 - w] = + static_cast<Dtype>((uint8_t)data[ + (c * height + h + h_offset) * width + w + w_offset] + ) * scale - subtraction; + } + } + } + } else { + // Normal copy + for (int c = 0; c < channels; ++c) { + for (int h = 0; h < cropsize; ++h) { + for (int w = 0; w < cropsize; ++w) { + top_data[((itemid * channels + c) * cropsize + h) * cropsize + w] = + static_cast<Dtype>((uint8_t)data[ + (c * height + h + h_offset) * width + w + w_offset] + ) * scale - subtraction; + } } } } } else { // we will prefer to use data() first, and then try float_data() if (data.size()) { - for (int j = 0; j < layer->datum_size_; ++j) { - top_data[itemid * layer->datum_size_ + j] = + for (int j = 0; j < size; ++j) { + top_data[itemid * size + j] = (static_cast<Dtype>((uint8_t)data[j]) * scale) - subtraction; } } else { - for (int j = 0; j < layer->datum_size_; ++j) { - top_data[itemid * layer->datum_size_ + j] = + for (int j = 0; j < size; ++j) { + top_data[itemid * size + j] = (datum.float_data(j) * scale) - subtraction; } } } + top_label[itemid] = datum.label(); // go to the next iter layer->iter_->Next(); diff --git a/src/caffe/optimization/solver.cpp b/src/caffe/optimization/solver.cpp index 1afe2936..d9ab2c1b 100644 --- a/src/caffe/optimization/solver.cpp +++ b/src/caffe/optimization/solver.cpp @@ -1,5 +1,7 @@ // Copyright Yangqing Jia 2013 +#include <cstdio> + #include <algorithm> #include <string> #include <vector> @@ -34,7 +36,7 @@ void Solver<Dtype>::Solve(Net<Dtype>* net) { if (param_.snapshot() > 0 && iter_ % param_.snapshot() == 0) { Snapshot(false); } - if (param_.display() && iter_ % param_.display()) { + if (param_.display() && iter_ % param_.display() == 0) { LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss; } } @@ -47,14 +49,14 @@ void Solver<Dtype>::Snapshot(bool is_final) { NetParameter net_param; // For intermediate results, we will also dump the gradient values. net_->ToProto(&net_param, !is_final); - stringstream ss; - ss << param_.snapshot_prefix(); + string filename(param_.snapshot_prefix()); if (is_final) { - ss << "_final"; + filename += "_final"; } else { - ss << "_iter_" << iter_; + char iter_str_buffer[20]; + sprintf(iter_str_buffer, "_iter_%d", iter_); + filename += iter_str_buffer; } - string filename = ss.str(); LOG(ERROR) << "Snapshotting to " << filename; WriteProtoToBinaryFile(net_param, filename.c_str()); } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 048144c8..afefccab 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -66,6 +66,8 @@ message LayerParameter { optional uint32 batchsize = 19; // For data layers, specify if we would like to randomly crop an image. optional uint32 cropsize = 20 [default = 0]; + // For data layers, specify if we want to randomly mirror data. + optional bool mirror = 21 [default = false]; // The blobs containing the numeric parameters of the layer repeated BlobProto blobs = 50; diff --git a/src/caffe/util/im2col.cu b/src/caffe/util/im2col.cu index 81ac3c27..0b0c8b83 100644 --- a/src/caffe/util/im2col.cu +++ b/src/caffe/util/im2col.cu @@ -9,6 +9,7 @@ namespace caffe { + template <typename Dtype> __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, const int height, const int width, const int ksize, @@ -48,6 +49,7 @@ void im2col_gpu(const Dtype* data_im, const int channels, CUDA_POST_KERNEL_CHECK; } + // Explicit instantiation template void im2col_gpu<float>(const float* data_im, const int channels, const int height, const int width, const int ksize, const int stride, @@ -71,13 +73,24 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, int w_col_end = min(w / stride + 1, width_col); int h_col_start = (h < ksize) ? 0 : (h - ksize) / stride + 1; int h_col_end = min(h / stride + 1, height_col); + /* for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { // the col location: [c * width * height + h_out, w_out] - int c_col = c * ksize * ksize + (h - h_col * stride) * ksize + (w - w_col * stride); + int c_col = c * ksize * ksize + (h - h_col * stride) * ksize + (w - w_col * stride); val += data_col[(c_col * height_col + h_col) * width_col + w_col]; } } + */ + // equivalent implementation + int offset = (c * ksize * ksize + h * ksize + w) * height_col * width_col; + int coeff_h_col = (1 - stride * ksize * height_col) * width_col; + int coeff_w_col = (1 - stride * height_col * width_col); + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col]; + } + } data_im[index] = val; } } diff --git a/src/programs/imagenet.prototxt b/src/programs/imagenet.prototxt new file mode 100644 index 00000000..65b7432c --- /dev/null +++ b/src/programs/imagenet.prototxt @@ -0,0 +1,332 @@ +name: "CaffeNet" +layers { + layer { + name: "data" + type: "data" + source: "/home/jiayq/caffe-train-leveldb" + batchsize: 96 + subtraction: 114 + cropsize: 227 + mirror: true + } + top: "data" + top: "label" +} +layers { + layer { + name: "conv1" + type: "conv" + num_output: 96 + kernelsize: 11 + stride: 4 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } + bottom: "data" + top: "conv1" +} +layers { + layer { + name: "relu1" + type: "relu" + } + bottom: "conv1" + top: "relu1" +} +layers { + layer { + name: "pool1" + type: "pool" + pool: MAX + kernelsize: 3 + stride: 2 + } + bottom: "relu1" + top: "pool1" +} +layers { + layer { + name: "norm1" + type: "lrn" + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } + bottom: "pool1" + top: "norm1" +} +layers { + layer { + name: "pad2" + type: "padding" + pad: 2 + } + bottom: "norm1" + top: "pad2" +} +layers { + layer { + name: "conv2" + type: "conv" + num_output: 256 + group: 2 + kernelsize: 5 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } + bottom: "pad2" + top: "conv2" +} +layers { + layer { + name: "relu2" + type: "relu" + } + bottom: "conv2" + top: "relu2" +} +layers { + layer { + name: "pool2" + type: "pool" + pool: MAX + kernelsize: 3 + stride: 2 + } + bottom: "relu2" + top: "pool2" +} +layers { + layer { + name: "norm2" + type: "lrn" + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } + bottom: "pool2" + top: "norm2" +} +layers { + layer { + name: "pad3" + type: "padding" + pad: 1 + } + bottom: "norm2" + top: "pad3" +} +layers { + layer { + name: "conv3" + type: "conv" + num_output: 384 + kernelsize: 3 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } + bottom: "pad3" + top: "conv3" +} +layers { + layer { + name: "relu3" + type: "relu" + } + bottom: "conv3" + top: "relu3" +} +layers { + layer { + name: "pad4" + type: "padding" + pad: 1 + } + bottom: "relu3" + top: "pad4" +} +layers { + layer { + name: "conv4" + type: "conv" + num_output: 384 + group: 2 + kernelsize: 3 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } + bottom: "pad4" + top: "conv4" +} +layers { + layer { + name: "relu4" + type: "relu" + } + bottom: "conv4" + top: "relu4" +} +layers { + layer { + name: "pad5" + type: "padding" + pad: 1 + } + bottom: "relu4" + top: "pad5" +} +layers { + layer { + name: "conv5" + type: "conv" + num_output: 256 + group: 2 + kernelsize: 3 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } + bottom: "pad5" + top: "conv5" +} +layers { + layer { + name: "relu5" + type: "relu" + } + bottom: "conv5" + top: "relu5" +} +layers { + layer { + name: "pool5" + type: "pool" + kernelsize: 3 + pool: MAX + stride: 2 + } + bottom: "relu5" + top: "pool5" +} +layers { + layer { + name: "fc6" + type: "innerproduct" + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } + bottom: "pool5" + top: "fc6" +} +layers { + layer { + name: "relu6" + type: "relu" + } + bottom: "fc6" + top: "relu6" +} +layers { + layer { + name: "drop6" + type: "dropout" + dropout_ratio: 0.5 + } + bottom: "relu6" + top: "drop6" +} +layers { + layer { + name: "fc7" + type: "innerproduct" + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } + bottom: "drop6" + top: "fc7" +} +layers { + layer { + name: "relu7" + type: "relu" + } + bottom: "fc7" + top: "relu7" +} +layers { + layer { + name: "drop7" + type: "dropout" + dropout_ratio: 0.5 + } + bottom: "relu7" + top: "drop7" +} +layers { + layer { + name: "fc8" + type: "innerproduct" + num_output: 1000 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } + bottom: "drop7" + top: "fc8" +} +layers { + layer { + name: "loss" + type: "softmax_loss" + } + bottom: "fc8" + bottom: "label" +}
\ No newline at end of file diff --git a/src/programs/imagenet_solver.prototxt b/src/programs/imagenet_solver.prototxt new file mode 100644 index 00000000..58b0dfef --- /dev/null +++ b/src/programs/imagenet_solver.prototxt @@ -0,0 +1,10 @@ +base_lr: 0.02 +lr_policy: "step" +gamma: 0.1 +stepsize: 340000 +display: 100 +max_iter: 1200000 +momentum: 0.9 +weight_decay: 0.0005 +snapshot: 15000 +snapshot_prefix: "alexnet_train"
\ No newline at end of file diff --git a/src/programs/net_speed_benchmark.cpp b/src/programs/net_speed_benchmark.cpp new file mode 100644 index 00000000..560c5d87 --- /dev/null +++ b/src/programs/net_speed_benchmark.cpp @@ -0,0 +1,62 @@ +// Copyright 2013 Yangqing Jia + +#include <cuda_runtime.h> +#include <fcntl.h> +#include <google/protobuf/text_format.h> + +#include <cstring> +#include <ctime> + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/net.hpp" +#include "caffe/filler.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/io.hpp" +#include "caffe/optimization/solver.hpp" + +using namespace caffe; + +int main(int argc, char** argv) { + cudaSetDevice(1); + Caffe::set_mode(Caffe::GPU); + Caffe::set_phase(Caffe::TRAIN); + int repeat = 100; + + NetParameter net_param; + ReadProtoFromTextFile(argv[1], + &net_param); + vector<Blob<float>*> bottom_vec; + Net<float> caffe_net(net_param, bottom_vec); + + // Run the network without training. + LOG(ERROR) << "Performing Forward"; + caffe_net.Forward(bottom_vec); + LOG(ERROR) << "Performing Backward"; + LOG(ERROR) << "Initial loss: " << caffe_net.Backward(); + + const vector<shared_ptr<Layer<float> > >& layers = caffe_net.layers(); + vector<vector<Blob<float>*> >& bottom_vecs = caffe_net.bottom_vecs(); + vector<vector<Blob<float>*> >& top_vecs = caffe_net.top_vecs(); + LOG(ERROR) << "*** Benchmark begins ***"; + for (int i = 0; i < layers.size(); ++i) { + const string& layername = layers[i]->layer_param().name(); + clock_t start = clock(); + for (int j = 0; j < repeat; ++j) { + layers[i]->Forward(bottom_vecs[i], &top_vecs[i]); + } + LOG(ERROR) << layername << "\tforward: " + << float(clock() - start) / CLOCKS_PER_SEC << " seconds."; + } + for (int i = layers.size() - 1; i >= 0; --i) { + const string& layername = layers[i]->layer_param().name(); + clock_t start = clock(); + for (int j = 0; j < repeat; ++j) { + layers[i]->Backward(top_vecs[i], true, &bottom_vecs[i]); + } + LOG(ERROR) << layername << "\tbackward: " + << float(clock() - start) / CLOCKS_PER_SEC << " seconds."; + } + LOG(ERROR) << "*** Benchmark ends ***"; + return 0; +} diff --git a/src/programs/train_alexnet.cpp b/src/programs/train_alexnet.cpp deleted file mode 100644 index 99e37c3e..00000000 --- a/src/programs/train_alexnet.cpp +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2013 Yangqing Jia - -#include <cuda_runtime.h> -#include <fcntl.h> -#include <google/protobuf/text_format.h> - -#include <cstring> - -#include "caffe/blob.hpp" -#include "caffe/common.hpp" -#include "caffe/net.hpp" -#include "caffe/filler.hpp" -#include "caffe/proto/caffe.pb.h" -#include "caffe/util/io.hpp" -#include "caffe/optimization/solver.hpp" - -using namespace caffe; - -int main(int argc, char** argv) { - cudaSetDevice(0); - Caffe::set_mode(Caffe::GPU); - Caffe::set_phase(Caffe::TRAIN); - - NetParameter net_param; - ReadProtoFromTextFile(argv[1], - &net_param); - vector<Blob<float>*> bottom_vec; - Net<float> caffe_net(net_param, bottom_vec); - - // Run the network without training. - LOG(ERROR) << "Performing Forward"; - caffe_net.Forward(bottom_vec); - LOG(ERROR) << "Performing Backward"; - LOG(ERROR) << "Initial loss: " << caffe_net.Backward(); - - /* - // Now, let's dump all the layers - string output_prefix("alexnet_initial_dump_"); - const vector<string>& blob_names = caffe_net.blob_names(); - const vector<shared_ptr<Blob<float> > >& blobs = caffe_net.blobs(); - for (int blobid = 0; blobid < caffe_net.blobs().size(); ++blobid) { - // Serialize blob - LOG(ERROR) << "Dumping " << blob_names[blobid]; - BlobProto output_blob_proto; - blobs[blobid]->ToProto(&output_blob_proto); - WriteProtoToBinaryFile(output_blob_proto, output_prefix + blob_names[blobid]); - } - */ - - SolverParameter solver_param; - solver_param.set_base_lr(0.01); - solver_param.set_display(1); - solver_param.set_max_iter(60000); - solver_param.set_lr_policy("fixed"); - solver_param.set_momentum(0.9); - solver_param.set_weight_decay(0.0005); - solver_param.set_snapshot(5000); - solver_param.set_snapshot_prefix("alexnet"); - - LOG(ERROR) << "Starting Optimization"; - SGDSolver<float> solver(solver_param); - solver.Solve(&caffe_net); - LOG(ERROR) << "Optimization Done."; - - // Run the network after training. - LOG(ERROR) << "Performing Forward"; - caffe_net.Forward(bottom_vec); - LOG(ERROR) << "Performing Backward"; - float loss = caffe_net.Backward(); - LOG(ERROR) << "Final loss: " << loss; - - return 0; -} diff --git a/src/programs/train_net.cpp b/src/programs/train_net.cpp new file mode 100644 index 00000000..41110430 --- /dev/null +++ b/src/programs/train_net.cpp @@ -0,0 +1,39 @@ +// Copyright 2013 Yangqing Jia + +#include <cuda_runtime.h> +#include <fcntl.h> +#include <google/protobuf/text_format.h> + +#include <cstring> + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/net.hpp" +#include "caffe/filler.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/io.hpp" +#include "caffe/optimization/solver.hpp" + +using namespace caffe; + +int main(int argc, char** argv) { + cudaSetDevice(0); + Caffe::set_mode(Caffe::GPU); + Caffe::set_phase(Caffe::TRAIN); + + NetParameter net_param; + ReadProtoFromTextFile(argv[1], + &net_param); + vector<Blob<float>*> bottom_vec; + Net<float> caffe_net(net_param, bottom_vec); + + SolverParameter solver_param; + ReadProtoFromTextFile(argv[2], &solver_param); + + LOG(ERROR) << "Starting Optimization"; + SGDSolver<float> solver(solver_param); + solver.Solve(&caffe_net); + LOG(ERROR) << "Optimization Done."; + + return 0; +} |