summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorYangqing Jia <jiayq84@gmail.com>2013-10-14 11:29:54 -0700
committerYangqing Jia <jiayq84@gmail.com>2013-10-14 11:29:54 -0700
commitd566c0bcbac805ff747563ded7ce860b9bde49c9 (patch)
treee17f5efb02c7a80ebede21b4d9dfb1b28def009c /src
parent20dde61d73785f25ac19acf542539b6cf2ed297c (diff)
parent8500c5537f7a1193b1437f41b34478af77b051fd (diff)
downloadcaffe-d566c0bcbac805ff747563ded7ce860b9bde49c9.tar.gz
caffe-d566c0bcbac805ff747563ded7ce860b9bde49c9.tar.bz2
caffe-d566c0bcbac805ff747563ded7ce860b9bde49c9.zip
Merge branch 'master' of github.com:Yangqing/caffe
Diffstat (limited to 'src')
-rw-r--r--src/Makefile8
-rw-r--r--src/caffe/common.hpp8
-rw-r--r--src/caffe/layers/data_layer.cpp45
-rw-r--r--src/caffe/layers/inner_product_layer.cpp2
-rw-r--r--src/caffe/layers/pooling_layer.cpp7
-rw-r--r--src/caffe/layers/pooling_layer.cu5
-rw-r--r--src/caffe/layers/relu_layer.cu2
-rw-r--r--src/caffe/net.cpp18
-rw-r--r--src/caffe/net.hpp7
-rw-r--r--src/caffe/optimization/solver.cpp19
-rw-r--r--src/caffe/proto/caffe.proto3
-rw-r--r--src/caffe/test/test_simple_conv.cpp54
-rw-r--r--src/caffe/test/test_solver_linear_regression.cpp65
-rw-r--r--src/caffe/util/blob_math.cpp.working0
-rw-r--r--src/caffe/util/blob_math.hpp104
-rw-r--r--src/caffe/util/im2col.hpp2
-rw-r--r--src/caffe/util/io.cpp75
-rw-r--r--src/caffe/util/io.hpp20
-rw-r--r--src/programs/convert_dataset.cpp81
-rw-r--r--src/programs/demo_mnist.cpp4
-rw-r--r--src/programs/dump_network.cpp43
-rw-r--r--src/programs/imagenet.prototxt18
-rw-r--r--src/programs/imagenet_solver.prototxt6
23 files changed, 184 insertions, 412 deletions
diff --git a/src/Makefile b/src/Makefile
index 6f06866f..60aa139e 100644
--- a/src/Makefile
+++ b/src/Makefile
@@ -31,7 +31,7 @@ TEST_BINS := ${TEST_OBJS:.o=.testbin}
# define third-party library paths
CUDA_DIR := /usr/local/cuda
-CUDA_ARCH := -arch=sm_20
+CUDA_ARCH := -arch=sm_30
MKL_DIR := /opt/intel/mkl
CUDA_INCLUDE_DIR := $(CUDA_DIR)/include
@@ -43,7 +43,7 @@ MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64
INCLUDE_DIRS := . /usr/local/include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR)
LIBRARY_DIRS := . /usr/lib /usr/local/lib $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread curand \
- leveldb snappy opencv_core opencv_highgui pthread tcmalloc
+ leveldb snappy pthread tcmalloc
WARNINGS := -Wall
COMMON_FLAGS := $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
@@ -80,6 +80,10 @@ $(TEST_BINS): %.testbin : %.o $(GTEST_OBJ) $(STATIC_NAME)
$(PROGRAM_BINS): %.bin : %.o $(STATIC_NAME)
$(CXX) $< $(STATIC_NAME) -o $@ $(LDFLAGS) $(WARNINGS)
+$(OBJS): $(PROTO_GEN_CC)
+
+$(PROGRAM_OBJS): $(PROTO_GEN_CC)
+
$(CU_OBJS): %.cuo: %.cu
$(NVCC) -c $< -o $@
diff --git a/src/caffe/common.hpp b/src/caffe/common.hpp
index 39e417f7..8eb79876 100644
--- a/src/caffe/common.hpp
+++ b/src/caffe/common.hpp
@@ -43,18 +43,14 @@ private:\
namespace caffe {
-// Two classes whose purpose are solely for instantiating blob template
-// functions.
-class GPUBrewer {};
-class CPUBrewer {};
// We will use the boost shared_ptr instead of the new C++11 one mainly
// because cuda does not work (at least now) well with C++11 features.
using boost::shared_ptr;
-// For backward compatibility we will just use 512 threads per block
-const int CAFFE_CUDA_NUM_THREADS = 512;
+// We will use 1024 threads per block, which requires cuda sm_2x or above.
+const int CAFFE_CUDA_NUM_THREADS = 1024;
inline int CAFFE_GET_BLOCKS(const int N) {
diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp
index c44fbd6f..daa65e6e 100644
--- a/src/caffe/layers/data_layer.cpp
+++ b/src/caffe/layers/data_layer.cpp
@@ -47,10 +47,12 @@ void* DataLayerPrefetch(void* layer_pointer) {
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < cropsize; ++h) {
for (int w = 0; w < cropsize; ++w) {
- top_data[((itemid * channels + c) * cropsize + h) * cropsize + cropsize - 1 - w] =
- static_cast<Dtype>((uint8_t)data[
- (c * height + h + h_offset) * width + w + w_offset]
- ) * scale - subtraction;
+ top_data[((itemid * channels + c) * cropsize + h) * cropsize
+ + cropsize - 1 - w] =
+ static_cast<Dtype>(
+ (uint8_t)data[(c * height + h + h_offset) * width
+ + w + w_offset])
+ * scale - subtraction;
}
}
}
@@ -59,10 +61,11 @@ void* DataLayerPrefetch(void* layer_pointer) {
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < cropsize; ++h) {
for (int w = 0; w < cropsize; ++w) {
- top_data[((itemid * channels + c) * cropsize + h) * cropsize + w] =
- static_cast<Dtype>((uint8_t)data[
- (c * height + h + h_offset) * width + w + w_offset]
- ) * scale - subtraction;
+ top_data[((itemid * channels + c) * cropsize + h) * cropsize + w]
+ = static_cast<Dtype>(
+ (uint8_t)data[(c * height + h + h_offset) * width
+ + w + w_offset])
+ * scale - subtraction;
}
}
}
@@ -146,10 +149,10 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
CHECK_GT(datum_height_, cropsize);
CHECK_GT(datum_width_, cropsize);
// Now, start the prefetch thread.
- //LOG(INFO) << "Initializing prefetch";
- CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>, (void*)this))
- << "Pthread execution failed.";
- //LOG(INFO) << "Prefetch initialized.";
+ // LOG(INFO) << "Initializing prefetch";
+ CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
+ reinterpret_cast<void*>(this))) << "Pthread execution failed.";
+ // LOG(INFO) << "Prefetch initialized.";
}
template <typename Dtype>
@@ -163,8 +166,8 @@ void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
memcpy((*top)[1]->mutable_cpu_data(), prefetch_label_->cpu_data(),
sizeof(Dtype) * prefetch_label_->count());
// Start a new prefetch thread
- CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>, (void*)this))
- << "Pthread execution failed.";
+ CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
+ reinterpret_cast<void*>(this))) << "Pthread execution failed.";
}
template <typename Dtype>
@@ -173,13 +176,15 @@ void DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
// First, join the thread
CHECK(!pthread_join(thread_, NULL)) << "Pthread joining failed.";
// Copy the data
- CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(), prefetch_data_->cpu_data(),
- sizeof(Dtype) * prefetch_data_->count(), cudaMemcpyHostToDevice));
- CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(), prefetch_label_->cpu_data(),
- sizeof(Dtype) * prefetch_label_->count(), cudaMemcpyHostToDevice));
+ CUDA_CHECK(cudaMemcpy((*top)[0]->mutable_gpu_data(),
+ prefetch_data_->cpu_data(), sizeof(Dtype) * prefetch_data_->count(),
+ cudaMemcpyHostToDevice));
+ CUDA_CHECK(cudaMemcpy((*top)[1]->mutable_gpu_data(),
+ prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
+ cudaMemcpyHostToDevice));
// Start a new prefetch thread
- CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>, (void*)this))
- << "Pthread execution failed.";
+ CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
+ reinterpret_cast<void*>(this))) << "Pthread execution failed.";
}
// The backward operations are dummy - they do not carry any computation.
diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp
index ef985936..18f1df0d 100644
--- a/src/caffe/layers/inner_product_layer.cpp
+++ b/src/caffe/layers/inner_product_layer.cpp
@@ -49,7 +49,7 @@ void InnerProductLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
GetFiller<Dtype>(this->layer_param_.bias_filler()));
bias_filler->Fill(this->blobs_[1].get());
}
- } // parameter initialization
+ } // parameter initialization
// Setting up the bias multiplier
if (biasterm_) {
bias_multiplier_.reset(new SyncedMemory(M_ * sizeof(Dtype)));
diff --git a/src/caffe/layers/pooling_layer.cpp b/src/caffe/layers/pooling_layer.cpp
index 7de2a643..59ce3fe7 100644
--- a/src/caffe/layers/pooling_layer.cpp
+++ b/src/caffe/layers/pooling_layer.cpp
@@ -8,8 +8,6 @@
#include "caffe/vision_layers.hpp"
#include "caffe/util/math_functions.hpp"
-#define CAFFE_MAX_POOLING_THRESHOLD 1e-8f
-
using std::max;
using std::min;
@@ -135,9 +133,8 @@ Dtype PoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
for (int w = wstart; w < wend; ++w) {
bottom_diff[h * WIDTH_ + w] +=
top_diff[ph * POOLED_WIDTH_ + pw] *
- (bottom_data[h * WIDTH_ + w] >=
- top_data[ph * POOLED_WIDTH_ + pw] -
- CAFFE_MAX_POOLING_THRESHOLD);
+ (bottom_data[h * WIDTH_ + w] ==
+ top_data[ph * POOLED_WIDTH_ + pw]);
}
}
}
diff --git a/src/caffe/layers/pooling_layer.cu b/src/caffe/layers/pooling_layer.cu
index 706ee156..9d15c534 100644
--- a/src/caffe/layers/pooling_layer.cu
+++ b/src/caffe/layers/pooling_layer.cu
@@ -6,8 +6,6 @@
#include "caffe/vision_layers.hpp"
#include "caffe/util/math_functions.hpp"
-#define CAFFE_MAX_POOLING_THRESHOLD 1e-8f
-
using std::max;
using std::min;
@@ -116,8 +114,7 @@ __global__ void MaxPoolBackward(const int nthreads, const Dtype* bottom_data,
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
gradient += top_diff[ph * pooled_width + pw] *
- (bottom_datum >= top_data[ph * pooled_width + pw] -
- CAFFE_MAX_POOLING_THRESHOLD);
+ (bottom_datum == top_data[ph * pooled_width + pw]);
}
}
bottom_diff[index] = gradient;
diff --git a/src/caffe/layers/relu_layer.cu b/src/caffe/layers/relu_layer.cu
index 8613b3bc..b0fc46ef 100644
--- a/src/caffe/layers/relu_layer.cu
+++ b/src/caffe/layers/relu_layer.cu
@@ -39,7 +39,7 @@ template <typename Dtype>
__global__ void ReLUForward(const int n, const Dtype* in, Dtype* out) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
- out[index] = max(in[index], Dtype(0.));
+ out[index] = in[index] > 0 ? in[index] : 0;
}
}
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
index 8190d8ee..ff1cca4b 100644
--- a/src/caffe/net.cpp
+++ b/src/caffe/net.cpp
@@ -85,7 +85,7 @@ Net<Dtype>::Net(const NetParameter& param,
net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get());
}
- LOG(ERROR) << "Setting up the layers.";
+ LOG(INFO) << "Setting up the layers.";
for (int i = 0; i < layers_.size(); ++i) {
LOG(INFO) << "Setting up " << layer_names_[i];
layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
@@ -93,14 +93,26 @@ Net<Dtype>::Net(const NetParameter& param,
for (int j = 0; j < layer_blobs.size(); ++j) {
params_.push_back(layer_blobs[j]);
}
+ // push the learning rate mutlipliers
+ if (layers_[i]->layer_param().blobs_lr_size()) {
+ CHECK_EQ(layers_[i]->layer_param().blobs_lr_size(), layer_blobs.size());
+ for (int j = 0; j < layer_blobs.size(); ++j) {
+ float local_lr = layers_[i]->layer_param().blobs_lr(j);
+ CHECK_GT(local_lr, 0.);
+ params_lr_.push_back(local_lr);
+ }
+ } else {
+ for (int j = 0; j < layer_blobs.size(); ++j) {
+ params_lr_.push_back(1.);
+ }
+ }
for (int topid = 0; topid < top_vecs_[i].size(); ++topid) {
LOG(INFO) << "Top shape: " << top_vecs_[i][topid]->channels() << " "
<< top_vecs_[i][topid]->height() << " "
<< top_vecs_[i][topid]->width();
}
}
-
- LOG(ERROR) << "Network initialization done.";
+ LOG(INFO) << "Network initialization done.";
}
template <typename Dtype>
diff --git a/src/caffe/net.hpp b/src/caffe/net.hpp
index 24bef4bb..c27442b8 100644
--- a/src/caffe/net.hpp
+++ b/src/caffe/net.hpp
@@ -57,7 +57,9 @@ class Net {
inline vector<vector<Blob<Dtype>*> >& bottom_vecs() { return bottom_vecs_; }
inline vector<vector<Blob<Dtype>*> >& top_vecs() { return top_vecs_; }
// returns the parameters
- vector<shared_ptr<Blob<Dtype> > >& params() { return params_; }
+ inline vector<shared_ptr<Blob<Dtype> > >& params() { return params_; }
+ // returns the parameter learning rate multipliers
+ inline vector<float>& params_lr() {return params_lr_; }
// Updates the network
void Update();
@@ -82,7 +84,8 @@ class Net {
string name_;
// The parameters in the network.
vector<shared_ptr<Blob<Dtype> > > params_;
-
+ // the learning rate multipliers
+ vector<float> params_lr_;
DISABLE_COPY_AND_ASSIGN(Net);
};
diff --git a/src/caffe/optimization/solver.cpp b/src/caffe/optimization/solver.cpp
index d9ab2c1b..b2a57600 100644
--- a/src/caffe/optimization/solver.cpp
+++ b/src/caffe/optimization/solver.cpp
@@ -17,7 +17,6 @@ using std::min;
namespace caffe {
-
template <typename Dtype>
void Solver<Dtype>::Solve(Net<Dtype>* net) {
net_ = net;
@@ -112,22 +111,28 @@ void SGDSolver<Dtype>::PreSolve() {
template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue() {
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
+ vector<float>& net_params_lr = this->net_->params_lr();
// get the learning rate
Dtype rate = GetLearningRate();
+ if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
+ LOG(ERROR) << "Iteration " << this->iter_ << ", lr = " << rate;
+ }
Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay();
// LOG(ERROR) << "rate:" << rate << " momentum:" << momentum
- // << " weight_decay:" << weight_decay;
+ // << " weight_decay:" << weight_decay;
switch (Caffe::mode()) {
case Caffe::CPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
// Compute the value to history, and then copy them to the blob's diff.
- caffe_axpby(net_params[param_id]->count(), rate,
+ Dtype local_rate = rate * net_params_lr[param_id];
+ caffe_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
if (weight_decay) {
// add weight decay
- caffe_axpy(net_params[param_id]->count(), weight_decay * rate,
+ caffe_axpy(net_params[param_id]->count(),
+ weight_decay * local_rate,
net_params[param_id]->cpu_data(),
history_[param_id]->mutable_cpu_data());
}
@@ -140,12 +145,14 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
case Caffe::GPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
// Compute the value to history, and then copy them to the blob's diff.
- caffe_gpu_axpby(net_params[param_id]->count(), rate,
+ Dtype local_rate = rate * net_params_lr[param_id];
+ caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->gpu_diff(), momentum,
history_[param_id]->mutable_gpu_data());
if (weight_decay) {
// add weight decay
- caffe_gpu_axpy(net_params[param_id]->count(), weight_decay * rate,
+ caffe_gpu_axpy(net_params[param_id]->count(),
+ weight_decay * local_rate,
net_params[param_id]->gpu_data(),
history_[param_id]->mutable_gpu_data());
}
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index afefccab..87f2c2cc 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -71,6 +71,9 @@ message LayerParameter {
// The blobs containing the numeric parameters of the layer
repeated BlobProto blobs = 50;
+ // The ratio that is multiplied on the global learning rate. If you want to set
+ // the learning ratio for one blob, you need to set it for all blobs.
+ repeated float blobs_lr = 51;
}
message LayerConnection {
diff --git a/src/caffe/test/test_simple_conv.cpp b/src/caffe/test/test_simple_conv.cpp
deleted file mode 100644
index f5fe489e..00000000
--- a/src/caffe/test/test_simple_conv.cpp
+++ /dev/null
@@ -1,54 +0,0 @@
-// Copyright 2013 Yangqing Jia
-
-#include <gtest/gtest.h>
-
-#include <cstring>
-
-#include "caffe/common.hpp"
-#include "caffe/blob.hpp"
-#include "caffe/net.hpp"
-#include "caffe/proto/caffe.pb.h"
-#include "caffe/util/io.hpp"
-
-#include "caffe/test/test_caffe_main.hpp"
-
-namespace caffe {
-
-template <typename Dtype>
-class NetProtoTest : public ::testing::Test {};
-
-typedef ::testing::Types<float> Dtypes;
-TYPED_TEST_CASE(NetProtoTest, Dtypes);
-
-TYPED_TEST(NetProtoTest, TestLoadFromText) {
- NetParameter net_param;
- ReadProtoFromTextFile("data/simple_conv.prototxt", &net_param);
- Blob<TypeParam> lena_image;
- ReadImageToBlob<TypeParam>(string("data/lena_256.jpg"), &lena_image);
- vector<Blob<TypeParam>*> bottom_vec;
- bottom_vec.push_back(&lena_image);
-
- for (int i = 0; i < lena_image.count(); ++i) {
- EXPECT_GE(lena_image.cpu_data()[i], 0);
- EXPECT_LE(lena_image.cpu_data()[i], 1);
- }
-
- Caffe::set_mode(Caffe::CPU);
- // Initialize the network, and then does smoothing
- Net<TypeParam> caffe_net(net_param, bottom_vec);
- LOG(ERROR) << "Start Forward.";
- const vector<Blob<TypeParam>*>& output = caffe_net.Forward(bottom_vec);
- LOG(ERROR) << "Forward Done.";
- EXPECT_EQ(output[0]->num(), 1);
- EXPECT_EQ(output[0]->channels(), 1);
- EXPECT_EQ(output[0]->height(), 252);
- EXPECT_EQ(output[0]->width(), 252);
- for (int i = 0; i < output[0]->count(); ++i) {
- EXPECT_GE(output[0]->cpu_data()[i], 0);
- EXPECT_LE(output[0]->cpu_data()[i], 1);
- }
- WriteBlobToImage<TypeParam>(string("lena_smoothed.png"), *output[0]);
-}
-
-
-} // namespace caffe
diff --git a/src/caffe/test/test_solver_linear_regression.cpp b/src/caffe/test/test_solver_linear_regression.cpp
index fbd53f32..8fd504b0 100644
--- a/src/caffe/test/test_solver_linear_regression.cpp
+++ b/src/caffe/test/test_solver_linear_regression.cpp
@@ -26,7 +26,7 @@ class SolverTest : public ::testing::Test {};
typedef ::testing::Types<float, double> Dtypes;
TYPED_TEST_CASE(SolverTest, Dtypes);
-TYPED_TEST(SolverTest, TestSolve) {
+TYPED_TEST(SolverTest, TestSolveGPU) {
Caffe::set_mode(Caffe::GPU);
NetParameter net_param;
@@ -41,10 +41,10 @@ TYPED_TEST(SolverTest, TestSolve) {
EXPECT_EQ(caffe_net.blob_names().size(), 3);
// Run the network without training.
- LOG(ERROR) << "Performing Forward";
+ LOG(INFO) << "Performing Forward";
caffe_net.Forward(bottom_vec);
- LOG(ERROR) << "Performing Backward";
- LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
+ LOG(INFO) << "Performing Backward";
+ LOG(INFO) << "Initial loss: " << caffe_net.Backward();
SolverParameter solver_param;
solver_param.set_base_lr(0.1);
@@ -55,13 +55,62 @@ TYPED_TEST(SolverTest, TestSolve) {
solver_param.set_power(0.75);
solver_param.set_momentum(0.9);
- LOG(ERROR) << "Starting Optimization";
+ LOG(INFO) << "Starting Optimization";
SGDSolver<TypeParam> solver(solver_param);
solver.Solve(&caffe_net);
- LOG(ERROR) << "Optimization Done.";
- LOG(ERROR) << "Weight: " << caffe_net.params()[0]->cpu_data()[0] << ", "
+ LOG(INFO) << "Optimization Done.";
+ LOG(INFO) << "Weight: " << caffe_net.params()[0]->cpu_data()[0] << ", "
<< caffe_net.params()[0]->cpu_data()[1];
- LOG(ERROR) << "Bias: " << caffe_net.params()[1]->cpu_data()[0];
+ LOG(INFO) << "Bias: " << caffe_net.params()[1]->cpu_data()[0];
+
+ EXPECT_GE(caffe_net.params()[0]->cpu_data()[0], 0.3);
+ EXPECT_LE(caffe_net.params()[0]->cpu_data()[0], 0.35);
+
+ EXPECT_GE(caffe_net.params()[0]->cpu_data()[1], 0.3);
+ EXPECT_LE(caffe_net.params()[0]->cpu_data()[1], 0.35);
+
+ EXPECT_GE(caffe_net.params()[1]->cpu_data()[0], -0.01);
+ EXPECT_LE(caffe_net.params()[1]->cpu_data()[0], 0.01);
+}
+
+
+
+TYPED_TEST(SolverTest, TestSolveCPU) {
+ Caffe::set_mode(Caffe::CPU);
+
+ NetParameter net_param;
+ ReadProtoFromTextFile("data/linear_regression.prototxt",
+ &net_param);
+ // check if things are right
+ EXPECT_EQ(net_param.layers_size(), 3);
+ EXPECT_EQ(net_param.input_size(), 0);
+ vector<Blob<TypeParam>*> bottom_vec;
+ Net<TypeParam> caffe_net(net_param, bottom_vec);
+ EXPECT_EQ(caffe_net.layer_names().size(), 3);
+ EXPECT_EQ(caffe_net.blob_names().size(), 3);
+
+ // Run the network without training.
+ LOG(INFO) << "Performing Forward";
+ caffe_net.Forward(bottom_vec);
+ LOG(INFO) << "Performing Backward";
+ LOG(INFO) << "Initial loss: " << caffe_net.Backward();
+
+ SolverParameter solver_param;
+ solver_param.set_base_lr(0.1);
+ solver_param.set_display(0);
+ solver_param.set_max_iter(100);
+ solver_param.set_lr_policy("inv");
+ solver_param.set_gamma(1.);
+ solver_param.set_power(0.75);
+ solver_param.set_momentum(0.9);
+
+ LOG(INFO) << "Starting Optimization";
+ SGDSolver<TypeParam> solver(solver_param);
+ solver.Solve(&caffe_net);
+ LOG(INFO) << "Optimization Done.";
+ LOG(INFO) << "Weight: " << caffe_net.params()[0]->cpu_data()[0] << ", "
+ << caffe_net.params()[0]->cpu_data()[1];
+ LOG(INFO) << "Bias: " << caffe_net.params()[1]->cpu_data()[0];
EXPECT_GE(caffe_net.params()[0]->cpu_data()[0], 0.3);
EXPECT_LE(caffe_net.params()[0]->cpu_data()[0], 0.35);
diff --git a/src/caffe/util/blob_math.cpp.working b/src/caffe/util/blob_math.cpp.working
deleted file mode 100644
index e69de29b..00000000
--- a/src/caffe/util/blob_math.cpp.working
+++ /dev/null
diff --git a/src/caffe/util/blob_math.hpp b/src/caffe/util/blob_math.hpp
deleted file mode 100644
index 414d6eb0..00000000
--- a/src/caffe/util/blob_math.hpp
+++ /dev/null
@@ -1,104 +0,0 @@
-// Copyright Yangqing Jia 2013
-//
-// This is a working version of the math functions that would hopefully replace
-// the cpu and gpu separate version, that would eventually replace the old
-// math_functions wrapper.
-
-#include "caffe/common.hpp"
-#include "caffe/syncedmem.hpp"
-
-namespace caffe {
-
-namespace blobmath {
-
-
-// Decaf gemm provides a simpler interface to the gemm functions, with the
-// limitation that the data has to be contiguous in memory.
-template <class Brewer, typename Dtype>
-void gemm(const CBLAS_TRANSPOSE TransA,
- const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
- const Dtype alpha, const Dtype* A, const Dtype* B, const Dtype beta,
- Dtype* C);
-
-
-
-
-template <typename Dtype>
-void caffe_cpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N,
- const Dtype alpha, const Dtype* A, const Dtype* x, const Dtype beta,
- Dtype* y);
-
-template <typename Dtype>
-void caffe_gpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N,
- const Dtype alpha, const Dtype* A, const Dtype* x, const Dtype beta,
- Dtype* y);
-
-template <typename Dtype>
-void caffe_axpy(const int N, const Dtype alpha, const Dtype* X,
- Dtype* Y);
-
-template <typename Dtype>
-void caffe_gpu_axpy(const int N, const Dtype alpha, const Dtype* X,
- Dtype* Y);
-
-template <typename Dtype>
-void caffe_axpby(const int N, const Dtype alpha, const Dtype* X,
- const Dtype beta, Dtype* Y);
-
-template <typename Dtype>
-void caffe_gpu_axpby(const int N, const Dtype alpha, const Dtype* X,
- const Dtype beta, Dtype* Y);
-
-template <typename Dtype>
-void caffe_copy(const int N, const Dtype *X, Dtype *Y);
-
-template <typename Dtype>
-void caffe_gpu_copy(const int N, const Dtype *X, Dtype *Y);
-
-template <typename Dtype>
-void caffe_scal(const int N, const Dtype alpha, Dtype *X);
-
-template <typename Dtype>
-void caffe_gpu_scal(const int N, const Dtype alpha, Dtype *X);
-
-template <typename Dtype>
-void caffe_sqr(const int N, const Dtype* a, Dtype* y);
-
-template <typename Dtype>
-void caffe_add(const int N, const Dtype* a, const Dtype* b, Dtype* y);
-
-template <typename Dtype>
-void caffe_sub(const int N, const Dtype* a, const Dtype* b, Dtype* y);
-
-template <typename Dtype>
-void caffe_mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
-
-template <typename Dtype>
-void caffe_gpu_mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
-
-template <typename Dtype>
-void caffe_div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
-
-template <typename Dtype>
-void caffe_powx(const int n, const Dtype* a, const Dtype b, Dtype* y);
-
-template <typename Dtype>
-void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b);
-
-template <typename Dtype>
-void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
- const Dtype sigma);
-
-template <typename Dtype>
-void caffe_exp(const int n, const Dtype* a, Dtype* y);
-
-template <typename Dtype>
-Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y);
-
-template <typename Dtype>
-void caffe_gpu_dot(const int n, const Dtype* x, const Dtype* y, Dtype* out);
-
-
-} // namespace blobmath
-
-} // namespace caffe \ No newline at end of file
diff --git a/src/caffe/util/im2col.hpp b/src/caffe/util/im2col.hpp
index 4e79dba3..83c01dda 100644
--- a/src/caffe/util/im2col.hpp
+++ b/src/caffe/util/im2col.hpp
@@ -1,6 +1,6 @@
// Copyright 2013 Yangqing Jia
-#ifndef _CAFFE_UTIL__IM2COL_HPP_
+#ifndef _CAFFE_UTIL_IM2COL_HPP_
#define _CAFFE_UTIL_IM2COL_HPP_
namespace caffe {
diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp
index 9e91000f..5a88b9d2 100644
--- a/src/caffe/util/io.cpp
+++ b/src/caffe/util/io.cpp
@@ -5,8 +5,6 @@
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/io/coded_stream.h>
-#include <opencv2/core/core.hpp>
-#include <opencv2/highgui/highgui.hpp>
#include <algorithm>
#include <string>
@@ -17,8 +15,6 @@
#include "caffe/util/io.hpp"
#include "caffe/proto/caffe.pb.h"
-using cv::Mat;
-using cv::Vec3b;
using std::fstream;
using std::ios;
using std::max;
@@ -32,77 +28,6 @@ using google::protobuf::io::CodedOutputStream;
namespace caffe {
-void ReadImageToProto(const string& filename, BlobProto* proto) {
- Mat cv_img;
- cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR);
- CHECK(cv_img.data) << "Could not open or find the image.";
- DCHECK_EQ(cv_img.channels(), 3);
- proto->set_num(1);
- proto->set_channels(3);
- proto->set_height(cv_img.rows);
- proto->set_width(cv_img.cols);
- proto->clear_data();
- proto->clear_diff();
- for (int c = 0; c < 3; ++c) {
- for (int h = 0; h < cv_img.rows; ++h) {
- for (int w = 0; w < cv_img.cols; ++w) {
- proto->add_data(static_cast<float>(cv_img.at<Vec3b>(h, w)[c]) / 255.);
- }
- }
- }
-}
-
-void ReadImageToDatum(const string& filename, const int label, Datum* datum) {
- Mat cv_img;
- cv_img = cv::imread(filename, CV_LOAD_IMAGE_COLOR);
- CHECK(cv_img.data) << "Could not open or find the image.";
- DCHECK_EQ(cv_img.channels(), 3);
- datum->set_channels(3);
- datum->set_height(cv_img.rows);
- datum->set_width(cv_img.cols);
- datum->set_label(label);
- datum->clear_data();
- datum->clear_float_data();
- string* datum_string = datum->mutable_data();
- for (int c = 0; c < 3; ++c) {
- for (int h = 0; h < cv_img.rows; ++h) {
- for (int w = 0; w < cv_img.cols; ++w) {
- datum_string->push_back(static_cast<char>(cv_img.at<Vec3b>(h, w)[c]));
- }
- }
- }
-}
-
-
-void WriteProtoToImage(const string& filename, const BlobProto& proto) {
- CHECK_EQ(proto.num(), 1);
- CHECK(proto.channels() == 3 || proto.channels() == 1);
- CHECK_GT(proto.height(), 0);
- CHECK_GT(proto.width(), 0);
- Mat cv_img(proto.height(), proto.width(), CV_8UC3);
- if (proto.channels() == 1) {
- for (int c = 0; c < 3; ++c) {
- for (int h = 0; h < cv_img.rows; ++h) {
- for (int w = 0; w < cv_img.cols; ++w) {
- cv_img.at<Vec3b>(h, w)[c] =
- uint8_t(proto.data(h * cv_img.cols + w) * 255.);
- }
- }
- }
- } else {
- for (int c = 0; c < 3; ++c) {
- for (int h = 0; h < cv_img.rows; ++h) {
- for (int w = 0; w < cv_img.cols; ++w) {
- cv_img.at<Vec3b>(h, w)[c] =
- uint8_t(proto.data((c * cv_img.rows + h) * cv_img.cols + w)
- * 255.);
- }
- }
- }
- }
- CHECK(cv::imwrite(filename, cv_img));
-}
-
void ReadProtoFromTextFile(const char* filename,
::google::protobuf::Message* proto) {
int fd = open(filename, O_RDONLY);
diff --git a/src/caffe/util/io.hpp b/src/caffe/util/io.hpp
index 201e729f..0dce4e7e 100644
--- a/src/caffe/util/io.hpp
+++ b/src/caffe/util/io.hpp
@@ -15,26 +15,6 @@ using ::google::protobuf::Message;
namespace caffe {
-void ReadImageToProto(const string& filename, BlobProto* proto);
-
-template <typename Dtype>
-inline void ReadImageToBlob(const string& filename, Blob<Dtype>* blob) {
- BlobProto proto;
- ReadImageToProto(filename, &proto);
- blob->FromProto(proto);
-}
-
-void WriteProtoToImage(const string& filename, const BlobProto& proto);
-
-template <typename Dtype>
-inline void WriteBlobToImage(const string& filename, const Blob<Dtype>& blob) {
- BlobProto proto;
- blob.ToProto(&proto);
- WriteProtoToImage(filename, proto);
-}
-
-void ReadImageToDatum(const string& filename, const int label, Datum* datum);
-
void ReadProtoFromTextFile(const char* filename,
Message* proto);
inline void ReadProtoFromTextFile(const string& filename,
diff --git a/src/programs/convert_dataset.cpp b/src/programs/convert_dataset.cpp
deleted file mode 100644
index 3bf7794f..00000000
--- a/src/programs/convert_dataset.cpp
+++ /dev/null
@@ -1,81 +0,0 @@
-// Copyright 2013 Yangqing Jia
-// This program converts a set of images to a leveldb by storing them as Datum
-// proto buffers.
-// Usage:
-// convert_dataset ROOTFOLDER LISTFILE DB_NAME
-// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
-// should be a list of files as well as their labels, in the format as
-// subfolder1/file1.JPEG 0
-// ....
-// You are responsible for shuffling the files yourself.
-
-#include <glog/logging.h>
-#include <leveldb/db.h>
-#include <leveldb/write_batch.h>
-
-#include <string>
-#include <iostream>
-#include <fstream>
-
-#include "caffe/proto/caffe.pb.h"
-#include "caffe/util/io.hpp"
-
-using namespace caffe;
-using std::string;
-using std::stringstream;
-
-// A utility function to generate random strings
-void GenerateRandomPrefix(const int n, string* key) {
- const char* kCHARS = "abcdefghijklmnopqrstuvwxyz";
- key->clear();
- for (int i = 0; i < n; ++i) {
- key->push_back(kCHARS[rand() % 26]);
- }
- key->push_back('_');
-}
-
-int main(int argc, char** argv) {
- ::google::InitGoogleLogging(argv[0]);
- std::ifstream infile(argv[2]);
- leveldb::DB* db;
- leveldb::Options options;
- options.error_if_exists = true;
- options.create_if_missing = true;
- options.create_if_missing = true;
- options.write_buffer_size = 268435456;
- LOG(INFO) << "Opening leveldb " << argv[3];
- leveldb::Status status = leveldb::DB::Open(
- options, argv[3], &db);
- CHECK(status.ok()) << "Failed to open leveldb " << argv[3];
-
- string root_folder(argv[1]);
- string filename;
- int label;
- Datum datum;
- int count = 0;
- char key_cstr[100];
- leveldb::WriteBatch* batch = new leveldb::WriteBatch();
- while (infile >> filename >> label) {
- ReadImageToDatum(root_folder + filename, label, &datum);
- // sequential
- sprintf(key_cstr, "%08d_%s", count, filename.c_str());
- string key(key_cstr);
- // random
- // string key;
- // GenerateRandomPrefix(8, &key);
- // key += filename;
- string value;
- // get the value
- datum.SerializeToString(&value);
- batch->Put(key, value);
- if (++count % 1000 == 0) {
- db->Write(leveldb::WriteOptions(), batch);
- LOG(ERROR) << "Processed " << count << " files.";
- delete batch;
- batch = new leveldb::WriteBatch();
- }
- }
-
- delete db;
- return 0;
-}
diff --git a/src/programs/demo_mnist.cpp b/src/programs/demo_mnist.cpp
index 37e697e9..284b671f 100644
--- a/src/programs/demo_mnist.cpp
+++ b/src/programs/demo_mnist.cpp
@@ -17,7 +17,7 @@
using namespace caffe;
int main(int argc, char** argv) {
- cudaSetDevice(0);
+ cudaSetDevice(1);
Caffe::set_mode(Caffe::GPU);
Caffe::set_phase(Caffe::TRAIN);
@@ -35,7 +35,7 @@ int main(int argc, char** argv) {
SolverParameter solver_param;
solver_param.set_base_lr(0.01);
- solver_param.set_display(1);
+ solver_param.set_display(100);
solver_param.set_max_iter(6000);
solver_param.set_lr_policy("inv");
solver_param.set_gamma(0.0001);
diff --git a/src/programs/dump_network.cpp b/src/programs/dump_network.cpp
index 35071000..8dd8b0df 100644
--- a/src/programs/dump_network.cpp
+++ b/src/programs/dump_network.cpp
@@ -4,7 +4,10 @@
// all the intermediate blobs produced by the net to individual binary
// files stored in protobuffer binary formats.
// Usage:
-// dump_network input_net_param trained_net_param input_blob output_prefix
+// dump_network input_net_param trained_net_param input_blob output_prefix 0/1
+// if input_net_param is 'none', we will directly load the network from
+// trained_net_param. If the last argv is 1, we will do a forward-backward pass
+// before dumping everyting, and also dump the who network.
#include <cuda_runtime.h>
#include <fcntl.h>
@@ -29,26 +32,41 @@ int main(int argc, char** argv) {
NetParameter net_param;
NetParameter trained_net_param;
- ReadProtoFromTextFile(argv[1], &net_param);
- ReadProtoFromBinaryFile(argv[2], &trained_net_param);
- BlobProto input_blob_proto;
- ReadProtoFromBinaryFile(argv[3], &input_blob_proto);
- shared_ptr<Blob<float> > input_blob(new Blob<float>());
- input_blob->FromProto(input_blob_proto);
+ if (strcmp(argv[1], "none") == 0) {
+ // We directly load the net param from trained file
+ ReadProtoFromBinaryFile(argv[2], &net_param);
+ } else {
+ ReadProtoFromTextFile(argv[1], &net_param);
+ }
+ ReadProtoFromBinaryFile(argv[2], &trained_net_param);
+
vector<Blob<float>* > input_vec;
- input_vec.push_back(input_blob.get());
- // For implementational reasons, we need to first set up the net, and
- // then copy the trained parameters.
+ if (strcmp(argv[3], "none") != 0) {
+ BlobProto input_blob_proto;
+ ReadProtoFromBinaryFile(argv[3], &input_blob_proto);
+ shared_ptr<Blob<float> > input_blob(new Blob<float>());
+ input_blob->FromProto(input_blob_proto);
+ input_vec.push_back(input_blob.get());
+ }
+
shared_ptr<Net<float> > caffe_net(new Net<float>(net_param, input_vec));
caffe_net->CopyTrainedLayersFrom(trained_net_param);
+ string output_prefix(argv[4]);
// Run the network without training.
LOG(ERROR) << "Performing Forward";
caffe_net->Forward(input_vec);
-
+ if (argc > 4 && strcmp(argv[4], "1")) {
+ LOG(ERROR) << "Performing Backward";
+ caffe_net->Backward();
+ // Dump the network
+ NetParameter output_net_param;
+ caffe_net->ToProto(&output_net_param, true);
+ WriteProtoToBinaryFile(output_net_param, output_prefix + output_net_param.name());
+ }
// Now, let's dump all the layers
- string output_prefix(argv[4]);
+
const vector<string>& blob_names = caffe_net->blob_names();
const vector<shared_ptr<Blob<float> > >& blobs = caffe_net->blobs();
for (int blobid = 0; blobid < caffe_net->blobs().size(); ++blobid) {
@@ -59,6 +77,5 @@ int main(int argc, char** argv) {
WriteProtoToBinaryFile(output_blob_proto, output_prefix + blob_names[blobid]);
}
- // Dump results.
return 0;
}
diff --git a/src/programs/imagenet.prototxt b/src/programs/imagenet.prototxt
index 65b7432c..53295de2 100644
--- a/src/programs/imagenet.prototxt
+++ b/src/programs/imagenet.prototxt
@@ -4,7 +4,7 @@ layers {
name: "data"
type: "data"
source: "/home/jiayq/caffe-train-leveldb"
- batchsize: 96
+ batchsize: 128
subtraction: 114
cropsize: 227
mirror: true
@@ -27,6 +27,8 @@ layers {
type: "constant"
value: 0
}
+ blobs_lr: 1.
+ blobs_lr: 2.
}
bottom: "data"
top: "conv1"
@@ -85,6 +87,8 @@ layers {
type: "constant"
value: 1
}
+ blobs_lr: 1.
+ blobs_lr: 2.
}
bottom: "pad2"
top: "conv2"
@@ -142,6 +146,8 @@ layers {
type: "constant"
value: 0
}
+ blobs_lr: 1.
+ blobs_lr: 2.
}
bottom: "pad3"
top: "conv3"
@@ -178,6 +184,8 @@ layers {
type: "constant"
value: 1
}
+ blobs_lr: 1.
+ blobs_lr: 2.
}
bottom: "pad4"
top: "conv4"
@@ -214,6 +222,8 @@ layers {
type: "constant"
value: 1
}
+ blobs_lr: 1.
+ blobs_lr: 2.
}
bottom: "pad5"
top: "conv5"
@@ -250,6 +260,8 @@ layers {
type: "constant"
value: 1
}
+ blobs_lr: 1.
+ blobs_lr: 2.
}
bottom: "pool5"
top: "fc6"
@@ -284,6 +296,8 @@ layers {
type: "constant"
value: 1
}
+ blobs_lr: 1.
+ blobs_lr: 2.
}
bottom: "drop6"
top: "fc7"
@@ -318,6 +332,8 @@ layers {
type: "constant"
value: 0
}
+ blobs_lr: 1.
+ blobs_lr: 2.
}
bottom: "drop7"
top: "fc8"
diff --git a/src/programs/imagenet_solver.prototxt b/src/programs/imagenet_solver.prototxt
index 58b0dfef..6e583638 100644
--- a/src/programs/imagenet_solver.prototxt
+++ b/src/programs/imagenet_solver.prototxt
@@ -1,10 +1,10 @@
-base_lr: 0.02
+base_lr: 0.01
lr_policy: "step"
gamma: 0.1
stepsize: 340000
-display: 100
+display: 20
max_iter: 1200000
momentum: 0.9
weight_decay: 0.0005
-snapshot: 15000
+snapshot: 5000
snapshot_prefix: "alexnet_train" \ No newline at end of file