diff options
Diffstat (limited to 'include/caffe/solver.hpp')
-rw-r--r-- | include/caffe/solver.hpp | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 703434b5..fbade938 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -135,6 +135,29 @@ class AdaGradSolver : public SGDSolver<Dtype> { DISABLE_COPY_AND_ASSIGN(AdaGradSolver); }; + +template <typename Dtype> +class RMSPropSolver : public SGDSolver<Dtype> { + public: + explicit RMSPropSolver(const SolverParameter& param) + : SGDSolver<Dtype>(param) { constructor_sanity_check(); } + explicit RMSPropSolver(const string& param_file) + : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); } + + protected: + virtual void ComputeUpdateValue(int param_id, Dtype rate); + void constructor_sanity_check() { + CHECK_EQ(0, this->param_.momentum()) + << "Momentum cannot be used with RMSProp."; + CHECK_GE(this->param_.rms_decay(), 0) + << "rms_decay should lie between 0 and 1."; + CHECK_LT(this->param_.rms_decay(), 1) + << "rms_decay should lie between 0 and 1."; + } + + DISABLE_COPY_AND_ASSIGN(RMSPropSolver); +}; + template <typename Dtype> Solver<Dtype>* GetSolver(const SolverParameter& param) { SolverParameter_SolverType type = param.solver_type(); @@ -146,6 +169,8 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) { return new NesterovSolver<Dtype>(param); case SolverParameter_SolverType_ADAGRAD: return new AdaGradSolver<Dtype>(param); + case SolverParameter_SolverType_RMSPROP: + return new RMSPropSolver<Dtype>(param); default: LOG(FATAL) << "Unknown SolverType: " << type; } |