summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorYangqing Jia <jiayq84@gmail.com>2013-09-30 14:08:39 -0700
committerYangqing Jia <jiayq84@gmail.com>2013-09-30 14:08:39 -0700
commitf171796995d17e3db4b4752a116e532d499dc91f (patch)
treee98d6b54c5a1cbe1c1a9e157940c974a04d50d6e /src
parent45cbe8e3177abfd22b6590a1bf3eeef55960457d (diff)
downloadcaffeonacl-f171796995d17e3db4b4752a116e532d499dc91f.tar.gz
caffeonacl-f171796995d17e3db4b4752a116e532d499dc91f.tar.bz2
caffeonacl-f171796995d17e3db4b4752a116e532d499dc91f.zip
solver
Diffstat (limited to 'src')
-rw-r--r--src/caffe/blob.cpp20
-rw-r--r--src/caffe/layers/data_layer.cpp20
-rw-r--r--src/caffe/optimization/solver.cpp54
-rw-r--r--src/caffe/proto/caffe.proto4
-rw-r--r--src/caffe/pyutil/convert.py7
-rw-r--r--src/caffe/test/data/simple_linear_regression_data.py16
-rw-r--r--src/caffe/test/test_solver_linear_regression.cpp76
-rw-r--r--src/caffe/util/math_functions.cpp37
-rw-r--r--src/caffe/util/math_functions.hpp11
9 files changed, 221 insertions, 24 deletions
diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp
index d31ba72b..68380367 100644
--- a/src/caffe/blob.cpp
+++ b/src/caffe/blob.cpp
@@ -6,6 +6,7 @@
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/syncedmem.hpp"
+#include "caffe/util/math_functions.hpp"
namespace caffe {
@@ -86,9 +87,24 @@ Dtype* Blob<Dtype>::mutable_gpu_diff() {
template <typename Dtype>
void Blob<Dtype>::Update() {
- // not implemented yet.
- LOG(FATAL) << "not implemented";
// We will perform update based on where the data is located.
+ switch (data_->head()) {
+ case SyncedMemory::HEAD_AT_CPU:
+ // perform computation on CPU
+ caffe_axpy<Dtype>(count_, Dtype(-1),
+ reinterpret_cast<const Dtype*>(diff_->cpu_data()),
+ reinterpret_cast<Dtype*>(data_->mutable_cpu_data()));
+ break;
+ case SyncedMemory::HEAD_AT_GPU:
+ case SyncedMemory::SYNCED:
+ // perform computation on GPU
+ caffe_gpu_axpy<Dtype>(count_, Dtype(-1),
+ reinterpret_cast<const Dtype*>(diff_->gpu_data()),
+ reinterpret_cast<Dtype*>(data_->mutable_gpu_data()));
+ break;
+ default:
+ LOG(FATAL) << "Syncedmem not initialized.";
+ }
}
template <typename Dtype>
diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp
index e527bd3a..d42a8104 100644
--- a/src/caffe/layers/data_layer.cpp
+++ b/src/caffe/layers/data_layer.cpp
@@ -40,7 +40,7 @@ void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
// label
(*top)[1]->Reshape(this->layer_param_.batchsize(), 1, 1, 1);
// datum size
- datum_size_ = datum.data().size();
+ datum_size_ = datum.channels() * datum.height() * datum.width();
}
template <typename Dtype>
@@ -51,13 +51,25 @@ void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
Dtype* top_label = (*top)[1]->mutable_cpu_data();
const Dtype scale = this->layer_param_.scale();
const Dtype subtraction = this->layer_param_.subtraction();
+ // LOG(ERROR) << "Debug code on";
+ // if (true) {
+ // iter_->SeekToFirst();
+ // }
for (int i = 0; i < this->layer_param_.batchsize(); ++i) {
// get a blob
datum.ParseFromString(iter_->value().ToString());
const string& data = datum.data();
- for (int j = 0; j < datum_size_; ++j) {
- top_data[i * datum_size_ + j] =
- (static_cast<Dtype>((uint8_t)data[j]) * scale) - subtraction;
+ // we will prefer to use data() first, and then try float_data()
+ if (data.size()) {
+ for (int j = 0; j < datum_size_; ++j) {
+ top_data[i * datum_size_ + j] =
+ (static_cast<Dtype>((uint8_t)data[j]) * scale) - subtraction;
+ }
+ } else {
+ for (int j = 0; j < datum_size_; ++j) {
+ top_data[i * datum_size_ + j] =
+ (datum.float_data(j) * scale) - subtraction;
+ }
}
top_label[i] = datum.label();
// go to the next iter
diff --git a/src/caffe/optimization/solver.cpp b/src/caffe/optimization/solver.cpp
index 4f343c81..2df872ed 100644
--- a/src/caffe/optimization/solver.cpp
+++ b/src/caffe/optimization/solver.cpp
@@ -1,12 +1,16 @@
// Copyright Yangqing Jia 2013
+#include <algorithm>
#include <fstream>
#include <string>
#include "caffe/proto/caffe.pb.h"
#include "caffe/net.hpp"
+#include "caffe/util/math_functions.hpp"
#include "caffe/optimization/solver.hpp"
+using std::max;
+using std::min;
using std::stringstream;
using std::ofstream;
@@ -23,13 +27,16 @@ void Solver<Dtype>::Solve(Net<Dtype>* net) {
while (iter_++ < param_.max_iter()) {
Dtype loss = net_->ForwardBackWard(bottom_vec);
ComputeUpdateValue();
- net->Update();
+ net_->Update();
// Check if we need to do snapshot
- if (iter_ % param_.snapshot()) {
+ if (param_.snapshot() > 0 && iter_ % param_.snapshot()) {
// TODO(Yangqing): snapshot
+ NOT_IMPLEMENTED;
+ }
+ if (param_.display()) {
+ LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
}
- LOG(INFO) << "Iteration" << iter_ << ", loss=" << loss;
}
LOG(INFO) << "Optimization Done.";
}
@@ -62,20 +69,20 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
} else if (lr_policy == "inv") {
rate = this->param_.base_lr() *
pow(Dtype(1) + this->param_.gamma() * this->iter_,
- this->param_.power());
+ - this->param_.power());
} else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
}
- rate = min(max(rate, this->param_.min_pr()), this->param_.max_lr());
+ rate = min(max(rate, Dtype(this->param_.min_lr())),
+ Dtype(this->param_.max_lr()));
return rate;
}
template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue() {
// First of all, see if we need to initialize the history
- vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_.params();
- if (this->iter_ == 1 && this->param_.momentum() > 0) {
- LOG(INFO) << "Using momentum " << this->param_.momentum();
+ vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
+ if (history_.size() == 0 && this->param_.momentum() > 0) {
for (int i = 0; i < net_params.size(); ++i) {
const Blob<Dtype>* net_param = net_params[i].get();
history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
@@ -85,28 +92,47 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
}
// get the learning rate
Dtype rate = GetLearningRate();
- if (this->param_.momentum == 0) {
+ if (this->param_.momentum() == 0) {
for (int i = 0; i < net_params.size(); ++i) {
switch (Caffe::mode()) {
case Caffe::CPU:
caffe_scal(net_params[i]->count(), rate,
- net_params[i]->mutable_cpu_data());
+ net_params[i]->mutable_cpu_diff());
break;
case Caffe::GPU:
caffe_gpu_scal(net_params[i]->count(), rate,
- net_params[i]->mutable_gpu_data());
+ net_params[i]->mutable_gpu_diff());
break;
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
} else {
- NOT_IMPLEMENTED;
+ // Need to maintain momentum
+ for (int i = 0; i < net_params.size(); ++i) {
+ switch (Caffe::mode()) {
+ case Caffe::CPU:
+ caffe_axpby(net_params[i]->count(), rate,
+ net_params[i]->cpu_diff(), Dtype(this->param_.momentum()),
+ history_[i]->mutable_cpu_data());
+ caffe_copy(net_params[i]->count(), history_[i]->cpu_data(),
+ net_params[i]->mutable_cpu_diff());
+ break;
+ case Caffe::GPU:
+ caffe_gpu_axpby(net_params[i]->count(), rate,
+ net_params[i]->gpu_diff(), Dtype(this->param_.momentum()),
+ history_[i]->mutable_gpu_data());
+ caffe_gpu_copy(net_params[i]->count(), history_[i]->gpu_data(),
+ net_params[i]->mutable_gpu_diff());
+ break;
+ default:
+ LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ }
+ }
}
}
-
-
INSTANTIATE_CLASS(Solver);
+INSTANTIATE_CLASS(SGDSolver);
} // namespace caffe \ No newline at end of file
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index 012ad56e..c3632c11 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -18,6 +18,8 @@ message Datum {
// the actual image data, in bytes
optional bytes data = 4;
optional int32 label = 5;
+ // Optionally, the datum could also hold float data.
+ repeated float float_data = 6;
}
message FillerParameter {
@@ -84,7 +86,7 @@ message SolverParameter {
optional float base_lr = 1; // The base learning rate
optional int32 display = 2; // display options. 0 = no display
optional int32 max_iter = 3; // the maximum number of iterations
- optional int32 snapshot = 4; // The snapshot interval
+ optional int32 snapshot = 4 [default = 0]; // The snapshot interval
optional string lr_policy = 5; // The learning rate decay policy.
optional float min_lr = 6 [default = 0]; // The mininum learning rate
optional float max_lr = 7 [default = 1e10]; // The maximum learning rate
diff --git a/src/caffe/pyutil/convert.py b/src/caffe/pyutil/convert.py
index efcea42e..8a76a508 100644
--- a/src/caffe/pyutil/convert.py
+++ b/src/caffe/pyutil/convert.py
@@ -20,9 +20,10 @@ def array_to_blobproto(arr):
def array_to_datum(arr):
if arr.ndim != 3:
raise ValueError('Incorrect array shape.')
- if arr.dtype != np.uint8:
- raise TypeError('Input array has to be of type uint8.')
datum = caffe_pb2.Datum()
datum.channels, datum.height, datum.width = arr.shape
- datum.data = arr.tostring()
+ if arr.dtype == np.uint8:
+ datum.data = arr.tostring()
+ else:
+ datum.float_data.extend(arr.flat)
return datum
diff --git a/src/caffe/test/data/simple_linear_regression_data.py b/src/caffe/test/data/simple_linear_regression_data.py
new file mode 100644
index 00000000..e8fe840d
--- /dev/null
+++ b/src/caffe/test/data/simple_linear_regression_data.py
@@ -0,0 +1,16 @@
+"""This script generates the mnist train and test leveldbs used in the
+test.
+"""
+from caffe.pyutil import convert
+import numpy as np
+import leveldb
+
+db = leveldb.LevelDB('simple-linear-regression-leveldb')
+
+for i in range(1000):
+ label = np.random.randint(2) * 2 - 1
+ arr = np.random.randn(2,1,1) + label
+ datum = convert.array_to_datum(arr)
+ datum.label = label
+ db.Put('%d' % (i), datum.SerializeToString())
+del db
diff --git a/src/caffe/test/test_solver_linear_regression.cpp b/src/caffe/test/test_solver_linear_regression.cpp
new file mode 100644
index 00000000..f3689081
--- /dev/null
+++ b/src/caffe/test/test_solver_linear_regression.cpp
@@ -0,0 +1,76 @@
+// Copyright 2013 Yangqing Jia
+
+#include <cuda_runtime.h>
+#include <fcntl.h>
+#include <google/protobuf/text_format.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <gtest/gtest.h>
+
+#include <cstring>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/net.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/util/io.hpp"
+#include "caffe/optimization/solver.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+class SolverTest : public ::testing::Test {};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(SolverTest, Dtypes);
+
+TYPED_TEST(SolverTest, TestSolve) {
+ Caffe::set_mode(Caffe::GPU);
+
+ NetParameter net_param;
+ ReadProtoFromTextFile("caffe/test/data/linear_regression.prototxt",
+ &net_param);
+ // check if things are right
+ EXPECT_EQ(net_param.layers_size(), 3);
+ EXPECT_EQ(net_param.input_size(), 0);
+ vector<Blob<TypeParam>*> bottom_vec;
+ Net<TypeParam> caffe_net(net_param, bottom_vec);
+ EXPECT_EQ(caffe_net.layer_names().size(), 3);
+ EXPECT_EQ(caffe_net.blob_names().size(), 3);
+
+ // Run the network without training.
+ LOG(ERROR) << "Performing Forward";
+ caffe_net.Forward(bottom_vec);
+ LOG(ERROR) << "Performing Backward";
+ LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
+
+ SolverParameter solver_param;
+ solver_param.set_base_lr(0.1);
+ solver_param.set_display(0);
+ solver_param.set_max_iter(100);
+ solver_param.set_lr_policy("inv");
+ solver_param.set_gamma(1.);
+ solver_param.set_power(0.75);
+ solver_param.set_momentum(0.9);
+
+ LOG(ERROR) << "Starting Optimization";
+ SGDSolver<TypeParam> solver(solver_param);
+ solver.Solve(&caffe_net);
+ LOG(ERROR) << "Optimization Done.";
+ LOG(ERROR) << "Weight: " << caffe_net.params()[0]->cpu_data()[0] << ", "
+ << caffe_net.params()[0]->cpu_data()[1];
+ LOG(ERROR) << "Bias: " << caffe_net.params()[1]->cpu_data()[0];
+
+ EXPECT_GE(caffe_net.params()[0]->cpu_data()[0], 0.3);
+ EXPECT_LE(caffe_net.params()[0]->cpu_data()[0], 0.35);
+
+ EXPECT_GE(caffe_net.params()[0]->cpu_data()[1], 0.3);
+ EXPECT_LE(caffe_net.params()[0]->cpu_data()[1], 0.35);
+
+ EXPECT_GE(caffe_net.params()[1]->cpu_data()[0], -0.01);
+ EXPECT_LE(caffe_net.params()[1]->cpu_data()[0], 0.01);
+}
+
+} // namespace caffe
diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp
index a3a94cd5..a0745456 100644
--- a/src/caffe/util/math_functions.cpp
+++ b/src/caffe/util/math_functions.cpp
@@ -103,6 +103,19 @@ template <>
void caffe_axpy<double>(const int N, const double alpha, const double* X,
double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); }
+
+template <>
+void caffe_gpu_axpy<float>(const int N, const float alpha, const float* X,
+ float* Y) {
+ CUBLAS_CHECK(cublasSaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1));
+}
+
+template <>
+void caffe_gpu_axpy<double>(const int N, const double alpha, const double* X,
+ double* Y) {
+ CUBLAS_CHECK(cublasDaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1));
+}
+
template <>
void caffe_axpby<float>(const int N, const float alpha, const float* X,
const float beta, float* Y) {
@@ -126,6 +139,16 @@ void caffe_copy<double>(const int N, const double* X, double* Y) {
}
template <>
+void caffe_gpu_copy<float>(const int N, const float* X, float* Y) {
+ CUBLAS_CHECK(cublasScopy(Caffe::cublas_handle(), N, X, 1, Y, 1));
+}
+
+template <>
+void caffe_gpu_copy<double>(const int N, const double* X, double* Y) {
+ CUBLAS_CHECK(cublasDcopy(Caffe::cublas_handle(), N, X, 1, Y, 1));
+}
+
+template <>
void caffe_scal<float>(const int N, const float alpha, float *X) {
cblas_sscal(N, alpha, X, 1);
}
@@ -146,6 +169,20 @@ void caffe_gpu_scal<double>(const int N, const double alpha, double *X) {
}
template <>
+void caffe_gpu_axpby<float>(const int N, const float alpha, const float* X,
+ const float beta, float* Y) {
+ caffe_gpu_scal<float>(N, beta, Y);
+ caffe_gpu_axpy<float>(N, alpha, X, Y);
+}
+
+template <>
+void caffe_gpu_axpby<double>(const int N, const double alpha, const double* X,
+ const double beta, double* Y) {
+ caffe_gpu_scal<double>(N, beta, Y);
+ caffe_gpu_axpy<double>(N, alpha, X, Y);
+}
+
+template <>
void caffe_sqr<float>(const int n, const float* a, float* y) {
vsSqr(n, a, y);
}
diff --git a/src/caffe/util/math_functions.hpp b/src/caffe/util/math_functions.hpp
index e3ace98d..a71f28ed 100644
--- a/src/caffe/util/math_functions.hpp
+++ b/src/caffe/util/math_functions.hpp
@@ -40,13 +40,24 @@ void caffe_axpy(const int N, const Dtype alpha, const Dtype* X,
Dtype* Y);
template <typename Dtype>
+void caffe_gpu_axpy(const int N, const Dtype alpha, const Dtype* X,
+ Dtype* Y);
+
+template <typename Dtype>
void caffe_axpby(const int N, const Dtype alpha, const Dtype* X,
const Dtype beta, Dtype* Y);
template <typename Dtype>
+void caffe_gpu_axpby(const int N, const Dtype alpha, const Dtype* X,
+ const Dtype beta, Dtype* Y);
+
+template <typename Dtype>
void caffe_copy(const int N, const Dtype *X, Dtype *Y);
template <typename Dtype>
+void caffe_gpu_copy(const int N, const Dtype *X, Dtype *Y);
+
+template <typename Dtype>
void caffe_scal(const int N, const Dtype alpha, Dtype *X);
template <typename Dtype>