summaryrefslogtreecommitdiff
path: root/src/caffe
diff options
context:
space:
mode:
authorRonghang Hu <huronghang@hotmail.com>2015-09-24 17:11:07 -0700
committerRonghang Hu <huronghang@hotmail.com>2015-10-16 22:32:12 -0700
commitb822a702d19d4fbebbc91198a991f91c34e60650 (patch)
tree031db585a1701544bad8d190839db3ef481a0337 /src/caffe
parentdfcdb721c6286b7ff40aba3589df1d8e9d281bd9 (diff)
downloadcaffeonacl-b822a702d19d4fbebbc91198a991f91c34e60650.tar.gz
caffeonacl-b822a702d19d4fbebbc91198a991f91c34e60650.tar.bz2
caffeonacl-b822a702d19d4fbebbc91198a991f91c34e60650.zip
Split solver code into one file per solver class
Diffstat (limited to 'src/caffe')
-rw-r--r--src/caffe/solver.cpp811
-rw-r--r--src/caffe/solver_factory.cpp32
-rw-r--r--src/caffe/solvers/adadelta_solver.cpp155
-rw-r--r--src/caffe/solvers/adagrad_solver.cpp88
-rw-r--r--src/caffe/solvers/adam_solver.cpp112
-rw-r--r--src/caffe/solvers/nesterov_solver.cpp70
-rw-r--r--src/caffe/solvers/rmsprop_solver.cpp84
-rw-r--r--src/caffe/solvers/sgd_solver.cpp347
-rw-r--r--src/caffe/test/test_gradient_based_solver.cpp2
-rw-r--r--src/caffe/test/test_solver.cpp1
10 files changed, 890 insertions, 812 deletions
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 12c13dd8..016a0288 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -1,18 +1,11 @@
#include <cstdio>
-#include <algorithm>
#include <string>
#include <vector>
-#include "hdf5.h"
-#include "hdf5_hl.h"
-
-#include "caffe/net.hpp"
-#include "caffe/proto/caffe.pb.h"
#include "caffe/solver.hpp"
#include "caffe/util/hdf5.hpp"
#include "caffe/util/io.hpp"
-#include "caffe/util/math_functions.hpp"
#include "caffe/util/upgrade_proto.hpp"
namespace caffe {
@@ -492,810 +485,6 @@ void Solver<Dtype>::Restore(const char* state_file) {
}
}
-// Return the current learning rate. The currently implemented learning rate
-// policies are as follows:
-// - fixed: always return base_lr.
-// - step: return base_lr * gamma ^ (floor(iter / step))
-// - exp: return base_lr * gamma ^ iter
-// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
-// - multistep: similar to step but it allows non uniform steps defined by
-// stepvalue
-// - poly: the effective learning rate follows a polynomial decay, to be
-// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
-// - sigmoid: the effective learning rate follows a sigmod decay
-// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
-//
-// where base_lr, max_iter, gamma, step, stepvalue and power are defined
-// in the solver parameter protocol buffer, and iter is the current iteration.
-template <typename Dtype>
-Dtype SGDSolver<Dtype>::GetLearningRate() {
- Dtype rate;
- const string& lr_policy = this->param_.lr_policy();
- if (lr_policy == "fixed") {
- rate = this->param_.base_lr();
- } else if (lr_policy == "step") {
- this->current_step_ = this->iter_ / this->param_.stepsize();
- rate = this->param_.base_lr() *
- pow(this->param_.gamma(), this->current_step_);
- } else if (lr_policy == "exp") {
- rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
- } else if (lr_policy == "inv") {
- rate = this->param_.base_lr() *
- pow(Dtype(1) + this->param_.gamma() * this->iter_,
- - this->param_.power());
- } else if (lr_policy == "multistep") {
- if (this->current_step_ < this->param_.stepvalue_size() &&
- this->iter_ >= this->param_.stepvalue(this->current_step_)) {
- this->current_step_++;
- LOG(INFO) << "MultiStep Status: Iteration " <<
- this->iter_ << ", step = " << this->current_step_;
- }
- rate = this->param_.base_lr() *
- pow(this->param_.gamma(), this->current_step_);
- } else if (lr_policy == "poly") {
- rate = this->param_.base_lr() * pow(Dtype(1.) -
- (Dtype(this->iter_) / Dtype(this->param_.max_iter())),
- this->param_.power());
- } else if (lr_policy == "sigmoid") {
- rate = this->param_.base_lr() * (Dtype(1.) /
- (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
- Dtype(this->param_.stepsize())))));
- } else {
- LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
- }
- return rate;
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::PreSolve() {
- // Initialize the history
- const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
- history_.clear();
- update_.clear();
- temp_.clear();
- for (int i = 0; i < net_params.size(); ++i) {
- const vector<int>& shape = net_params[i]->shape();
- history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
- update_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
- temp_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
- }
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::ClipGradients() {
- const Dtype clip_gradients = this->param_.clip_gradients();
- if (clip_gradients < 0) { return; }
- const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
- Dtype sumsq_diff = 0;
- for (int i = 0; i < net_params.size(); ++i) {
- sumsq_diff += net_params[i]->sumsq_diff();
- }
- const Dtype l2norm_diff = std::sqrt(sumsq_diff);
- if (l2norm_diff > clip_gradients) {
- Dtype scale_factor = clip_gradients / l2norm_diff;
- LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm "
- << l2norm_diff << " > " << clip_gradients << ") "
- << "by scale factor " << scale_factor;
- for (int i = 0; i < net_params.size(); ++i) {
- net_params[i]->scale_diff(scale_factor);
- }
- }
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::ApplyUpdate() {
- CHECK(Caffe::root_solver());
- Dtype rate = GetLearningRate();
- if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
- LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
- }
- ClipGradients();
- for (int param_id = 0; param_id < this->net_->learnable_params().size();
- ++param_id) {
- Normalize(param_id);
- Regularize(param_id);
- ComputeUpdateValue(param_id, rate);
- }
- this->net_->Update();
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::Normalize(int param_id) {
- if (this->param_.iter_size() == 1) { return; }
- // Scale gradient to counterbalance accumulation.
- const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
- const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();
- switch (Caffe::mode()) {
- case Caffe::CPU: {
- caffe_scal(net_params[param_id]->count(), accum_normalization,
- net_params[param_id]->mutable_cpu_diff());
- break;
- }
- case Caffe::GPU: {
-#ifndef CPU_ONLY
- caffe_gpu_scal(net_params[param_id]->count(), accum_normalization,
- net_params[param_id]->mutable_gpu_diff());
-#else
- NO_GPU;
-#endif
- break;
- }
- default:
- LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
- }
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::Regularize(int param_id) {
- const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
- const vector<float>& net_params_weight_decay =
- this->net_->params_weight_decay();
- Dtype weight_decay = this->param_.weight_decay();
- string regularization_type = this->param_.regularization_type();
- Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
- switch (Caffe::mode()) {
- case Caffe::CPU: {
- if (local_decay) {
- if (regularization_type == "L2") {
- // add weight decay
- caffe_axpy(net_params[param_id]->count(),
- local_decay,
- net_params[param_id]->cpu_data(),
- net_params[param_id]->mutable_cpu_diff());
- } else if (regularization_type == "L1") {
- caffe_cpu_sign(net_params[param_id]->count(),
- net_params[param_id]->cpu_data(),
- temp_[param_id]->mutable_cpu_data());
- caffe_axpy(net_params[param_id]->count(),
- local_decay,
- temp_[param_id]->cpu_data(),
- net_params[param_id]->mutable_cpu_diff());
- } else {
- LOG(FATAL) << "Unknown regularization type: " << regularization_type;
- }
- }
- break;
- }
- case Caffe::GPU: {
-#ifndef CPU_ONLY
- if (local_decay) {
- if (regularization_type == "L2") {
- // add weight decay
- caffe_gpu_axpy(net_params[param_id]->count(),
- local_decay,
- net_params[param_id]->gpu_data(),
- net_params[param_id]->mutable_gpu_diff());
- } else if (regularization_type == "L1") {
- caffe_gpu_sign(net_params[param_id]->count(),
- net_params[param_id]->gpu_data(),
- temp_[param_id]->mutable_gpu_data());
- caffe_gpu_axpy(net_params[param_id]->count(),
- local_decay,
- temp_[param_id]->gpu_data(),
- net_params[param_id]->mutable_gpu_diff());
- } else {
- LOG(FATAL) << "Unknown regularization type: " << regularization_type;
- }
- }
-#else
- NO_GPU;
-#endif
- break;
- }
- default:
- LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
- }
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
- const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
- const vector<float>& net_params_lr = this->net_->params_lr();
- Dtype momentum = this->param_.momentum();
- Dtype local_rate = rate * net_params_lr[param_id];
- // Compute the update to history, then copy it to the parameter diff.
- switch (Caffe::mode()) {
- case Caffe::CPU: {
- caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
- net_params[param_id]->cpu_diff(), momentum,
- history_[param_id]->mutable_cpu_data());
- caffe_copy(net_params[param_id]->count(),
- history_[param_id]->cpu_data(),
- net_params[param_id]->mutable_cpu_diff());
- break;
- }
- case Caffe::GPU: {
-#ifndef CPU_ONLY
- caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
- net_params[param_id]->gpu_diff(), momentum,
- history_[param_id]->mutable_gpu_data());
- caffe_copy(net_params[param_id]->count(),
- history_[param_id]->gpu_data(),
- net_params[param_id]->mutable_gpu_diff());
-#else
- NO_GPU;
-#endif
- break;
- }
- default:
- LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
- }
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::SnapshotSolverState(const string& model_filename) {
- switch (this->param_.snapshot_format()) {
- case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
- SnapshotSolverStateToBinaryProto(model_filename);
- break;
- case caffe::SolverParameter_SnapshotFormat_HDF5:
- SnapshotSolverStateToHDF5(model_filename);
- break;
- default:
- LOG(FATAL) << "Unsupported snapshot format.";
- }
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto(
- const string& model_filename) {
- SolverState state;
- state.set_iter(this->iter_);
- state.set_learned_net(model_filename);
- state.set_current_step(this->current_step_);
- state.clear_history();
- for (int i = 0; i < history_.size(); ++i) {
- // Add history
- BlobProto* history_blob = state.add_history();
- history_[i]->ToProto(history_blob);
- }
- string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate");
- LOG(INFO)
- << "Snapshotting solver state to binary proto file " << snapshot_filename;
- WriteProtoToBinaryFile(state, snapshot_filename.c_str());
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::SnapshotSolverStateToHDF5(
- const string& model_filename) {
- string snapshot_filename =
- Solver<Dtype>::SnapshotFilename(".solverstate.h5");
- LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename;
- hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC,
- H5P_DEFAULT, H5P_DEFAULT);
- CHECK_GE(file_hid, 0)
- << "Couldn't open " << snapshot_filename << " to save solver state.";
- hdf5_save_int(file_hid, "iter", this->iter_);
- hdf5_save_string(file_hid, "learned_net", model_filename);
- hdf5_save_int(file_hid, "current_step", this->current_step_);
- hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT,
- H5P_DEFAULT);
- CHECK_GE(history_hid, 0)
- << "Error saving solver state to " << snapshot_filename << ".";
- for (int i = 0; i < history_.size(); ++i) {
- ostringstream oss;
- oss << i;
- hdf5_save_nd_dataset<Dtype>(history_hid, oss.str(), *history_[i]);
- }
- H5Gclose(history_hid);
- H5Fclose(file_hid);
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
- const string& state_file) {
- SolverState state;
- ReadProtoFromBinaryFile(state_file, &state);
- this->iter_ = state.iter();
- if (state.has_learned_net()) {
- NetParameter net_param;
- ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
- this->net_->CopyTrainedLayersFrom(net_param);
- }
- this->current_step_ = state.current_step();
- CHECK_EQ(state.history_size(), history_.size())
- << "Incorrect length of history blobs.";
- LOG(INFO) << "SGDSolver: restoring history";
- for (int i = 0; i < history_.size(); ++i) {
- history_[i]->FromProto(state.history(i));
- }
-}
-
-template <typename Dtype>
-void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
- hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
- CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;
- this->iter_ = hdf5_load_int(file_hid, "iter");
- if (H5LTfind_dataset(file_hid, "learned_net")) {
- string learned_net = hdf5_load_string(file_hid, "learned_net");
- this->net_->CopyTrainedLayersFrom(learned_net);
- }
- this->current_step_ = hdf5_load_int(file_hid, "current_step");
- hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT);
- CHECK_GE(history_hid, 0) << "Error reading history from " << state_file;
- int state_history_size = hdf5_get_num_links(history_hid);
- CHECK_EQ(state_history_size, history_.size())
- << "Incorrect length of history blobs.";
- for (int i = 0; i < history_.size(); ++i) {
- ostringstream oss;
- oss << i;
- hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), 0,
- kMaxBlobAxes, history_[i].get());
- }
- H5Gclose(history_hid);
- H5Fclose(file_hid);
-}
-
-template <typename Dtype>
-void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
- CHECK(Caffe::root_solver());
- const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
- const vector<float>& net_params_lr = this->net_->params_lr();
- Dtype momentum = this->param_.momentum();
- Dtype local_rate = rate * net_params_lr[param_id];
- switch (Caffe::mode()) {
- case Caffe::CPU: {
- // save history momentum for stepping back
- caffe_copy(net_params[param_id]->count(),
- this->history_[param_id]->cpu_data(),
- this->update_[param_id]->mutable_cpu_data());
-
- // update history
- caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
- net_params[param_id]->cpu_diff(), momentum,
- this->history_[param_id]->mutable_cpu_data());
-
- // compute update: step back then over step
- caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
- this->history_[param_id]->cpu_data(), -momentum,
- this->update_[param_id]->mutable_cpu_data());
-
- // copy
- caffe_copy(net_params[param_id]->count(),
- this->update_[param_id]->cpu_data(),
- net_params[param_id]->mutable_cpu_diff());
- break;
- }
- case Caffe::GPU: {
-#ifndef CPU_ONLY
- // save history momentum for stepping back
- caffe_copy(net_params[param_id]->count(),
- this->history_[param_id]->gpu_data(),
- this->update_[param_id]->mutable_gpu_data());
-
- // update history
- caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
- net_params[param_id]->gpu_diff(), momentum,
- this->history_[param_id]->mutable_gpu_data());
-
- // compute update: step back then over step
- caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
- this->history_[param_id]->gpu_data(), -momentum,
- this->update_[param_id]->mutable_gpu_data());
-
- // copy
- caffe_copy(net_params[param_id]->count(),
- this->update_[param_id]->gpu_data(),
- net_params[param_id]->mutable_gpu_diff());
-#else
- NO_GPU;
-#endif
- break;
- }
- default:
- LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
- }
-}
-
-template <typename Dtype>
-void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
- CHECK(Caffe::root_solver());
- const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
- const vector<float>& net_params_lr = this->net_->params_lr();
- Dtype delta = this->param_.delta();
- Dtype local_rate = rate * net_params_lr[param_id];
- switch (Caffe::mode()) {
- case Caffe::CPU: {
- // compute square of gradient in update
- caffe_powx(net_params[param_id]->count(),
- net_params[param_id]->cpu_diff(), Dtype(2),
- this->update_[param_id]->mutable_cpu_data());
-
- // update history
- caffe_add(net_params[param_id]->count(),
- this->update_[param_id]->cpu_data(),
- this->history_[param_id]->cpu_data(),
- this->history_[param_id]->mutable_cpu_data());
-
- // prepare update
- caffe_powx(net_params[param_id]->count(),
- this->history_[param_id]->cpu_data(), Dtype(0.5),
- this->update_[param_id]->mutable_cpu_data());
-
- caffe_add_scalar(net_params[param_id]->count(),
- delta, this->update_[param_id]->mutable_cpu_data());
-
- caffe_div(net_params[param_id]->count(),
- net_params[param_id]->cpu_diff(),
- this->update_[param_id]->cpu_data(),
- this->update_[param_id]->mutable_cpu_data());
-
- // scale and copy
- caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
- this->update_[param_id]->cpu_data(), Dtype(0),
- net_params[param_id]->mutable_cpu_diff());
- break;
- }
- case Caffe::GPU: {
-#ifndef CPU_ONLY
- // compute square of gradient in update
- caffe_gpu_powx(net_params[param_id]->count(),
- net_params[param_id]->gpu_diff(), Dtype(2),
- this->update_[param_id]->mutable_gpu_data());
-
- // update history
- caffe_gpu_add(net_params[param_id]->count(),
- this->update_[param_id]->gpu_data(),
- this->history_[param_id]->gpu_data(),
- this->history_[param_id]->mutable_gpu_data());
-
- // prepare update
- caffe_gpu_powx(net_params[param_id]->count(),
- this->history_[param_id]->gpu_data(), Dtype(0.5),
- this->update_[param_id]->mutable_gpu_data());
-
- caffe_gpu_add_scalar(net_params[param_id]->count(),
- delta, this->update_[param_id]->mutable_gpu_data());
-
- caffe_gpu_div(net_params[param_id]->count(),
- net_params[param_id]->gpu_diff(),
- this->update_[param_id]->gpu_data(),
- this->update_[param_id]->mutable_gpu_data());
-
- // scale and copy
- caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
- this->update_[param_id]->gpu_data(), Dtype(0),
- net_params[param_id]->mutable_gpu_diff());
-#else
- NO_GPU;
-#endif
- break;
- }
- default:
- LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
- }
-}
-
-template <typename Dtype>
-void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
- const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
- const vector<float>& net_params_lr = this->net_->params_lr();
-
- // get the learning rate
- Dtype delta = this->param_.delta();
- Dtype rms_decay = this->param_.rms_decay();
- Dtype local_rate = rate * net_params_lr[param_id];
-
- switch (Caffe::mode()) {
- case Caffe::CPU:
- // compute square of gradient in update
- caffe_powx(net_params[param_id]->count(),
- net_params[param_id]->cpu_diff(), Dtype(2),
- this->update_[param_id]->mutable_cpu_data());
-
- // update history
- caffe_cpu_axpby(net_params[param_id] -> count(),
- Dtype(1-rms_decay), this->update_[param_id]->cpu_data(),
- rms_decay, this->history_[param_id]-> mutable_cpu_data());
-
- // prepare update
- caffe_powx(net_params[param_id]->count(),
- this->history_[param_id]->cpu_data(), Dtype(0.5),
- this->update_[param_id]->mutable_cpu_data());
-
- caffe_add_scalar(net_params[param_id]->count(),
- delta, this->update_[param_id]->mutable_cpu_data());
-
- caffe_div(net_params[param_id]->count(),
- net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(),
- this->update_[param_id]->mutable_cpu_data());
-
- // scale and copy
- caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
- this->update_[param_id]->cpu_data(), Dtype(0),
- net_params[param_id]->mutable_cpu_diff());
- break;
- case Caffe::GPU:
-#ifndef CPU_ONLY
- // compute square of gradient in update
- caffe_gpu_powx(net_params[param_id]->count(),
- net_params[param_id]->gpu_diff(), Dtype(2),
- this->update_[param_id]->mutable_gpu_data());
-
- // update history
- caffe_gpu_axpby(net_params[param_id] -> count(),
- Dtype(1-rms_decay), this->update_[param_id]->gpu_data(),
- rms_decay, this->history_[param_id]-> mutable_gpu_data());
-
- // prepare update
- caffe_gpu_powx(net_params[param_id]->count(),
- this->history_[param_id]->gpu_data(), Dtype(0.5),
- this->update_[param_id]->mutable_gpu_data());
-
- caffe_gpu_add_scalar(net_params[param_id]->count(),
- delta, this->update_[param_id]->mutable_gpu_data());
-
- caffe_gpu_div(net_params[param_id]->count(),
- net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(),
- this->update_[param_id]->mutable_gpu_data());
-
- caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
- this->update_[param_id]->gpu_data(), Dtype(0),
- net_params[param_id]->mutable_gpu_diff());
-#else
- NO_GPU;
-#endif
- break;
- default:
- LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
- }
-}
-
-template <typename Dtype>
-void AdaDeltaSolver<Dtype>::AdaDeltaPreSolve() {
- // Add the extra history entries for AdaDelta after those from
- // SGDSolver::PreSolve
- const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
- for (int i = 0; i < net_params.size(); ++i) {
- const vector<int>& shape = net_params[i]->shape();
- this->history_.push_back(
- shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
- }
-}
-
-template <typename Dtype>
-void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
- const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
- const vector<float>& net_params_lr = this->net_->params_lr();
- Dtype delta = this->param_.delta();
- Dtype momentum = this->param_.momentum();
- Dtype local_rate = rate * net_params_lr[param_id];
- size_t update_history_offset = net_params.size();
- switch (Caffe::mode()) {
- case Caffe::CPU: {
- // compute square of gradient in update
- caffe_powx(net_params[param_id]->count(),
- net_params[param_id]->cpu_diff(), Dtype(2),
- this->update_[param_id]->mutable_cpu_data());
-
- // update history of gradients
- caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
- this->update_[param_id]->cpu_data(), momentum,
- this->history_[param_id]->mutable_cpu_data());
-
- // add delta to history to guard against dividing by zero later
- caffe_set(net_params[param_id]->count(), delta,
- this->temp_[param_id]->mutable_cpu_data());
-
- caffe_add(net_params[param_id]->count(),
- this->temp_[param_id]->cpu_data(),
- this->history_[update_history_offset + param_id]->cpu_data(),
- this->update_[param_id]->mutable_cpu_data());
-
- caffe_add(net_params[param_id]->count(),
- this->temp_[param_id]->cpu_data(),
- this->history_[param_id]->cpu_data(),
- this->temp_[param_id]->mutable_cpu_data());
-
- // divide history of updates by history of gradients
- caffe_div(net_params[param_id]->count(),
- this->update_[param_id]->cpu_data(),
- this->temp_[param_id]->cpu_data(),
- this->update_[param_id]->mutable_cpu_data());
-
- // jointly compute the RMS of both for update and gradient history
- caffe_powx(net_params[param_id]->count(),
- this->update_[param_id]->cpu_data(), Dtype(0.5),
- this->update_[param_id]->mutable_cpu_data());
-
- // compute the update
- caffe_mul(net_params[param_id]->count(),
- net_params[param_id]->cpu_diff(),
- this->update_[param_id]->cpu_data(),
- net_params[param_id]->mutable_cpu_diff());
-
- // compute square of update
- caffe_powx(net_params[param_id]->count(),
- net_params[param_id]->cpu_diff(), Dtype(2),
- this->update_[param_id]->mutable_cpu_data());
-
- // update history of updates
- caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
- this->update_[param_id]->cpu_data(), momentum,
- this->history_[update_history_offset + param_id]->mutable_cpu_data());
-
- // apply learning rate
- caffe_cpu_scale(net_params[param_id]->count(), local_rate,
- net_params[param_id]->cpu_diff(),
- net_params[param_id]->mutable_cpu_diff());
- break;
- }
- case Caffe::GPU: {
-#ifndef CPU_ONLY
- // compute square of gradient in update
- caffe_gpu_powx(net_params[param_id]->count(),
- net_params[param_id]->gpu_diff(), Dtype(2),
- this->update_[param_id]->mutable_gpu_data());
-
- // update history of gradients
- caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
- this->update_[param_id]->gpu_data(), momentum,
- this->history_[param_id]->mutable_gpu_data());
-
- // add delta to history to guard against dividing by zero later
- caffe_gpu_set(net_params[param_id]->count(), delta,
- this->temp_[param_id]->mutable_gpu_data());
-
- caffe_gpu_add(net_params[param_id]->count(),
- this->temp_[param_id]->gpu_data(),
- this->history_[update_history_offset + param_id]->gpu_data(),
- this->update_[param_id]->mutable_gpu_data());
-
- caffe_gpu_add(net_params[param_id]->count(),
- this->temp_[param_id]->gpu_data(),
- this->history_[param_id]->gpu_data(),
- this->temp_[param_id]->mutable_gpu_data());
-
- // divide history of updates by history of gradients
- caffe_gpu_div(net_params[param_id]->count(),
- this->update_[param_id]->gpu_data(),
- this->temp_[param_id]->gpu_data(),
- this->update_[param_id]->mutable_gpu_data());
-
- // jointly compute the RMS of both for update and gradient history
- caffe_gpu_powx(net_params[param_id]->count(),
- this->update_[param_id]->gpu_data(), Dtype(0.5),
- this->update_[param_id]->mutable_gpu_data());
-
- // compute the update and copy to net_diff
- caffe_gpu_mul(net_params[param_id]->count(),
- net_params[param_id]->gpu_diff(),
- this->update_[param_id]->gpu_data(),
- net_params[param_id]->mutable_gpu_diff());
-
- // compute square of update
- caffe_gpu_powx(net_params[param_id]->count(),
- net_params[param_id]->gpu_diff(), Dtype(2),
- this->update_[param_id]->mutable_gpu_data());
-
- // update history of updates
- caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
- this->update_[param_id]->gpu_data(), momentum,
- this->history_[update_history_offset + param_id]->mutable_gpu_data());
-
- // apply learning rate
- caffe_gpu_scale(net_params[param_id]->count(), local_rate,
- net_params[param_id]->gpu_diff(),
- net_params[param_id]->mutable_gpu_diff());
-#else
- NO_GPU;
-#endif
- break;
- }
- default:
- LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
- }
-}
-
-template <typename Dtype>
-void AdamSolver<Dtype>::AdamPreSolve() {
- // Add the extra history entries for Adam after those from
- // SGDSolver::PreSolve
- const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
- for (int i = 0; i < net_params.size(); ++i) {
- const vector<int>& shape = net_params[i]->shape();
- this->history_.push_back(
- shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
- }
-}
-
-template <typename Dtype>
-void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
- const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
- const vector<float>& net_params_lr = this->net_->params_lr();
- Dtype local_rate = rate * net_params_lr[param_id];
- const Dtype beta1 = this->param_.momentum();
- const Dtype beta2 = this->param_.momentum2();
-
- // we create aliases for convenience
- size_t update_history_offset = net_params.size();
- Blob<Dtype>* val_m = this->history_[param_id].get();
- Blob<Dtype>* val_v = this->history_[param_id + update_history_offset].get();
- Blob<Dtype>* val_t = this->temp_[param_id].get();
-
- const int t = this->iter_ + 1;
- const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) /
- (Dtype(1.) - pow(beta1, t));
- const int N = net_params[param_id]->count();
- const Dtype eps_hat = this->param_.delta();
-
- switch (Caffe::mode()) {
- case Caffe::CPU: {
- // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
- caffe_cpu_axpby(N, Dtype(1)-beta1,
- net_params[param_id]->cpu_diff(), beta1,
- val_m->mutable_cpu_data());
-
- // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
- caffe_mul(N,
- net_params[param_id]->cpu_diff(),
- net_params[param_id]->cpu_diff(),
- val_t->mutable_cpu_data());
- caffe_cpu_axpby(N, Dtype(1)-beta2,
- val_t->cpu_data(), beta2,
- val_v->mutable_cpu_data());
-
- // set update
- caffe_powx(N,
- val_v->cpu_data(), Dtype(0.5),
- val_t->mutable_cpu_data());
- caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data());
- caffe_div(N,
- val_m->cpu_data(),
- val_t->cpu_data(),
- val_t->mutable_cpu_data());
-
- caffe_cpu_scale(N, local_rate*correction,
- val_t->cpu_data(),
- net_params[param_id]->mutable_cpu_diff());
- break;
- }
- case Caffe::GPU: {
-#ifndef CPU_ONLY
- // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
- caffe_gpu_axpby(N, Dtype(1)-beta1,
- net_params[param_id]->gpu_diff(), beta1,
- val_m->mutable_gpu_data());
-
- // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
- caffe_gpu_mul(N,
- net_params[param_id]->gpu_diff(),
- net_params[param_id]->gpu_diff(),
- val_t->mutable_gpu_data());
- caffe_gpu_axpby(N, Dtype(1)-beta2,
- val_t->gpu_data(), beta2,
- val_v->mutable_gpu_data());
-
- // set update
- caffe_gpu_powx(N,
- val_v->gpu_data(), Dtype(0.5),
- val_t->mutable_gpu_data());
- caffe_gpu_add_scalar(N, eps_hat,
- val_t->mutable_gpu_data());
- caffe_gpu_div(N,
- val_m->gpu_data(),
- val_t->gpu_data(),
- val_t->mutable_gpu_data());
-
- caffe_gpu_scale(N, local_rate*correction,
- val_t->gpu_data(),
- net_params[param_id]->mutable_gpu_diff());
-#else
- NO_GPU;
-#endif
- break;
- }
- default:
- LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
- }
-}
-
INSTANTIATE_CLASS(Solver);
-INSTANTIATE_CLASS(SGDSolver);
-INSTANTIATE_CLASS(NesterovSolver);
-INSTANTIATE_CLASS(AdaGradSolver);
-INSTANTIATE_CLASS(RMSPropSolver);
-INSTANTIATE_CLASS(AdaDeltaSolver);
-INSTANTIATE_CLASS(AdamSolver);
} // namespace caffe
diff --git a/src/caffe/solver_factory.cpp b/src/caffe/solver_factory.cpp
new file mode 100644
index 00000000..f78fab28
--- /dev/null
+++ b/src/caffe/solver_factory.cpp
@@ -0,0 +1,32 @@
+#include "caffe/solver.hpp"
+#include "caffe/sgd_solvers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+Solver<Dtype>* GetSolver(const SolverParameter& param) {
+ SolverParameter_SolverType type = param.solver_type();
+
+ switch (type) {
+ case SolverParameter_SolverType_SGD:
+ return new SGDSolver<Dtype>(param);
+ case SolverParameter_SolverType_NESTEROV:
+ return new NesterovSolver<Dtype>(param);
+ case SolverParameter_SolverType_ADAGRAD:
+ return new AdaGradSolver<Dtype>(param);
+ case SolverParameter_SolverType_RMSPROP:
+ return new RMSPropSolver<Dtype>(param);
+ case SolverParameter_SolverType_ADADELTA:
+ return new AdaDeltaSolver<Dtype>(param);
+ case SolverParameter_SolverType_ADAM:
+ return new AdamSolver<Dtype>(param);
+ default:
+ LOG(FATAL) << "Unknown SolverType: " << type;
+ }
+ return (Solver<Dtype>*) NULL;
+}
+
+template Solver<float>* GetSolver(const SolverParameter& param);
+template Solver<double>* GetSolver(const SolverParameter& param);
+
+} // namespace caffe
diff --git a/src/caffe/solvers/adadelta_solver.cpp b/src/caffe/solvers/adadelta_solver.cpp
new file mode 100644
index 00000000..45cd4eb2
--- /dev/null
+++ b/src/caffe/solvers/adadelta_solver.cpp
@@ -0,0 +1,155 @@
+#include <vector>
+
+#include "caffe/sgd_solvers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void AdaDeltaSolver<Dtype>::AdaDeltaPreSolve() {
+ // Add the extra history entries for AdaDelta after those from
+ // SGDSolver::PreSolve
+ const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+ for (int i = 0; i < net_params.size(); ++i) {
+ const vector<int>& shape = net_params[i]->shape();
+ this->history_.push_back(
+ shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
+ }
+}
+
+template <typename Dtype>
+void AdaDeltaSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+ const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+ const vector<float>& net_params_lr = this->net_->params_lr();
+ Dtype delta = this->param_.delta();
+ Dtype momentum = this->param_.momentum();
+ Dtype local_rate = rate * net_params_lr[param_id];
+ size_t update_history_offset = net_params.size();
+ switch (Caffe::mode()) {
+ case Caffe::CPU: {
+ // compute square of gradient in update
+ caffe_powx(net_params[param_id]->count(),
+ net_params[param_id]->cpu_diff(), Dtype(2),
+ this->update_[param_id]->mutable_cpu_data());
+
+ // update history of gradients
+ caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
+ this->update_[param_id]->cpu_data(), momentum,
+ this->history_[param_id]->mutable_cpu_data());
+
+ // add delta to history to guard against dividing by zero later
+ caffe_set(net_params[param_id]->count(), delta,
+ this->temp_[param_id]->mutable_cpu_data());
+
+ caffe_add(net_params[param_id]->count(),
+ this->temp_[param_id]->cpu_data(),
+ this->history_[update_history_offset + param_id]->cpu_data(),
+ this->update_[param_id]->mutable_cpu_data());
+
+ caffe_add(net_params[param_id]->count(),
+ this->temp_[param_id]->cpu_data(),
+ this->history_[param_id]->cpu_data(),
+ this->temp_[param_id]->mutable_cpu_data());
+
+ // divide history of updates by history of gradients
+ caffe_div(net_params[param_id]->count(),
+ this->update_[param_id]->cpu_data(),
+ this->temp_[param_id]->cpu_data(),
+ this->update_[param_id]->mutable_cpu_data());
+
+ // jointly compute the RMS of both for update and gradient history
+ caffe_powx(net_params[param_id]->count(),
+ this->update_[param_id]->cpu_data(), Dtype(0.5),
+ this->update_[param_id]->mutable_cpu_data());
+
+ // compute the update
+ caffe_mul(net_params[param_id]->count(),
+ net_params[param_id]->cpu_diff(),
+ this->update_[param_id]->cpu_data(),
+ net_params[param_id]->mutable_cpu_diff());
+
+ // compute square of update
+ caffe_powx(net_params[param_id]->count(),
+ net_params[param_id]->cpu_diff(), Dtype(2),
+ this->update_[param_id]->mutable_cpu_data());
+
+ // update history of updates
+ caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
+ this->update_[param_id]->cpu_data(), momentum,
+ this->history_[update_history_offset + param_id]->mutable_cpu_data());
+
+ // apply learning rate
+ caffe_cpu_scale(net_params[param_id]->count(), local_rate,
+ net_params[param_id]->cpu_diff(),
+ net_params[param_id]->mutable_cpu_diff());
+ break;
+ }
+ case Caffe::GPU: {
+#ifndef CPU_ONLY
+ // compute square of gradient in update
+ caffe_gpu_powx(net_params[param_id]->count(),
+ net_params[param_id]->gpu_diff(), Dtype(2),
+ this->update_[param_id]->mutable_gpu_data());
+
+ // update history of gradients
+ caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
+ this->update_[param_id]->gpu_data(), momentum,
+ this->history_[param_id]->mutable_gpu_data());
+
+ // add delta to history to guard against dividing by zero later
+ caffe_gpu_set(net_params[param_id]->count(), delta,
+ this->temp_[param_id]->mutable_gpu_data());
+
+ caffe_gpu_add(net_params[param_id]->count(),
+ this->temp_[param_id]->gpu_data(),
+ this->history_[update_history_offset + param_id]->gpu_data(),
+ this->update_[param_id]->mutable_gpu_data());
+
+ caffe_gpu_add(net_params[param_id]->count(),
+ this->temp_[param_id]->gpu_data(),
+ this->history_[param_id]->gpu_data(),
+ this->temp_[param_id]->mutable_gpu_data());
+
+ // divide history of updates by history of gradients
+ caffe_gpu_div(net_params[param_id]->count(),
+ this->update_[param_id]->gpu_data(),
+ this->temp_[param_id]->gpu_data(),
+ this->update_[param_id]->mutable_gpu_data());
+
+ // jointly compute the RMS of both for update and gradient history
+ caffe_gpu_powx(net_params[param_id]->count(),
+ this->update_[param_id]->gpu_data(), Dtype(0.5),
+ this->update_[param_id]->mutable_gpu_data());
+
+ // compute the update and copy to net_diff
+ caffe_gpu_mul(net_params[param_id]->count(),
+ net_params[param_id]->gpu_diff(),
+ this->update_[param_id]->gpu_data(),
+ net_params[param_id]->mutable_gpu_diff());
+
+ // compute square of update
+ caffe_gpu_powx(net_params[param_id]->count(),
+ net_params[param_id]->gpu_diff(), Dtype(2),
+ this->update_[param_id]->mutable_gpu_data());
+
+ // update history of updates
+ caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum,
+ this->update_[param_id]->gpu_data(), momentum,
+ this->history_[update_history_offset + param_id]->mutable_gpu_data());
+
+ // apply learning rate
+ caffe_gpu_scale(net_params[param_id]->count(), local_rate,
+ net_params[param_id]->gpu_diff(),
+ net_params[param_id]->mutable_gpu_diff());
+#else
+ NO_GPU;
+#endif
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ }
+}
+
+INSTANTIATE_CLASS(AdaDeltaSolver);
+
+} // namespace caffe
diff --git a/src/caffe/solvers/adagrad_solver.cpp b/src/caffe/solvers/adagrad_solver.cpp
new file mode 100644
index 00000000..627d816a
--- /dev/null
+++ b/src/caffe/solvers/adagrad_solver.cpp
@@ -0,0 +1,88 @@
+#include <vector>
+
+#include "caffe/sgd_solvers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+ CHECK(Caffe::root_solver());
+ const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+ const vector<float>& net_params_lr = this->net_->params_lr();
+ Dtype delta = this->param_.delta();
+ Dtype local_rate = rate * net_params_lr[param_id];
+ switch (Caffe::mode()) {
+ case Caffe::CPU: {
+ // compute square of gradient in update
+ caffe_powx(net_params[param_id]->count(),
+ net_params[param_id]->cpu_diff(), Dtype(2),
+ this->update_[param_id]->mutable_cpu_data());
+
+ // update history
+ caffe_add(net_params[param_id]->count(),
+ this->update_[param_id]->cpu_data(),
+ this->history_[param_id]->cpu_data(),
+ this->history_[param_id]->mutable_cpu_data());
+
+ // prepare update
+ caffe_powx(net_params[param_id]->count(),
+ this->history_[param_id]->cpu_data(), Dtype(0.5),
+ this->update_[param_id]->mutable_cpu_data());
+
+ caffe_add_scalar(net_params[param_id]->count(),
+ delta, this->update_[param_id]->mutable_cpu_data());
+
+ caffe_div(net_params[param_id]->count(),
+ net_params[param_id]->cpu_diff(),
+ this->update_[param_id]->cpu_data(),
+ this->update_[param_id]->mutable_cpu_data());
+
+ // scale and copy
+ caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+ this->update_[param_id]->cpu_data(), Dtype(0),
+ net_params[param_id]->mutable_cpu_diff());
+ break;
+ }
+ case Caffe::GPU: {
+#ifndef CPU_ONLY
+ // compute square of gradient in update
+ caffe_gpu_powx(net_params[param_id]->count(),
+ net_params[param_id]->gpu_diff(), Dtype(2),
+ this->update_[param_id]->mutable_gpu_data());
+
+ // update history
+ caffe_gpu_add(net_params[param_id]->count(),
+ this->update_[param_id]->gpu_data(),
+ this->history_[param_id]->gpu_data(),
+ this->history_[param_id]->mutable_gpu_data());
+
+ // prepare update
+ caffe_gpu_powx(net_params[param_id]->count(),
+ this->history_[param_id]->gpu_data(), Dtype(0.5),
+ this->update_[param_id]->mutable_gpu_data());
+
+ caffe_gpu_add_scalar(net_params[param_id]->count(),
+ delta, this->update_[param_id]->mutable_gpu_data());
+
+ caffe_gpu_div(net_params[param_id]->count(),
+ net_params[param_id]->gpu_diff(),
+ this->update_[param_id]->gpu_data(),
+ this->update_[param_id]->mutable_gpu_data());
+
+ // scale and copy
+ caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+ this->update_[param_id]->gpu_data(), Dtype(0),
+ net_params[param_id]->mutable_gpu_diff());
+#else
+ NO_GPU;
+#endif
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ }
+}
+
+INSTANTIATE_CLASS(AdaGradSolver);
+
+} // namespace caffe
diff --git a/src/caffe/solvers/adam_solver.cpp b/src/caffe/solvers/adam_solver.cpp
new file mode 100644
index 00000000..8c334f66
--- /dev/null
+++ b/src/caffe/solvers/adam_solver.cpp
@@ -0,0 +1,112 @@
+#include <vector>
+
+#include "caffe/sgd_solvers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void AdamSolver<Dtype>::AdamPreSolve() {
+ // Add the extra history entries for Adam after those from
+ // SGDSolver::PreSolve
+ const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+ for (int i = 0; i < net_params.size(); ++i) {
+ const vector<int>& shape = net_params[i]->shape();
+ this->history_.push_back(
+ shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
+ }
+}
+
+template <typename Dtype>
+void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+ const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+ const vector<float>& net_params_lr = this->net_->params_lr();
+ Dtype local_rate = rate * net_params_lr[param_id];
+ const Dtype beta1 = this->param_.momentum();
+ const Dtype beta2 = this->param_.momentum2();
+
+ // we create aliases for convenience
+ size_t update_history_offset = net_params.size();
+ Blob<Dtype>* val_m = this->history_[param_id].get();
+ Blob<Dtype>* val_v = this->history_[param_id + update_history_offset].get();
+ Blob<Dtype>* val_t = this->temp_[param_id].get();
+
+ const int t = this->iter_ + 1;
+ const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) /
+ (Dtype(1.) - pow(beta1, t));
+ const int N = net_params[param_id]->count();
+ const Dtype eps_hat = this->param_.delta();
+
+ switch (Caffe::mode()) {
+ case Caffe::CPU: {
+ // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
+ caffe_cpu_axpby(N, Dtype(1)-beta1,
+ net_params[param_id]->cpu_diff(), beta1,
+ val_m->mutable_cpu_data());
+
+ // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
+ caffe_mul(N,
+ net_params[param_id]->cpu_diff(),
+ net_params[param_id]->cpu_diff(),
+ val_t->mutable_cpu_data());
+ caffe_cpu_axpby(N, Dtype(1)-beta2,
+ val_t->cpu_data(), beta2,
+ val_v->mutable_cpu_data());
+
+ // set update
+ caffe_powx(N,
+ val_v->cpu_data(), Dtype(0.5),
+ val_t->mutable_cpu_data());
+ caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data());
+ caffe_div(N,
+ val_m->cpu_data(),
+ val_t->cpu_data(),
+ val_t->mutable_cpu_data());
+
+ caffe_cpu_scale(N, local_rate*correction,
+ val_t->cpu_data(),
+ net_params[param_id]->mutable_cpu_diff());
+ break;
+ }
+ case Caffe::GPU: {
+#ifndef CPU_ONLY
+ // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t
+ caffe_gpu_axpby(N, Dtype(1)-beta1,
+ net_params[param_id]->gpu_diff(), beta1,
+ val_m->mutable_gpu_data());
+
+ // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2
+ caffe_gpu_mul(N,
+ net_params[param_id]->gpu_diff(),
+ net_params[param_id]->gpu_diff(),
+ val_t->mutable_gpu_data());
+ caffe_gpu_axpby(N, Dtype(1)-beta2,
+ val_t->gpu_data(), beta2,
+ val_v->mutable_gpu_data());
+
+ // set update
+ caffe_gpu_powx(N,
+ val_v->gpu_data(), Dtype(0.5),
+ val_t->mutable_gpu_data());
+ caffe_gpu_add_scalar(N, eps_hat,
+ val_t->mutable_gpu_data());
+ caffe_gpu_div(N,
+ val_m->gpu_data(),
+ val_t->gpu_data(),
+ val_t->mutable_gpu_data());
+
+ caffe_gpu_scale(N, local_rate*correction,
+ val_t->gpu_data(),
+ net_params[param_id]->mutable_gpu_diff());
+#else
+ NO_GPU;
+#endif
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ }
+}
+
+INSTANTIATE_CLASS(AdamSolver);
+
+} // namespace caffe
diff --git a/src/caffe/solvers/nesterov_solver.cpp b/src/caffe/solvers/nesterov_solver.cpp
new file mode 100644
index 00000000..8135ee2c
--- /dev/null
+++ b/src/caffe/solvers/nesterov_solver.cpp
@@ -0,0 +1,70 @@
+#include <vector>
+
+#include "caffe/sgd_solvers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+ CHECK(Caffe::root_solver());
+ const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+ const vector<float>& net_params_lr = this->net_->params_lr();
+ Dtype momentum = this->param_.momentum();
+ Dtype local_rate = rate * net_params_lr[param_id];
+ switch (Caffe::mode()) {
+ case Caffe::CPU: {
+ // save history momentum for stepping back
+ caffe_copy(net_params[param_id]->count(),
+ this->history_[param_id]->cpu_data(),
+ this->update_[param_id]->mutable_cpu_data());
+
+ // update history
+ caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+ net_params[param_id]->cpu_diff(), momentum,
+ this->history_[param_id]->mutable_cpu_data());
+
+ // compute update: step back then over step
+ caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
+ this->history_[param_id]->cpu_data(), -momentum,
+ this->update_[param_id]->mutable_cpu_data());
+
+ // copy
+ caffe_copy(net_params[param_id]->count(),
+ this->update_[param_id]->cpu_data(),
+ net_params[param_id]->mutable_cpu_diff());
+ break;
+ }
+ case Caffe::GPU: {
+#ifndef CPU_ONLY
+ // save history momentum for stepping back
+ caffe_copy(net_params[param_id]->count(),
+ this->history_[param_id]->gpu_data(),
+ this->update_[param_id]->mutable_gpu_data());
+
+ // update history
+ caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+ net_params[param_id]->gpu_diff(), momentum,
+ this->history_[param_id]->mutable_gpu_data());
+
+ // compute update: step back then over step
+ caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
+ this->history_[param_id]->gpu_data(), -momentum,
+ this->update_[param_id]->mutable_gpu_data());
+
+ // copy
+ caffe_copy(net_params[param_id]->count(),
+ this->update_[param_id]->gpu_data(),
+ net_params[param_id]->mutable_gpu_diff());
+#else
+ NO_GPU;
+#endif
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ }
+}
+
+INSTANTIATE_CLASS(NesterovSolver);
+
+} // namespace caffe
diff --git a/src/caffe/solvers/rmsprop_solver.cpp b/src/caffe/solvers/rmsprop_solver.cpp
new file mode 100644
index 00000000..96d1b3dd
--- /dev/null
+++ b/src/caffe/solvers/rmsprop_solver.cpp
@@ -0,0 +1,84 @@
+#include <vector>
+
+#include "caffe/sgd_solvers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+ const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+ const vector<float>& net_params_lr = this->net_->params_lr();
+
+ // get the learning rate
+ Dtype delta = this->param_.delta();
+ Dtype rms_decay = this->param_.rms_decay();
+ Dtype local_rate = rate * net_params_lr[param_id];
+
+ switch (Caffe::mode()) {
+ case Caffe::CPU:
+ // compute square of gradient in update
+ caffe_powx(net_params[param_id]->count(),
+ net_params[param_id]->cpu_diff(), Dtype(2),
+ this->update_[param_id]->mutable_cpu_data());
+
+ // update history
+ caffe_cpu_axpby(net_params[param_id] -> count(),
+ Dtype(1-rms_decay), this->update_[param_id]->cpu_data(),
+ rms_decay, this->history_[param_id]-> mutable_cpu_data());
+
+ // prepare update
+ caffe_powx(net_params[param_id]->count(),
+ this->history_[param_id]->cpu_data(), Dtype(0.5),
+ this->update_[param_id]->mutable_cpu_data());
+
+ caffe_add_scalar(net_params[param_id]->count(),
+ delta, this->update_[param_id]->mutable_cpu_data());
+
+ caffe_div(net_params[param_id]->count(),
+ net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(),
+ this->update_[param_id]->mutable_cpu_data());
+
+ // scale and copy
+ caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+ this->update_[param_id]->cpu_data(), Dtype(0),
+ net_params[param_id]->mutable_cpu_diff());
+ break;
+ case Caffe::GPU:
+#ifndef CPU_ONLY
+ // compute square of gradient in update
+ caffe_gpu_powx(net_params[param_id]->count(),
+ net_params[param_id]->gpu_diff(), Dtype(2),
+ this->update_[param_id]->mutable_gpu_data());
+
+ // update history
+ caffe_gpu_axpby(net_params[param_id] -> count(),
+ Dtype(1-rms_decay), this->update_[param_id]->gpu_data(),
+ rms_decay, this->history_[param_id]-> mutable_gpu_data());
+
+ // prepare update
+ caffe_gpu_powx(net_params[param_id]->count(),
+ this->history_[param_id]->gpu_data(), Dtype(0.5),
+ this->update_[param_id]->mutable_gpu_data());
+
+ caffe_gpu_add_scalar(net_params[param_id]->count(),
+ delta, this->update_[param_id]->mutable_gpu_data());
+
+ caffe_gpu_div(net_params[param_id]->count(),
+ net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(),
+ this->update_[param_id]->mutable_gpu_data());
+
+ caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+ this->update_[param_id]->gpu_data(), Dtype(0),
+ net_params[param_id]->mutable_gpu_diff());
+#else
+ NO_GPU;
+#endif
+ break;
+ default:
+ LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ }
+}
+
+INSTANTIATE_CLASS(RMSPropSolver);
+
+} // namespace caffe
diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp
new file mode 100644
index 00000000..89ef5ec4
--- /dev/null
+++ b/src/caffe/solvers/sgd_solver.cpp
@@ -0,0 +1,347 @@
+#include <string>
+#include <vector>
+
+#include "caffe/sgd_solvers.hpp"
+#include "caffe/util/hdf5.hpp"
+#include "caffe/util/io.hpp"
+#include "caffe/util/upgrade_proto.hpp"
+
+namespace caffe {
+
+// Return the current learning rate. The currently implemented learning rate
+// policies are as follows:
+// - fixed: always return base_lr.
+// - step: return base_lr * gamma ^ (floor(iter / step))
+// - exp: return base_lr * gamma ^ iter
+// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
+// - multistep: similar to step but it allows non uniform steps defined by
+// stepvalue
+// - poly: the effective learning rate follows a polynomial decay, to be
+// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
+// - sigmoid: the effective learning rate follows a sigmod decay
+// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
+//
+// where base_lr, max_iter, gamma, step, stepvalue and power are defined
+// in the solver parameter protocol buffer, and iter is the current iteration.
+template <typename Dtype>
+Dtype SGDSolver<Dtype>::GetLearningRate() {
+ Dtype rate;
+ const string& lr_policy = this->param_.lr_policy();
+ if (lr_policy == "fixed") {
+ rate = this->param_.base_lr();
+ } else if (lr_policy == "step") {
+ this->current_step_ = this->iter_ / this->param_.stepsize();
+ rate = this->param_.base_lr() *
+ pow(this->param_.gamma(), this->current_step_);
+ } else if (lr_policy == "exp") {
+ rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
+ } else if (lr_policy == "inv") {
+ rate = this->param_.base_lr() *
+ pow(Dtype(1) + this->param_.gamma() * this->iter_,
+ - this->param_.power());
+ } else if (lr_policy == "multistep") {
+ if (this->current_step_ < this->param_.stepvalue_size() &&
+ this->iter_ >= this->param_.stepvalue(this->current_step_)) {
+ this->current_step_++;
+ LOG(INFO) << "MultiStep Status: Iteration " <<
+ this->iter_ << ", step = " << this->current_step_;
+ }
+ rate = this->param_.base_lr() *
+ pow(this->param_.gamma(), this->current_step_);
+ } else if (lr_policy == "poly") {
+ rate = this->param_.base_lr() * pow(Dtype(1.) -
+ (Dtype(this->iter_) / Dtype(this->param_.max_iter())),
+ this->param_.power());
+ } else if (lr_policy == "sigmoid") {
+ rate = this->param_.base_lr() * (Dtype(1.) /
+ (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
+ Dtype(this->param_.stepsize())))));
+ } else {
+ LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
+ }
+ return rate;
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::PreSolve() {
+ // Initialize the history
+ const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+ history_.clear();
+ update_.clear();
+ temp_.clear();
+ for (int i = 0; i < net_params.size(); ++i) {
+ const vector<int>& shape = net_params[i]->shape();
+ history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
+ update_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
+ temp_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
+ }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::ClipGradients() {
+ const Dtype clip_gradients = this->param_.clip_gradients();
+ if (clip_gradients < 0) { return; }
+ const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+ Dtype sumsq_diff = 0;
+ for (int i = 0; i < net_params.size(); ++i) {
+ sumsq_diff += net_params[i]->sumsq_diff();
+ }
+ const Dtype l2norm_diff = std::sqrt(sumsq_diff);
+ if (l2norm_diff > clip_gradients) {
+ Dtype scale_factor = clip_gradients / l2norm_diff;
+ LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm "
+ << l2norm_diff << " > " << clip_gradients << ") "
+ << "by scale factor " << scale_factor;
+ for (int i = 0; i < net_params.size(); ++i) {
+ net_params[i]->scale_diff(scale_factor);
+ }
+ }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::ApplyUpdate() {
+ CHECK(Caffe::root_solver());
+ Dtype rate = GetLearningRate();
+ if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
+ LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
+ }
+ ClipGradients();
+ for (int param_id = 0; param_id < this->net_->learnable_params().size();
+ ++param_id) {
+ Normalize(param_id);
+ Regularize(param_id);
+ ComputeUpdateValue(param_id, rate);
+ }
+ this->net_->Update();
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::Normalize(int param_id) {
+ if (this->param_.iter_size() == 1) { return; }
+ // Scale gradient to counterbalance accumulation.
+ const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+ const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();
+ switch (Caffe::mode()) {
+ case Caffe::CPU: {
+ caffe_scal(net_params[param_id]->count(), accum_normalization,
+ net_params[param_id]->mutable_cpu_diff());
+ break;
+ }
+ case Caffe::GPU: {
+#ifndef CPU_ONLY
+ caffe_gpu_scal(net_params[param_id]->count(), accum_normalization,
+ net_params[param_id]->mutable_gpu_diff());
+#else
+ NO_GPU;
+#endif
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::Regularize(int param_id) {
+ const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+ const vector<float>& net_params_weight_decay =
+ this->net_->params_weight_decay();
+ Dtype weight_decay = this->param_.weight_decay();
+ string regularization_type = this->param_.regularization_type();
+ Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
+ switch (Caffe::mode()) {
+ case Caffe::CPU: {
+ if (local_decay) {
+ if (regularization_type == "L2") {
+ // add weight decay
+ caffe_axpy(net_params[param_id]->count(),
+ local_decay,
+ net_params[param_id]->cpu_data(),
+ net_params[param_id]->mutable_cpu_diff());
+ } else if (regularization_type == "L1") {
+ caffe_cpu_sign(net_params[param_id]->count(),
+ net_params[param_id]->cpu_data(),
+ temp_[param_id]->mutable_cpu_data());
+ caffe_axpy(net_params[param_id]->count(),
+ local_decay,
+ temp_[param_id]->cpu_data(),
+ net_params[param_id]->mutable_cpu_diff());
+ } else {
+ LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+ }
+ }
+ break;
+ }
+ case Caffe::GPU: {
+#ifndef CPU_ONLY
+ if (local_decay) {
+ if (regularization_type == "L2") {
+ // add weight decay
+ caffe_gpu_axpy(net_params[param_id]->count(),
+ local_decay,
+ net_params[param_id]->gpu_data(),
+ net_params[param_id]->mutable_gpu_diff());
+ } else if (regularization_type == "L1") {
+ caffe_gpu_sign(net_params[param_id]->count(),
+ net_params[param_id]->gpu_data(),
+ temp_[param_id]->mutable_gpu_data());
+ caffe_gpu_axpy(net_params[param_id]->count(),
+ local_decay,
+ temp_[param_id]->gpu_data(),
+ net_params[param_id]->mutable_gpu_diff());
+ } else {
+ LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+ }
+ }
+#else
+ NO_GPU;
+#endif
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+ const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
+ const vector<float>& net_params_lr = this->net_->params_lr();
+ Dtype momentum = this->param_.momentum();
+ Dtype local_rate = rate * net_params_lr[param_id];
+ // Compute the update to history, then copy it to the parameter diff.
+ switch (Caffe::mode()) {
+ case Caffe::CPU: {
+ caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+ net_params[param_id]->cpu_diff(), momentum,
+ history_[param_id]->mutable_cpu_data());
+ caffe_copy(net_params[param_id]->count(),
+ history_[param_id]->cpu_data(),
+ net_params[param_id]->mutable_cpu_diff());
+ break;
+ }
+ case Caffe::GPU: {
+#ifndef CPU_ONLY
+ caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+ net_params[param_id]->gpu_diff(), momentum,
+ history_[param_id]->mutable_gpu_data());
+ caffe_copy(net_params[param_id]->count(),
+ history_[param_id]->gpu_data(),
+ net_params[param_id]->mutable_gpu_diff());
+#else
+ NO_GPU;
+#endif
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::SnapshotSolverState(const string& model_filename) {
+ switch (this->param_.snapshot_format()) {
+ case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
+ SnapshotSolverStateToBinaryProto(model_filename);
+ break;
+ case caffe::SolverParameter_SnapshotFormat_HDF5:
+ SnapshotSolverStateToHDF5(model_filename);
+ break;
+ default:
+ LOG(FATAL) << "Unsupported snapshot format.";
+ }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto(
+ const string& model_filename) {
+ SolverState state;
+ state.set_iter(this->iter_);
+ state.set_learned_net(model_filename);
+ state.set_current_step(this->current_step_);
+ state.clear_history();
+ for (int i = 0; i < history_.size(); ++i) {
+ // Add history
+ BlobProto* history_blob = state.add_history();
+ history_[i]->ToProto(history_blob);
+ }
+ string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate");
+ LOG(INFO)
+ << "Snapshotting solver state to binary proto file " << snapshot_filename;
+ WriteProtoToBinaryFile(state, snapshot_filename.c_str());
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::SnapshotSolverStateToHDF5(
+ const string& model_filename) {
+ string snapshot_filename =
+ Solver<Dtype>::SnapshotFilename(".solverstate.h5");
+ LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename;
+ hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC,
+ H5P_DEFAULT, H5P_DEFAULT);
+ CHECK_GE(file_hid, 0)
+ << "Couldn't open " << snapshot_filename << " to save solver state.";
+ hdf5_save_int(file_hid, "iter", this->iter_);
+ hdf5_save_string(file_hid, "learned_net", model_filename);
+ hdf5_save_int(file_hid, "current_step", this->current_step_);
+ hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT,
+ H5P_DEFAULT);
+ CHECK_GE(history_hid, 0)
+ << "Error saving solver state to " << snapshot_filename << ".";
+ for (int i = 0; i < history_.size(); ++i) {
+ ostringstream oss;
+ oss << i;
+ hdf5_save_nd_dataset<Dtype>(history_hid, oss.str(), *history_[i]);
+ }
+ H5Gclose(history_hid);
+ H5Fclose(file_hid);
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
+ const string& state_file) {
+ SolverState state;
+ ReadProtoFromBinaryFile(state_file, &state);
+ this->iter_ = state.iter();
+ if (state.has_learned_net()) {
+ NetParameter net_param;
+ ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
+ this->net_->CopyTrainedLayersFrom(net_param);
+ }
+ this->current_step_ = state.current_step();
+ CHECK_EQ(state.history_size(), history_.size())
+ << "Incorrect length of history blobs.";
+ LOG(INFO) << "SGDSolver: restoring history";
+ for (int i = 0; i < history_.size(); ++i) {
+ history_[i]->FromProto(state.history(i));
+ }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
+ hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
+ CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;
+ this->iter_ = hdf5_load_int(file_hid, "iter");
+ if (H5LTfind_dataset(file_hid, "learned_net")) {
+ string learned_net = hdf5_load_string(file_hid, "learned_net");
+ this->net_->CopyTrainedLayersFrom(learned_net);
+ }
+ this->current_step_ = hdf5_load_int(file_hid, "current_step");
+ hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT);
+ CHECK_GE(history_hid, 0) << "Error reading history from " << state_file;
+ int state_history_size = hdf5_get_num_links(history_hid);
+ CHECK_EQ(state_history_size, history_.size())
+ << "Incorrect length of history blobs.";
+ for (int i = 0; i < history_.size(); ++i) {
+ ostringstream oss;
+ oss << i;
+ hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), 0,
+ kMaxBlobAxes, history_[i].get());
+ }
+ H5Gclose(history_hid);
+ H5Fclose(file_hid);
+}
+
+INSTANTIATE_CLASS(SGDSolver);
+
+} // namespace caffe
diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp
index 7ad7467f..1767ad3f 100644
--- a/src/caffe/test/test_gradient_based_solver.cpp
+++ b/src/caffe/test/test_gradient_based_solver.cpp
@@ -10,7 +10,7 @@
#include "caffe/common.hpp"
#include "caffe/parallel.hpp"
#include "caffe/proto/caffe.pb.h"
-#include "caffe/solver.hpp"
+#include "caffe/sgd_solvers.hpp"
#include "caffe/util/io.hpp"
#include "caffe/test/test_caffe_main.hpp"
diff --git a/src/caffe/test/test_solver.cpp b/src/caffe/test/test_solver.cpp
index ceabc9cd..b1816426 100644
--- a/src/caffe/test/test_solver.cpp
+++ b/src/caffe/test/test_solver.cpp
@@ -7,6 +7,7 @@
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
+#include "caffe/sgd_solvers.hpp"
#include "caffe/solver.hpp"
#include "caffe/test/test_caffe_main.hpp"