diff options
author | Matthias Plappert <matthiasplappert@me.com> | 2015-07-18 18:46:51 +0200 |
---|---|---|
committer | Matthias Plappert <matthiasplappert@me.com> | 2015-08-10 11:44:13 +0200 |
commit | f2e523e479b89902b644f3a8bb2ac51a6dc28eee (patch) | |
tree | c68bfe6e46a3305ccf6a51af48dc6bc49e290c40 /include | |
parent | 4c58741ce2e031b61aef53914128801e6edd673d (diff) | |
download | caffeonacl-f2e523e479b89902b644f3a8bb2ac51a6dc28eee.tar.gz caffeonacl-f2e523e479b89902b644f3a8bb2ac51a6dc28eee.tar.bz2 caffeonacl-f2e523e479b89902b644f3a8bb2ac51a6dc28eee.zip |
Clean up and modernize AdaDelta code; add learning rate support; add additional test cases
Diffstat (limited to 'include')
-rw-r--r-- | include/caffe/solver.hpp | 16 |
1 files changed, 5 insertions, 11 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 495cd4f1..5fefd01e 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -82,12 +82,12 @@ class SGDSolver : public Solver<Dtype> { const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; } protected: + void PreSolve(); Dtype GetLearningRate(); virtual void ApplyUpdate(); virtual void Normalize(int param_id); virtual void Regularize(int param_id); virtual void ComputeUpdateValue(int param_id, Dtype rate); - virtual void PreSolve(); virtual void ClipGradients(); virtual void SnapshotSolverState(const string& model_filename); virtual void SnapshotSolverStateToBinaryProto(const string& model_filename); @@ -162,19 +162,13 @@ template <typename Dtype> class AdaDeltaSolver : public SGDSolver<Dtype> { public: explicit AdaDeltaSolver(const SolverParameter& param) - : SGDSolver<Dtype>(param) { PreSolve(); constructor_sanity_check(); } + : SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); } explicit AdaDeltaSolver(const string& param_file) - : SGDSolver<Dtype>(param_file) { PreSolve(); constructor_sanity_check(); } + : SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); } protected: - virtual void PreSolve(); - virtual void ComputeUpdateValue(); - void constructor_sanity_check() { - CHECK_EQ(0, this->param_.base_lr()) - << "Learning rate cannot be used with AdaDelta."; - CHECK_EQ("", this->param_.lr_policy()) - << "Learning rate policy cannot be applied to AdaDelta."; - } + void AdaDeltaPreSolve(); + virtual void ComputeUpdateValue(int param_id, Dtype rate); DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); }; |