summaryrefslogtreecommitdiff
path: root/include/caffe/solver.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/caffe/solver.hpp')
-rw-r--r--include/caffe/solver.hpp25
1 files changed, 25 insertions, 0 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp
index 703434b5..fbade938 100644
--- a/include/caffe/solver.hpp
+++ b/include/caffe/solver.hpp
@@ -135,6 +135,29 @@ class AdaGradSolver : public SGDSolver<Dtype> {
DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
};
+
+template <typename Dtype>
+class RMSPropSolver : public SGDSolver<Dtype> {
+ public:
+ explicit RMSPropSolver(const SolverParameter& param)
+ : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
+ explicit RMSPropSolver(const string& param_file)
+ : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
+
+ protected:
+ virtual void ComputeUpdateValue(int param_id, Dtype rate);
+ void constructor_sanity_check() {
+ CHECK_EQ(0, this->param_.momentum())
+ << "Momentum cannot be used with RMSProp.";
+ CHECK_GE(this->param_.rms_decay(), 0)
+ << "rms_decay should lie between 0 and 1.";
+ CHECK_LT(this->param_.rms_decay(), 1)
+ << "rms_decay should lie between 0 and 1.";
+ }
+
+ DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
+};
+
template <typename Dtype>
Solver<Dtype>* GetSolver(const SolverParameter& param) {
SolverParameter_SolverType type = param.solver_type();
@@ -146,6 +169,8 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) {
return new NesterovSolver<Dtype>(param);
case SolverParameter_SolverType_ADAGRAD:
return new AdaGradSolver<Dtype>(param);
+ case SolverParameter_SolverType_RMSPROP:
+ return new RMSPropSolver<Dtype>(param);
default:
LOG(FATAL) << "Unknown SolverType: " << type;
}