diff options
Diffstat (limited to 'include/caffe/solver.hpp')
-rw-r--r-- | include/caffe/solver.hpp | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp new file mode 100644 index 00000000..98c872dc --- /dev/null +++ b/include/caffe/solver.hpp @@ -0,0 +1,64 @@ +// Copyright Yangqing Jia 2013 + +#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_ +#define CAFFE_OPTIMIZATION_SOLVER_HPP_ + +#include <vector> + +namespace caffe { + +template <typename Dtype> +class Solver { + public: + explicit Solver(const SolverParameter& param) + : param_(param) {} + // The main entry of the solver function. In default, iter will be zero. Pass + // in a non-zero iter number to resume training for a pre-trained net. + void Solve(Net<Dtype>* net, char* state_file = NULL); + virtual ~Solver() {} + + protected: + // PreSolve is run before any solving iteration starts, allowing one to + // put up some scaffold. + virtual void PreSolve() {} + // Get the update value for the current iteration. + virtual void ComputeUpdateValue() = 0; + // The Solver::Snapshot function implements the basic snapshotting utility + // that stores the learned net. You should implement the SnapshotSolverState() + // function that produces a SolverState protocol buffer that needs to be + // written to disk together with the learned net. + void Snapshot(); + virtual void SnapshotSolverState(SolverState* state) = 0; + // The Restore function implements how one should restore the solver to a + // previously snapshotted state. You should implement the RestoreSolverState() + // function that restores the state from a SolverState protocol buffer. + void Restore(char* state_file); + virtual void RestoreSolverState(const SolverState& state) = 0; + SolverParameter param_; + int iter_; + Net<Dtype>* net_; + + DISABLE_COPY_AND_ASSIGN(Solver); +}; + + +template <typename Dtype> +class SGDSolver : public Solver<Dtype> { + public: + explicit SGDSolver(const SolverParameter& param) + : Solver<Dtype>(param) {} + + protected: + virtual void PreSolve(); + virtual Dtype GetLearningRate(); + virtual void ComputeUpdateValue(); + virtual void SnapshotSolverState(SolverState * state); + virtual void RestoreSolverState(const SolverState& state); + // history maintains the historical momentum data. + vector<shared_ptr<Blob<Dtype> > > history_; +}; + + +} // namspace caffe + +#endif // CAFFE_OPTIMIZATION_SOLVER_HPP_ |