summaryrefslogtreecommitdiff
path: root/src/caffe/optimization/solver.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/caffe/optimization/solver.hpp')
-rw-r--r--src/caffe/optimization/solver.hpp19
1 files changed, 16 insertions, 3 deletions
diff --git a/src/caffe/optimization/solver.hpp b/src/caffe/optimization/solver.hpp
index 8dc41aff..a5ea6126 100644
--- a/src/caffe/optimization/solver.hpp
+++ b/src/caffe/optimization/solver.hpp
@@ -12,8 +12,9 @@ class Solver {
public:
explicit Solver(const SolverParameter& param)
: param_(param) {}
- // The main entry of the solver function.
- void Solve(Net<Dtype>* net);
+ // The main entry of the solver function. In default, iter will be zero. Pass
+ // in a non-zero iter number to resume training for a pre-trained net.
+ void Solve(Net<Dtype>* net, char* state_file = NULL);
virtual ~Solver() {}
protected:
@@ -22,7 +23,17 @@ class Solver {
virtual void PreSolve() {}
// Get the update value for the current iteration.
virtual void ComputeUpdateValue() = 0;
+ // The Solver::Snapshot function implements the basic snapshotting utility
+ // that stores the learned net. You should implement the SnapshotSolverState()
+ // function that produces a SolverState protocol buffer that needs to be
+ // written to disk together with the learned net.
void Snapshot(bool is_final = false);
+ virtual void SnapshotSolverState(SolverState* state) = 0;
+ // The Restore function implements how one should restore the solver to a
+ // previously snapshotted state. You should implement the RestoreSolverState()
+ // function that restores the state from a SolverState protocol buffer.
+ void Restore(char* state_file);
+ virtual void RestoreSolverState(const SolverState& state) = 0;
SolverParameter param_;
int iter_;
Net<Dtype>* net_;
@@ -39,8 +50,10 @@ class SGDSolver : public Solver<Dtype> {
protected:
virtual void PreSolve();
- Dtype GetLearningRate();
+ virtual Dtype GetLearningRate();
virtual void ComputeUpdateValue();
+ virtual void SnapshotSolverState(SolverState * state);
+ virtual void RestoreSolverState(const SolverState& state);
// history maintains the historical momentum data.
vector<shared_ptr<Blob<Dtype> > > history_;
};