summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEren Golge <erogol@hotmail.com>2015-08-08 23:45:08 -0700
committerRonghang Hu <huronghang@hotmail.com>2015-08-08 23:45:08 -0700
commitabe99e8748ad7f583c87d1a6132ff2d79e70dd9c (patch)
tree55b196b5e10f8ed630e79b8e3aa05144ca180652
parenteb3e1149a2fcc9c48d268ffe2319d872081e4c3b (diff)
downloadcaffeonacl-abe99e8748ad7f583c87d1a6132ff2d79e70dd9c.tar.gz
caffeonacl-abe99e8748ad7f583c87d1a6132ff2d79e70dd9c.tar.bz2
caffeonacl-abe99e8748ad7f583c87d1a6132ff2d79e70dd9c.zip
Implement RMSProp Solver
Implement RMSProp solver and cleaned up to adjust to new solver interface that uses accumulated gradients and refactored regularization.
-rw-r--r--examples/mnist/lenet_solver_rmsprop.prototxt27
-rwxr-xr-xexamples/mnist/train_lenet_rmsprop.sh3
-rw-r--r--include/caffe/solver.hpp25
-rw-r--r--src/caffe/proto/caffe.proto25
-rw-r--r--src/caffe/solver.cpp76
-rw-r--r--src/caffe/test/test_gradient_based_solver.cpp245
6 files changed, 353 insertions, 48 deletions
diff --git a/examples/mnist/lenet_solver_rmsprop.prototxt b/examples/mnist/lenet_solver_rmsprop.prototxt
new file mode 100644
index 00000000..74dadc51
--- /dev/null
+++ b/examples/mnist/lenet_solver_rmsprop.prototxt
@@ -0,0 +1,27 @@
+# The train/test net protocol buffer definition
+net: "examples/mnist/lenet_train_test.prototxt"
+# test_iter specifies how many forward passes the test should carry out.
+# In the case of MNIST, we have test batch size 100 and 100 test iterations,
+# covering the full 10,000 testing images.
+test_iter: 100
+# Carry out testing every 500 training iterations.
+test_interval: 500
+# The base learning rate, momentum and the weight decay of the network.
+base_lr: 0.01
+momentum: 0.0
+weight_decay: 0.0005
+# The learning rate policy
+lr_policy: "inv"
+gamma: 0.0001
+power: 0.75
+# Display every 100 iterations
+display: 100
+# The maximum number of iterations
+max_iter: 10000
+# snapshot intermediate results
+snapshot: 5000
+snapshot_prefix: "examples/mnist/lenet_rmsprop"
+# solver mode: CPU or GPU
+solver_mode: GPU
+solver_type: RMSPROP
+rms_decay: 0.98
diff --git a/examples/mnist/train_lenet_rmsprop.sh b/examples/mnist/train_lenet_rmsprop.sh
new file mode 100755
index 00000000..621cab23
--- /dev/null
+++ b/examples/mnist/train_lenet_rmsprop.sh
@@ -0,0 +1,3 @@
+#!/usr/bin/env sh
+
+./build/tools/caffe train --solver=examples/mnist/lenet_solver_rmsprop.prototxt
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp
index 703434b5..fbade938 100644
--- a/include/caffe/solver.hpp
+++ b/include/caffe/solver.hpp
@@ -135,6 +135,29 @@ class AdaGradSolver : public SGDSolver<Dtype> {
DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
};
+
+template <typename Dtype>
+class RMSPropSolver : public SGDSolver<Dtype> {
+ public:
+ explicit RMSPropSolver(const SolverParameter& param)
+ : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
+ explicit RMSPropSolver(const string& param_file)
+ : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
+
+ protected:
+ virtual void ComputeUpdateValue(int param_id, Dtype rate);
+ void constructor_sanity_check() {
+ CHECK_EQ(0, this->param_.momentum())
+ << "Momentum cannot be used with RMSProp.";
+ CHECK_GE(this->param_.rms_decay(), 0)
+ << "rms_decay should lie between 0 and 1.";
+ CHECK_LT(this->param_.rms_decay(), 1)
+ << "rms_decay should lie between 0 and 1.";
+ }
+
+ DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
+};
+
template <typename Dtype>
Solver<Dtype>* GetSolver(const SolverParameter& param) {
SolverParameter_SolverType type = param.solver_type();
@@ -146,6 +169,8 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) {
return new NesterovSolver<Dtype>(param);
case SolverParameter_SolverType_ADAGRAD:
return new AdaGradSolver<Dtype>(param);
+ case SolverParameter_SolverType_RMSPROP:
+ return new RMSPropSolver<Dtype>(param);
default:
LOG(FATAL) << "Unknown SolverType: " << type;
}
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index a13c0e79..89f14595 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -98,7 +98,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
-// SolverParameter next available ID: 38 (last added: snapshot_format)
+// SolverParameter next available ID: 39 (last added: rms_decay)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
@@ -153,7 +153,23 @@ message SolverParameter {
optional int32 max_iter = 7; // the maximum number of iterations
// accumulate gradients over `iter_size` x `batch_size` instances
optional int32 iter_size = 36 [default = 1];
- optional string lr_policy = 8; // The learning rate decay policy.
+
+ // The learning rate decay policy. The currently implemented learning rate
+ // policies are as follows:
+ // - fixed: always return base_lr.
+ // - step: return base_lr * gamma ^ (floor(iter / step))
+ // - exp: return base_lr * gamma ^ iter
+ // - inv: return base_lr * (1 + gamma * iter) ^ (- power)
+ // - multistep: similar to step but it allows non uniform steps defined by
+ // stepvalue
+ // - poly: the effective learning rate follows a polynomial decay, to be
+ // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
+ // - sigmoid: the effective learning rate follows a sigmod decay
+ // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
+ //
+ // where base_lr, max_iter, gamma, step, stepvalue and power are defined
+ // in the solver parameter protocol buffer, and iter is the current iteration.
+ optional string lr_policy = 8;
optional float gamma = 9; // The parameter to compute the learning rate.
optional float power = 10; // The parameter to compute the learning rate.
optional float momentum = 11; // The momentum value.
@@ -198,11 +214,16 @@ message SolverParameter {
SGD = 0;
NESTEROV = 1;
ADAGRAD = 2;
+ RMSPROP = 3;
}
optional SolverType solver_type = 30 [default = SGD];
// numerical stability for AdaGrad
optional float delta = 31 [default = 1e-8];
+ // RMSProp decay value
+ // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
+ optional float rms_decay = 38;
+
// If true, print information about the state of the net that may help with
// debugging learning problems.
optional bool debug_info = 23 [default = false];
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 32276ac1..43834c0c 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -859,9 +859,85 @@ void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}
}
+template <typename Dtype>
+void RMSPropSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+ const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
+ const vector<float>& net_params_lr = this->net_->params_lr();
+
+ // get the learning rate
+ Dtype delta = this->param_.delta();
+ Dtype rms_decay = this->param_.rms_decay();
+ Dtype local_rate = rate * net_params_lr[param_id];
+
+ switch (Caffe::mode()) {
+ case Caffe::CPU:
+ // compute square of gradient in update
+ caffe_powx(net_params[param_id]->count(),
+ net_params[param_id]->cpu_diff(), Dtype(2),
+ this->update_[param_id]->mutable_cpu_data());
+
+ // update history
+ caffe_cpu_axpby(net_params[param_id] -> count(),
+ Dtype(1-rms_decay), this->update_[param_id]->cpu_data(),
+ rms_decay, this->history_[param_id]-> mutable_cpu_data());
+
+ // prepare update
+ caffe_powx(net_params[param_id]->count(),
+ this->history_[param_id]->cpu_data(), Dtype(0.5),
+ this->update_[param_id]->mutable_cpu_data());
+
+ caffe_add_scalar(net_params[param_id]->count(),
+ delta, this->update_[param_id]->mutable_cpu_data());
+
+ caffe_div(net_params[param_id]->count(),
+ net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(),
+ this->update_[param_id]->mutable_cpu_data());
+
+ // scale and copy
+ caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
+ this->update_[param_id]->cpu_data(), Dtype(0),
+ net_params[param_id]->mutable_cpu_diff());
+ break;
+ case Caffe::GPU:
+#ifndef CPU_ONLY
+ // compute square of gradient in update
+ caffe_gpu_powx(net_params[param_id]->count(),
+ net_params[param_id]->gpu_diff(), Dtype(2),
+ this->update_[param_id]->mutable_gpu_data());
+
+ // update history
+ caffe_gpu_axpby(net_params[param_id] -> count(),
+ Dtype(1-rms_decay), this->update_[param_id]->gpu_data(),
+ rms_decay, this->history_[param_id]-> mutable_gpu_data());
+
+ // prepare update
+ caffe_gpu_powx(net_params[param_id]->count(),
+ this->history_[param_id]->gpu_data(), Dtype(0.5),
+ this->update_[param_id]->mutable_gpu_data());
+
+ caffe_gpu_add_scalar(net_params[param_id]->count(),
+ delta, this->update_[param_id]->mutable_gpu_data());
+
+ caffe_gpu_div(net_params[param_id]->count(),
+ net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(),
+ this->update_[param_id]->mutable_gpu_data());
+
+ caffe_gpu_axpby(net_params[param_id]->count(), local_rate,
+ this->update_[param_id]->gpu_data(), Dtype(0),
+ net_params[param_id]->mutable_gpu_diff());
+#else
+ NO_GPU;
+#endif
+ break;
+ default:
+ LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ }
+}
+
INSTANTIATE_CLASS(Solver);
INSTANTIATE_CLASS(SGDSolver);
INSTANTIATE_CLASS(NesterovSolver);
INSTANTIATE_CLASS(AdaGradSolver);
+INSTANTIATE_CLASS(RMSPropSolver);
} // namespace caffe
diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp
index 7bb0ec18..b0918922 100644
--- a/src/caffe/test/test_gradient_based_solver.cpp
+++ b/src/caffe/test/test_gradient_based_solver.cpp
@@ -52,13 +52,14 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode();
}
InitSolver(param);
- delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD) ?
- param.delta() : 0;
+ delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD ||
+ solver_type() == SolverParameter_SolverType_RMSPROP) ?
+ param.delta() : 0;
}
string RunLeastSquaresSolver(const Dtype learning_rate,
- const Dtype weight_decay, const Dtype momentum, const int num_iters,
- const int iter_size = 1, const bool snapshot = false,
+ const Dtype weight_decay, const Dtype momentum, const Dtype rms_decay,
+ const int num_iters, const int iter_size = 1, const bool snapshot = false,
const char* from_snapshot = NULL) {
ostringstream proto;
proto <<
@@ -173,6 +174,9 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
if (momentum != 0) {
proto << "momentum: " << momentum << " ";
}
+ if (rms_decay != 0) {
+ proto << "rms_decay: " << rms_decay << " ";
+ }
MakeTempDir(&snapshot_prefix_);
proto << "snapshot_prefix: '" << snapshot_prefix_ << "/' ";
if (snapshot) {
@@ -204,7 +208,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
// updated_params will store the updated weight and bias results,
// using the blobs' diffs to hold the update values themselves.
void ComputeLeastSquaresUpdate(const Dtype learning_rate,
- const Dtype weight_decay, const Dtype momentum,
+ const Dtype weight_decay, const Dtype momentum, const Dtype rms_decay,
vector<shared_ptr<Blob<Dtype> > >* updated_params) {
const int N = num_;
const int D = channels_ * height_ * width_;
@@ -287,6 +291,10 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
case SolverParameter_SolverType_ADAGRAD:
update_value /= std::sqrt(history_value + grad * grad) + delta_;
break;
+ case SolverParameter_SolverType_RMSPROP:
+ update_value /= std::sqrt(rms_decay*history_value
+ + grad * grad * (1 - rms_decay)) + delta_;
+ break;
default:
LOG(FATAL) << "Unknown solver type: " << solver_type();
}
@@ -352,13 +360,14 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
}
void CheckAccumulation(const Dtype kLearningRate, const Dtype kWeightDecay,
- const Dtype kMomentum, const int kNumIters, const int kIterSize) {
+ const Dtype kMomentum, const Dtype rms_decay, const int kNumIters,
+ const int kIterSize) {
const double kPrecision = 1e-2;
const double kMinPrecision = 1e-7;
constant_data_ = true;
// Solve without accumulation and save parameters.
this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
- kNumIters);
+ rms_decay, kNumIters);
// Save parameters for comparison.
Net<Dtype>& net = *this->solver_->net();
const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
@@ -370,7 +379,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
}
// Solve by equivalent accumulation of gradients over divided batches.
this->RunLeastSquaresSolver(kLearningRate, kWeightDecay, kMomentum,
- kNumIters, kIterSize);
+ rms_decay, kNumIters, kIterSize);
Net<Dtype>& net_accum = *this->solver_->net();
const vector<shared_ptr<Blob<Dtype> > >& accum_params =
net_accum.layer_by_name("innerprod")->blobs();
@@ -408,18 +417,19 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
// matches the solver's (K+1)th update.
void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0,
const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
- const int iter_to_check = 0) {
+ const Dtype rms_decay = 0.0, const int iter_to_check = 0) {
// Initialize the solver and run K (= iter_to_check) solver iterations.
- RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check);
+ RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay,
+ iter_to_check);
// Compute the (K+1)th update using the analytic least squares gradient.
vector<shared_ptr<Blob<Dtype> > > updated_params;
ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,
- &updated_params);
+ rms_decay, &updated_params);
// Reinitialize the solver and run K+1 solver iterations.
- RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
- iter_to_check + 1);
+ RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay,
+ iter_to_check + 1);
// Check that the solver's solution matches ours.
CheckLeastSquaresUpdate(updated_params);
@@ -427,12 +437,12 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
void TestSnapshot(const Dtype learning_rate = 1.0,
const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
- const int num_iters = 1) {
+ const Dtype rms_decay = 0.0, const int num_iters = 1) {
// Run the solver for num_iters * 2 iterations.
const int total_num_iters = num_iters * 2;
bool snapshot = false;
const int kIterSize = 1;
- RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
+ RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay,
total_num_iters, kIterSize, snapshot);
// Save the resulting param values.
@@ -463,12 +473,12 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
// Run the solver for num_iters iterations and snapshot.
snapshot = true;
string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay,
- momentum, num_iters, kIterSize, snapshot);
+ momentum, rms_decay, num_iters, kIterSize, snapshot);
// Reinitialize the solver and run for num_iters more iterations.
snapshot = false;
- RunLeastSquaresSolver(learning_rate, weight_decay,
- momentum, total_num_iters, kIterSize, snapshot, snapshot_name.c_str());
+ RunLeastSquaresSolver(learning_rate, weight_decay, momentum, rms_decay,
+ total_num_iters, kIterSize, snapshot, snapshot_name.c_str());
// Check that params now match.
const vector<Blob<Dtype>*>& params = solver_->net()->learnable_params();
@@ -548,9 +558,11 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentum) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0;
const Dtype kMomentum = 0.5;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 1;
for (int i = 0; i <= kNumIters; ++i) {
- this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+ kRMSDecay, i);
}
}
@@ -559,9 +571,11 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0;
const Dtype kMomentum = 0.5;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
for (int i = 0; i <= kNumIters; ++i) {
- this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+ kRMSDecay, i);
}
}
@@ -570,9 +584,11 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverything) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.5;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
for (int i = 0; i <= kNumIters; ++i) {
- this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+ kRMSDecay, i);
}
}
@@ -581,10 +597,12 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingShare) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.5;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
this->share_ = true;
for (int i = 0; i <= kNumIters; ++i) {
- this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+ kRMSDecay, i);
}
}
@@ -593,10 +611,11 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
const int kIterSize = 2;
- this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
- kIterSize);
+ this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+ kNumIters, kIterSize);
}
TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
@@ -604,11 +623,12 @@ TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
const int kIterSize = 2;
this->share_ = true;
- this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
- kIterSize);
+ this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+ kNumIters, kIterSize);
}
TYPED_TEST(SGDSolverTest, TestSnapshot) {
@@ -616,9 +636,10 @@ TYPED_TEST(SGDSolverTest, TestSnapshot) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
for (int i = 1; i <= kNumIters; ++i) {
- this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
}
}
@@ -627,10 +648,11 @@ TYPED_TEST(SGDSolverTest, TestSnapshotShare) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
this->share_ = true;
for (int i = 1; i <= kNumIters; ++i) {
- this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
}
}
@@ -672,22 +694,26 @@ TYPED_TEST(AdaGradSolverTest, TestAdaGradLeastSquaresUpdateWithEverything) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
for (int i = 0; i <= kNumIters; ++i) {
- this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+ kRMSDecay, i);
}
}
TYPED_TEST(AdaGradSolverTest,
- TestAdaGradLeastSquaresUpdateWithEverythingShare) {
+ TestAdaGradLeastSquaresUpdateWithEverythingShare) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
this->share_ = true;
for (int i = 0; i <= kNumIters; ++i) {
- this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+ kRMSDecay, i);
}
}
@@ -696,10 +722,11 @@ TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
const int kIterSize = 2;
- this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
- kIterSize);
+ this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+ kNumIters, kIterSize);
}
TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
@@ -707,11 +734,12 @@ TYPED_TEST(AdaGradSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
const int kIterSize = 2;
this->share_ = true;
- this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
- kIterSize);
+ this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+ kNumIters, kIterSize);
}
TYPED_TEST(AdaGradSolverTest, TestSnapshot) {
@@ -719,9 +747,10 @@ TYPED_TEST(AdaGradSolverTest, TestSnapshot) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
for (int i = 1; i <= kNumIters; ++i) {
- this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
}
}
@@ -730,10 +759,11 @@ TYPED_TEST(AdaGradSolverTest, TestSnapshotShare) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
this->share_ = true;
for (int i = 1; i <= kNumIters; ++i) {
- this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
}
}
@@ -787,9 +817,11 @@ TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithMomentum) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0;
const Dtype kMomentum = 0.5;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 1;
for (int i = 0; i <= kNumIters; ++i) {
- this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+ kRMSDecay, i);
}
}
@@ -798,9 +830,11 @@ TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0;
const Dtype kMomentum = 0.5;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
for (int i = 0; i <= kNumIters; ++i) {
- this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+ kRMSDecay, i);
}
}
@@ -821,10 +855,12 @@ TYPED_TEST(NesterovSolverTest,
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
this->share_ = true;
for (int i = 0; i <= kNumIters; ++i) {
- this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+ kRMSDecay, i);
}
}
@@ -833,10 +869,11 @@ TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
const int kIterSize = 2;
- this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
- kIterSize);
+ this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+ kNumIters, kIterSize);
}
TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
@@ -844,11 +881,12 @@ TYPED_TEST(NesterovSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
const int kIterSize = 2;
this->share_ = true;
- this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters,
- kIterSize);
+ this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+ kNumIters, kIterSize);
}
TYPED_TEST(NesterovSolverTest, TestSnapshot) {
@@ -856,9 +894,10 @@ TYPED_TEST(NesterovSolverTest, TestSnapshot) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
+ const Dtype kRMSDecay = 0;
const int kNumIters = 4;
for (int i = 1; i <= kNumIters; ++i) {
- this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
}
}
@@ -867,10 +906,124 @@ TYPED_TEST(NesterovSolverTest, TestSnapshotShare) {
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.5;
const Dtype kMomentum = 0.9;
+ const Dtype kRMSDecay = 0;
+ const int kNumIters = 4;
+ this->share_ = true;
+ for (int i = 1; i <= kNumIters; ++i) {
+ this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
+ }
+}
+
+template <typename TypeParam>
+class RMSPropSolverTest : public GradientBasedSolverTest<TypeParam> {
+ typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+ virtual void InitSolver(const SolverParameter& param) {
+ this->solver_.reset(new RMSPropSolver<Dtype>(param));
+ }
+ virtual SolverParameter_SolverType solver_type() {
+ return SolverParameter_SolverType_RMSPROP;
+ }
+};
+
+TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices);
+
+TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithWeightDecay) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kLearningRate = 1.0;
+ const Dtype kWeightDecay = 0.5;
+ this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);
+}
+
+TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithRmsDecay) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kLearningRate = 0.01;
+ const Dtype kWeightDecay = 0.0;
+ const Dtype kMomentum = 0.0;
+ const Dtype kRMSDecay = 0.95;
+ const int kNumIters = 4;
+ for (int i = 0; i <= kNumIters; ++i) {
+ this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+ kRMSDecay, i);
+ }
+}
+
+TYPED_TEST(RMSPropSolverTest, TestRMSPropLeastSquaresUpdateWithEverything) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kLearningRate = 0.01;
+ const Dtype kWeightDecay = 0.5;
+ const Dtype kMomentum = 0.0;
+ const Dtype kRMSDecay = 0.95;
+ const int kNumIters = 4;
+ for (int i = 0; i <= kNumIters; ++i) {
+ this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+ kRMSDecay, i);
+ }
+}
+
+TYPED_TEST(RMSPropSolverTest,
+ TestRMSPropLeastSquaresUpdateWithEverythingShare) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kLearningRate = 0.01;
+ const Dtype kWeightDecay = 0.5;
+ const Dtype kMomentum = 0.0;
+ const Dtype kRMSDecay = 0.95;
+ const int kNumIters = 4;
+ this->share_ = true;
+ for (int i = 0; i <= kNumIters; ++i) {
+ this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum,
+ kRMSDecay, i);
+ }
+}
+
+TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccum) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kLearningRate = 0.01;
+ const Dtype kWeightDecay = 0.5;
+ const Dtype kMomentum = 0.0;
+ const Dtype kRMSDecay = 0.95;
+ const int kNumIters = 4;
+ const int kIterSize = 2;
+ this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+ kNumIters, kIterSize);
+}
+
+TYPED_TEST(RMSPropSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kLearningRate = 0.01;
+ const Dtype kWeightDecay = 0.5;
+ const Dtype kMomentum = 0.0;
+ const Dtype kRMSDecay = 0.95;
+ const int kNumIters = 4;
+ const int kIterSize = 2;
+ this->share_ = true;
+ this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kRMSDecay,
+ kNumIters, kIterSize);
+}
+
+TYPED_TEST(RMSPropSolverTest, TestSnapshot) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kLearningRate = 0.01;
+ const Dtype kWeightDecay = 0.5;
+ const Dtype kMomentum = 0;
+ const Dtype kRMSDecay = 0.95;
+ const int kNumIters = 4;
+ for (int i = 1; i <= kNumIters; ++i) {
+ this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
+ }
+}
+
+TYPED_TEST(RMSPropSolverTest, TestSnapshotShare) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kLearningRate = 0.01;
+ const Dtype kWeightDecay = 0.5;
+ const Dtype kMomentum = 0;
+ const Dtype kRMSDecay = 0.95;
const int kNumIters = 4;
this->share_ = true;
for (int i = 1; i <= kNumIters; ++i) {
- this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i);
+ this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, kRMSDecay, i);
}
}