summaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorMatthias Plappert <matthiasplappert@me.com>2015-07-18 18:46:51 +0200
committerMatthias Plappert <matthiasplappert@me.com>2015-08-10 11:44:13 +0200
commitf2e523e479b89902b644f3a8bb2ac51a6dc28eee (patch)
treec68bfe6e46a3305ccf6a51af48dc6bc49e290c40 /include
parent4c58741ce2e031b61aef53914128801e6edd673d (diff)
downloadcaffeonacl-f2e523e479b89902b644f3a8bb2ac51a6dc28eee.tar.gz
caffeonacl-f2e523e479b89902b644f3a8bb2ac51a6dc28eee.tar.bz2
caffeonacl-f2e523e479b89902b644f3a8bb2ac51a6dc28eee.zip
Clean up and modernize AdaDelta code; add learning rate support; add additional test cases
Diffstat (limited to 'include')
-rw-r--r--include/caffe/solver.hpp16
1 files changed, 5 insertions, 11 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp
index 495cd4f1..5fefd01e 100644
--- a/include/caffe/solver.hpp
+++ b/include/caffe/solver.hpp
@@ -82,12 +82,12 @@ class SGDSolver : public Solver<Dtype> {
const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }
protected:
+ void PreSolve();
Dtype GetLearningRate();
virtual void ApplyUpdate();
virtual void Normalize(int param_id);
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
- virtual void PreSolve();
virtual void ClipGradients();
virtual void SnapshotSolverState(const string& model_filename);
virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
@@ -162,19 +162,13 @@ template <typename Dtype>
class AdaDeltaSolver : public SGDSolver<Dtype> {
public:
explicit AdaDeltaSolver(const SolverParameter& param)
- : SGDSolver<Dtype>(param) { PreSolve(); constructor_sanity_check(); }
+ : SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
explicit AdaDeltaSolver(const string& param_file)
- : SGDSolver<Dtype>(param_file) { PreSolve(); constructor_sanity_check(); }
+ : SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }
protected:
- virtual void PreSolve();
- virtual void ComputeUpdateValue();
- void constructor_sanity_check() {
- CHECK_EQ(0, this->param_.base_lr())
- << "Learning rate cannot be used with AdaDelta.";
- CHECK_EQ("", this->param_.lr_policy())
- << "Learning rate policy cannot be applied to AdaDelta.";
- }
+ void AdaDeltaPreSolve();
+ virtual void ComputeUpdateValue(int param_id, Dtype rate);
DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
};