summaryrefslogtreecommitdiff
path: root/src/caffe/solver.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/caffe/solver.cpp')
-rw-r--r--src/caffe/solver.cpp37
1 files changed, 31 insertions, 6 deletions
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 5c34e487..886c3cc6 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -158,6 +158,7 @@ template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
+ LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
PreSolve();
iter_ = 0;
@@ -257,7 +258,6 @@ void Solver<Dtype>::TestAll() {
}
}
-
template <typename Dtype>
void Solver<Dtype>::Test(const int test_net_id) {
LOG(INFO) << "Iteration " << iter_
@@ -336,6 +336,7 @@ void Solver<Dtype>::Snapshot() {
SnapshotSolverState(&state);
state.set_iter(iter_);
state.set_learned_net(model_filename);
+ state.set_current_step(current_step_);
snapshot_filename = filename + ".solverstate";
LOG(INFO) << "Snapshotting solver state to " << snapshot_filename;
WriteProtoToBinaryFile(state, snapshot_filename.c_str());
@@ -351,6 +352,7 @@ void Solver<Dtype>::Restore(const char* state_file) {
net_->CopyTrainedLayersFrom(net_param);
}
iter_ = state.iter();
+ current_step_ = state.current_step();
RestoreSolverState(state);
}
@@ -361,8 +363,15 @@ void Solver<Dtype>::Restore(const char* state_file) {
// - step: return base_lr * gamma ^ (floor(iter / step))
// - exp: return base_lr * gamma ^ iter
// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
-// where base_lr, gamma, step and power are defined in the solver parameter
-// protocol buffer, and iter is the current iteration.
+// - multistep: similar to step but it allows non uniform steps defined by
+// stepvalue
+// - poly: the effective learning rate follows a polynomial decay, to be
+// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
+// - sigmoid: the effective learning rate follows a sigmod decay
+// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
+//
+// where base_lr, max_iter, gamma, step, stepvalue and power are defined
+// in the solver parameter protocol buffer, and iter is the current iteration.
template <typename Dtype>
Dtype SGDSolver<Dtype>::GetLearningRate() {
Dtype rate;
@@ -370,22 +379,38 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
if (lr_policy == "fixed") {
rate = this->param_.base_lr();
} else if (lr_policy == "step") {
- int current_step = this->iter_ / this->param_.stepsize();
+ this->current_step_ = this->iter_ / this->param_.stepsize();
rate = this->param_.base_lr() *
- pow(this->param_.gamma(), current_step);
+ pow(this->param_.gamma(), this->current_step_);
} else if (lr_policy == "exp") {
rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
} else if (lr_policy == "inv") {
rate = this->param_.base_lr() *
pow(Dtype(1) + this->param_.gamma() * this->iter_,
- this->param_.power());
+ } else if (lr_policy == "multistep") {
+ if (this->current_step_ < this->param_.stepvalue_size() &&
+ this->iter_ >= this->param_.stepvalue(this->current_step_)) {
+ this->current_step_++;
+ LOG(INFO) << "MultiStep Status: Iteration " <<
+ this->iter_ << ", step = " << this->current_step_;
+ }
+ rate = this->param_.base_lr() *
+ pow(this->param_.gamma(), this->current_step_);
+ } else if (lr_policy == "poly") {
+ rate = this->param_.base_lr() * pow(Dtype(1.) -
+ (Dtype(this->iter_) / Dtype(this->param_.max_iter())),
+ this->param_.power());
+ } else if (lr_policy == "sigmoid") {
+ rate = this->param_.base_lr() * (Dtype(1.) /
+ (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
+ Dtype(this->param_.stepsize())))));
} else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
}
return rate;
}
-
template <typename Dtype>
void SGDSolver<Dtype>::PreSolve() {
// Initialize the history