diff options
Diffstat (limited to 'src/caffe/solver.cpp')
-rw-r--r-- | src/caffe/solver.cpp | 37 |
1 files changed, 31 insertions, 6 deletions
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 5c34e487..886c3cc6 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -158,6 +158,7 @@ template <typename Dtype> void Solver<Dtype>::Solve(const char* resume_file) { Caffe::set_phase(Caffe::TRAIN); LOG(INFO) << "Solving " << net_->name(); + LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); PreSolve(); iter_ = 0; @@ -257,7 +258,6 @@ void Solver<Dtype>::TestAll() { } } - template <typename Dtype> void Solver<Dtype>::Test(const int test_net_id) { LOG(INFO) << "Iteration " << iter_ @@ -336,6 +336,7 @@ void Solver<Dtype>::Snapshot() { SnapshotSolverState(&state); state.set_iter(iter_); state.set_learned_net(model_filename); + state.set_current_step(current_step_); snapshot_filename = filename + ".solverstate"; LOG(INFO) << "Snapshotting solver state to " << snapshot_filename; WriteProtoToBinaryFile(state, snapshot_filename.c_str()); @@ -351,6 +352,7 @@ void Solver<Dtype>::Restore(const char* state_file) { net_->CopyTrainedLayersFrom(net_param); } iter_ = state.iter(); + current_step_ = state.current_step(); RestoreSolverState(state); } @@ -361,8 +363,15 @@ void Solver<Dtype>::Restore(const char* state_file) { // - step: return base_lr * gamma ^ (floor(iter / step)) // - exp: return base_lr * gamma ^ iter // - inv: return base_lr * (1 + gamma * iter) ^ (- power) -// where base_lr, gamma, step and power are defined in the solver parameter -// protocol buffer, and iter is the current iteration. +// - multistep: similar to step but it allows non uniform steps defined by +// stepvalue +// - poly: the effective learning rate follows a polynomial decay, to be +// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) +// - sigmoid: the effective learning rate follows a sigmod decay +// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) +// +// where base_lr, max_iter, gamma, step, stepvalue and power are defined +// in the solver parameter protocol buffer, and iter is the current iteration. template <typename Dtype> Dtype SGDSolver<Dtype>::GetLearningRate() { Dtype rate; @@ -370,22 +379,38 @@ Dtype SGDSolver<Dtype>::GetLearningRate() { if (lr_policy == "fixed") { rate = this->param_.base_lr(); } else if (lr_policy == "step") { - int current_step = this->iter_ / this->param_.stepsize(); + this->current_step_ = this->iter_ / this->param_.stepsize(); rate = this->param_.base_lr() * - pow(this->param_.gamma(), current_step); + pow(this->param_.gamma(), this->current_step_); } else if (lr_policy == "exp") { rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_); } else if (lr_policy == "inv") { rate = this->param_.base_lr() * pow(Dtype(1) + this->param_.gamma() * this->iter_, - this->param_.power()); + } else if (lr_policy == "multistep") { + if (this->current_step_ < this->param_.stepvalue_size() && + this->iter_ >= this->param_.stepvalue(this->current_step_)) { + this->current_step_++; + LOG(INFO) << "MultiStep Status: Iteration " << + this->iter_ << ", step = " << this->current_step_; + } + rate = this->param_.base_lr() * + pow(this->param_.gamma(), this->current_step_); + } else if (lr_policy == "poly") { + rate = this->param_.base_lr() * pow(Dtype(1.) - + (Dtype(this->iter_) / Dtype(this->param_.max_iter())), + this->param_.power()); + } else if (lr_policy == "sigmoid") { + rate = this->param_.base_lr() * (Dtype(1.) / + (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) - + Dtype(this->param_.stepsize()))))); } else { LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; } return rate; } - template <typename Dtype> void SGDSolver<Dtype>::PreSolve() { // Initialize the history |