diff options
author | Eren Golge <erogol@hotmail.com> | 2015-08-08 23:45:08 -0700 |
---|---|---|
committer | Ronghang Hu <huronghang@hotmail.com> | 2015-08-08 23:45:08 -0700 |
commit | abe99e8748ad7f583c87d1a6132ff2d79e70dd9c (patch) | |
tree | 55b196b5e10f8ed630e79b8e3aa05144ca180652 /include | |
parent | eb3e1149a2fcc9c48d268ffe2319d872081e4c3b (diff) | |
download | caffeonacl-abe99e8748ad7f583c87d1a6132ff2d79e70dd9c.tar.gz caffeonacl-abe99e8748ad7f583c87d1a6132ff2d79e70dd9c.tar.bz2 caffeonacl-abe99e8748ad7f583c87d1a6132ff2d79e70dd9c.zip |
Implement RMSProp Solver
Implement RMSProp solver and cleaned up to adjust to new solver interface that uses
accumulated gradients and refactored regularization.
Diffstat (limited to 'include')
-rw-r--r-- | include/caffe/solver.hpp | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 703434b5..fbade938 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -135,6 +135,29 @@ class AdaGradSolver : public SGDSolver<Dtype> { DISABLE_COPY_AND_ASSIGN(AdaGradSolver); }; + +template <typename Dtype> +class RMSPropSolver : public SGDSolver<Dtype> { + public: + explicit RMSPropSolver(const SolverParameter& param) + : SGDSolver<Dtype>(param) { constructor_sanity_check(); } + explicit RMSPropSolver(const string& param_file) + : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); } + + protected: + virtual void ComputeUpdateValue(int param_id, Dtype rate); + void constructor_sanity_check() { + CHECK_EQ(0, this->param_.momentum()) + << "Momentum cannot be used with RMSProp."; + CHECK_GE(this->param_.rms_decay(), 0) + << "rms_decay should lie between 0 and 1."; + CHECK_LT(this->param_.rms_decay(), 1) + << "rms_decay should lie between 0 and 1."; + } + + DISABLE_COPY_AND_ASSIGN(RMSPropSolver); +}; + template <typename Dtype> Solver<Dtype>* GetSolver(const SolverParameter& param) { SolverParameter_SolverType type = param.solver_type(); @@ -146,6 +169,8 @@ Solver<Dtype>* GetSolver(const SolverParameter& param) { return new NesterovSolver<Dtype>(param); case SolverParameter_SolverType_ADAGRAD: return new AdaGradSolver<Dtype>(param); + case SolverParameter_SolverType_RMSPROP: + return new RMSPropSolver<Dtype>(param); default: LOG(FATAL) << "Unknown SolverType: " << type; } |