summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/lenet/lenet_multistep_solver.prototxt33
-rw-r--r--examples/lenet/lenet_stepearly_solver.prototxt28
-rw-r--r--include/caffe/solver.hpp1
-rw-r--r--src/caffe/proto/caffe.proto8
-rw-r--r--src/caffe/solver.cpp37
5 files changed, 99 insertions, 8 deletions
diff --git a/examples/lenet/lenet_multistep_solver.prototxt b/examples/lenet/lenet_multistep_solver.prototxt
new file mode 100644
index 00000000..fadd7c90
--- /dev/null
+++ b/examples/lenet/lenet_multistep_solver.prototxt
@@ -0,0 +1,33 @@
+# The training protocol buffer definition
+train_net: "lenet_train.prototxt"
+# The testing protocol buffer definition
+test_net: "lenet_test.prototxt"
+# test_iter specifies how many forward passes the test should carry out.
+# In the case of MNIST, we have test batch size 100 and 100 test iterations,
+# covering the full 10,000 testing images.
+test_iter: 100
+# Carry out testing every 500 training iterations.
+test_interval: 500
+# The base learning rate, momentum and the weight decay of the network.
+base_lr: 0.01
+momentum: 0.9
+weight_decay: 0.0005
+# The learning rate policy
+lr_policy: "multistep"
+gamma: 0.9
+stepvalue: 1000
+stepvalue: 2000
+stepvalue: 2500
+stepvalue: 3000
+stepvalue: 3500
+stepvalue: 4000
+# Display every 100 iterations
+display: 100
+# The maximum number of iterations
+max_iter: 10000
+# snapshot intermediate results
+snapshot: 5000
+snapshot_prefix: "lenet"
+# solver mode: 0 for CPU and 1 for GPU
+solver_mode: 1
+device_id: 1
diff --git a/examples/lenet/lenet_stepearly_solver.prototxt b/examples/lenet/lenet_stepearly_solver.prototxt
new file mode 100644
index 00000000..efc6a335
--- /dev/null
+++ b/examples/lenet/lenet_stepearly_solver.prototxt
@@ -0,0 +1,28 @@
+# The training protocol buffer definition
+train_net: "lenet_train.prototxt"
+# The testing protocol buffer definition
+test_net: "lenet_test.prototxt"
+# test_iter specifies how many forward passes the test should carry out.
+# In the case of MNIST, we have test batch size 100 and 100 test iterations,
+# covering the full 10,000 testing images.
+test_iter: 100
+# Carry out testing every 500 training iterations.
+test_interval: 500
+# The base learning rate, momentum and the weight decay of the network.
+base_lr: 0.01
+momentum: 0.9
+weight_decay: 0.0005
+# The learning rate policy
+lr_policy: "stepearly"
+gamma: 0.9
+stepearly: 1
+# Display every 100 iterations
+display: 100
+# The maximum number of iterations
+max_iter: 10000
+# snapshot intermediate results
+snapshot: 5000
+snapshot_prefix: "lenet"
+# solver mode: 0 for CPU and 1 for GPU
+solver_mode: 1
+device_id: 1
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp
index 6fd159d0..51aebb32 100644
--- a/include/caffe/solver.hpp
+++ b/include/caffe/solver.hpp
@@ -56,6 +56,7 @@ class Solver {
SolverParameter param_;
int iter_;
+ int current_step_;
shared_ptr<Net<Dtype> > net_;
vector<shared_ptr<Net<Dtype> > > test_nets_;
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index a789aeef..88f670fd 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -63,7 +63,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
-// SolverParameter next available ID: 34 (last added: average_loss)
+// SolverParameter next available ID: 35 (last added: stepvalue)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
@@ -124,7 +124,10 @@ message SolverParameter {
// regularization types supported: L1 and L2
// controlled by weight_decay
optional string regularization_type = 29 [default = "L2"];
- optional int32 stepsize = 13; // the stepsize for learning rate policy "step"
+ // the stepsize for learning rate policy "step"
+ optional int32 stepsize = 13;
+ // the stepsize for learning rate policy "multistep"
+ repeated int32 stepvalue = 34;
optional int32 snapshot = 14 [default = 0]; // The snapshot interval
optional string snapshot_prefix = 15; // The prefix for the snapshot.
// whether to snapshot diff in the results or not. Snapshotting diff will help
@@ -166,6 +169,7 @@ message SolverState {
optional int32 iter = 1; // The current iteration
optional string learned_net = 2; // The file that stores the learned net.
repeated BlobProto history = 3; // The history for sgd solvers
+ optional int32 current_step = 4 [default = 0]; // The current step for learning rate
}
enum Phase {
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 5c34e487..886c3cc6 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -158,6 +158,7 @@ template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
+ LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
PreSolve();
iter_ = 0;
@@ -257,7 +258,6 @@ void Solver<Dtype>::TestAll() {
}
}
-
template <typename Dtype>
void Solver<Dtype>::Test(const int test_net_id) {
LOG(INFO) << "Iteration " << iter_
@@ -336,6 +336,7 @@ void Solver<Dtype>::Snapshot() {
SnapshotSolverState(&state);
state.set_iter(iter_);
state.set_learned_net(model_filename);
+ state.set_current_step(current_step_);
snapshot_filename = filename + ".solverstate";
LOG(INFO) << "Snapshotting solver state to " << snapshot_filename;
WriteProtoToBinaryFile(state, snapshot_filename.c_str());
@@ -351,6 +352,7 @@ void Solver<Dtype>::Restore(const char* state_file) {
net_->CopyTrainedLayersFrom(net_param);
}
iter_ = state.iter();
+ current_step_ = state.current_step();
RestoreSolverState(state);
}
@@ -361,8 +363,15 @@ void Solver<Dtype>::Restore(const char* state_file) {
// - step: return base_lr * gamma ^ (floor(iter / step))
// - exp: return base_lr * gamma ^ iter
// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
-// where base_lr, gamma, step and power are defined in the solver parameter
-// protocol buffer, and iter is the current iteration.
+// - multistep: similar to step but it allows non uniform steps defined by
+// stepvalue
+// - poly: the effective learning rate follows a polynomial decay, to be
+// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
+// - sigmoid: the effective learning rate follows a sigmod decay
+// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
+//
+// where base_lr, max_iter, gamma, step, stepvalue and power are defined
+// in the solver parameter protocol buffer, and iter is the current iteration.
template <typename Dtype>
Dtype SGDSolver<Dtype>::GetLearningRate() {
Dtype rate;
@@ -370,22 +379,38 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
if (lr_policy == "fixed") {
rate = this->param_.base_lr();
} else if (lr_policy == "step") {
- int current_step = this->iter_ / this->param_.stepsize();
+ this->current_step_ = this->iter_ / this->param_.stepsize();
rate = this->param_.base_lr() *
- pow(this->param_.gamma(), current_step);
+ pow(this->param_.gamma(), this->current_step_);
} else if (lr_policy == "exp") {
rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
} else if (lr_policy == "inv") {
rate = this->param_.base_lr() *
pow(Dtype(1) + this->param_.gamma() * this->iter_,
- this->param_.power());
+ } else if (lr_policy == "multistep") {
+ if (this->current_step_ < this->param_.stepvalue_size() &&
+ this->iter_ >= this->param_.stepvalue(this->current_step_)) {
+ this->current_step_++;
+ LOG(INFO) << "MultiStep Status: Iteration " <<
+ this->iter_ << ", step = " << this->current_step_;
+ }
+ rate = this->param_.base_lr() *
+ pow(this->param_.gamma(), this->current_step_);
+ } else if (lr_policy == "poly") {
+ rate = this->param_.base_lr() * pow(Dtype(1.) -
+ (Dtype(this->iter_) / Dtype(this->param_.max_iter())),
+ this->param_.power());
+ } else if (lr_policy == "sigmoid") {
+ rate = this->param_.base_lr() * (Dtype(1.) /
+ (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
+ Dtype(this->param_.stepsize())))));
} else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
}
return rate;
}
-
template <typename Dtype>
void SGDSolver<Dtype>::PreSolve() {
// Initialize the history