diff options
Diffstat (limited to 'src/caffe/optimization/solver.hpp')
-rw-r--r-- | src/caffe/optimization/solver.hpp | 19 |
1 files changed, 16 insertions, 3 deletions
diff --git a/src/caffe/optimization/solver.hpp b/src/caffe/optimization/solver.hpp index 8dc41aff..a5ea6126 100644 --- a/src/caffe/optimization/solver.hpp +++ b/src/caffe/optimization/solver.hpp @@ -12,8 +12,9 @@ class Solver { public: explicit Solver(const SolverParameter& param) : param_(param) {} - // The main entry of the solver function. - void Solve(Net<Dtype>* net); + // The main entry of the solver function. In default, iter will be zero. Pass + // in a non-zero iter number to resume training for a pre-trained net. + void Solve(Net<Dtype>* net, char* state_file = NULL); virtual ~Solver() {} protected: @@ -22,7 +23,17 @@ class Solver { virtual void PreSolve() {} // Get the update value for the current iteration. virtual void ComputeUpdateValue() = 0; + // The Solver::Snapshot function implements the basic snapshotting utility + // that stores the learned net. You should implement the SnapshotSolverState() + // function that produces a SolverState protocol buffer that needs to be + // written to disk together with the learned net. void Snapshot(bool is_final = false); + virtual void SnapshotSolverState(SolverState* state) = 0; + // The Restore function implements how one should restore the solver to a + // previously snapshotted state. You should implement the RestoreSolverState() + // function that restores the state from a SolverState protocol buffer. + void Restore(char* state_file); + virtual void RestoreSolverState(const SolverState& state) = 0; SolverParameter param_; int iter_; Net<Dtype>* net_; @@ -39,8 +50,10 @@ class SGDSolver : public Solver<Dtype> { protected: virtual void PreSolve(); - Dtype GetLearningRate(); + virtual Dtype GetLearningRate(); virtual void ComputeUpdateValue(); + virtual void SnapshotSolverState(SolverState * state); + virtual void RestoreSolverState(const SolverState& state); // history maintains the historical momentum data. vector<shared_ptr<Blob<Dtype> > > history_; }; |