diff options
-rw-r--r-- | examples/lenet/lenet_multistep_solver.prototxt | 33 | ||||
-rw-r--r-- | examples/lenet/lenet_stepearly_solver.prototxt | 28 | ||||
-rw-r--r-- | include/caffe/solver.hpp | 1 | ||||
-rw-r--r-- | src/caffe/proto/caffe.proto | 8 | ||||
-rw-r--r-- | src/caffe/solver.cpp | 37 |
5 files changed, 99 insertions, 8 deletions
diff --git a/examples/lenet/lenet_multistep_solver.prototxt b/examples/lenet/lenet_multistep_solver.prototxt new file mode 100644 index 00000000..fadd7c90 --- /dev/null +++ b/examples/lenet/lenet_multistep_solver.prototxt @@ -0,0 +1,33 @@ +# The training protocol buffer definition +train_net: "lenet_train.prototxt" +# The testing protocol buffer definition +test_net: "lenet_test.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of MNIST, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 100 +# Carry out testing every 500 training iterations. +test_interval: 500 +# The base learning rate, momentum and the weight decay of the network. +base_lr: 0.01 +momentum: 0.9 +weight_decay: 0.0005 +# The learning rate policy +lr_policy: "multistep" +gamma: 0.9 +stepvalue: 1000 +stepvalue: 2000 +stepvalue: 2500 +stepvalue: 3000 +stepvalue: 3500 +stepvalue: 4000 +# Display every 100 iterations +display: 100 +# The maximum number of iterations +max_iter: 10000 +# snapshot intermediate results +snapshot: 5000 +snapshot_prefix: "lenet" +# solver mode: 0 for CPU and 1 for GPU +solver_mode: 1 +device_id: 1 diff --git a/examples/lenet/lenet_stepearly_solver.prototxt b/examples/lenet/lenet_stepearly_solver.prototxt new file mode 100644 index 00000000..efc6a335 --- /dev/null +++ b/examples/lenet/lenet_stepearly_solver.prototxt @@ -0,0 +1,28 @@ +# The training protocol buffer definition +train_net: "lenet_train.prototxt" +# The testing protocol buffer definition +test_net: "lenet_test.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of MNIST, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 100 +# Carry out testing every 500 training iterations. +test_interval: 500 +# The base learning rate, momentum and the weight decay of the network. +base_lr: 0.01 +momentum: 0.9 +weight_decay: 0.0005 +# The learning rate policy +lr_policy: "stepearly" +gamma: 0.9 +stepearly: 1 +# Display every 100 iterations +display: 100 +# The maximum number of iterations +max_iter: 10000 +# snapshot intermediate results +snapshot: 5000 +snapshot_prefix: "lenet" +# solver mode: 0 for CPU and 1 for GPU +solver_mode: 1 +device_id: 1 diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 2bad0e2f..b2091260 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -57,6 +57,7 @@ class Solver { SolverParameter param_; int iter_; + int current_step_; shared_ptr<Net<Dtype> > net_; vector<shared_ptr<Net<Dtype> > > test_nets_; diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index f0404a09..949bafec 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -63,7 +63,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 34 (last added: average_loss) +// SolverParameter next available ID: 35 (last added: stepvalue) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -124,7 +124,10 @@ message SolverParameter { // regularization types supported: L1 and L2 // controlled by weight_decay optional string regularization_type = 29 [default = "L2"]; - optional int32 stepsize = 13; // the stepsize for learning rate policy "step" + // the stepsize for learning rate policy "step" + optional int32 stepsize = 13; + // the stepsize for learning rate policy "multistep" + repeated int32 stepvalue = 34; optional int32 snapshot = 14 [default = 0]; // The snapshot interval optional string snapshot_prefix = 15; // The prefix for the snapshot. // whether to snapshot diff in the results or not. Snapshotting diff will help @@ -166,6 +169,7 @@ message SolverState { optional int32 iter = 1; // The current iteration optional string learned_net = 2; // The file that stores the learned net. repeated BlobProto history = 3; // The history for sgd solvers + optional int32 current_step = 4 [default = 0]; // The current step for learning rate } enum Phase { diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 995df06d..a13bca8b 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -158,6 +158,7 @@ template <typename Dtype> void Solver<Dtype>::Solve(const char* resume_file) { Caffe::set_phase(Caffe::TRAIN); LOG(INFO) << "Solving " << net_->name(); + LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); PreSolve(); iter_ = 0; @@ -257,7 +258,6 @@ void Solver<Dtype>::TestAll() { } } - template <typename Dtype> void Solver<Dtype>::Test(const int test_net_id) { LOG(INFO) << "Iteration " << iter_ @@ -336,6 +336,7 @@ void Solver<Dtype>::Snapshot() { SnapshotSolverState(&state); state.set_iter(iter_); state.set_learned_net(model_filename); + state.set_current_step(current_step_); snapshot_filename = filename + ".solverstate"; LOG(INFO) << "Snapshotting solver state to " << snapshot_filename; WriteProtoToBinaryFile(state, snapshot_filename.c_str()); @@ -351,6 +352,7 @@ void Solver<Dtype>::Restore(const char* state_file) { net_->CopyTrainedLayersFrom(net_param); } iter_ = state.iter(); + current_step_ = state.current_step(); RestoreSolverState(state); } @@ -361,8 +363,15 @@ void Solver<Dtype>::Restore(const char* state_file) { // - step: return base_lr * gamma ^ (floor(iter / step)) // - exp: return base_lr * gamma ^ iter // - inv: return base_lr * (1 + gamma * iter) ^ (- power) -// where base_lr, gamma, step and power are defined in the solver parameter -// protocol buffer, and iter is the current iteration. +// - multistep: similar to step but it allows non uniform steps defined by +// stepvalue +// - poly: the effective learning rate follows a polynomial decay, to be +// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) +// - sigmoid: the effective learning rate follows a sigmod decay +// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) +// +// where base_lr, max_iter, gamma, step, stepvalue and power are defined +// in the solver parameter protocol buffer, and iter is the current iteration. template <typename Dtype> Dtype SGDSolver<Dtype>::GetLearningRate() { Dtype rate; @@ -370,22 +379,38 @@ Dtype SGDSolver<Dtype>::GetLearningRate() { if (lr_policy == "fixed") { rate = this->param_.base_lr(); } else if (lr_policy == "step") { - int current_step = this->iter_ / this->param_.stepsize(); + this->current_step_ = this->iter_ / this->param_.stepsize(); rate = this->param_.base_lr() * - pow(this->param_.gamma(), current_step); + pow(this->param_.gamma(), this->current_step_); } else if (lr_policy == "exp") { rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_); } else if (lr_policy == "inv") { rate = this->param_.base_lr() * pow(Dtype(1) + this->param_.gamma() * this->iter_, - this->param_.power()); + } else if (lr_policy == "multistep") { + if (this->current_step_ < this->param_.stepvalue_size() && + this->iter_ >= this->param_.stepvalue(this->current_step_)) { + this->current_step_++; + LOG(INFO) << "MultiStep Status: Iteration " << + this->iter_ << ", step = " << this->current_step_; + } + rate = this->param_.base_lr() * + pow(this->param_.gamma(), this->current_step_); + } else if (lr_policy == "poly") { + rate = this->param_.base_lr() * pow(Dtype(1.) - + (Dtype(this->iter_) / Dtype(this->param_.max_iter())), + this->param_.power()); + } else if (lr_policy == "sigmoid") { + rate = this->param_.base_lr() * (Dtype(1.) / + (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) - + Dtype(this->param_.stepsize()))))); } else { LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; } return rate; } - template <typename Dtype> void SGDSolver<Dtype>::PreSolve() { // Initialize the history |