summaryrefslogtreecommitdiff
path: root/include/caffe/solver.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/caffe/solver.hpp')
-rw-r--r--include/caffe/solver.hpp64
1 files changed, 64 insertions, 0 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp
new file mode 100644
index 00000000..98c872dc
--- /dev/null
+++ b/include/caffe/solver.hpp
@@ -0,0 +1,64 @@
+// Copyright Yangqing Jia 2013
+
+#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
+#define CAFFE_OPTIMIZATION_SOLVER_HPP_
+
+#include <vector>
+
+namespace caffe {
+
+template <typename Dtype>
+class Solver {
+ public:
+ explicit Solver(const SolverParameter& param)
+ : param_(param) {}
+ // The main entry of the solver function. In default, iter will be zero. Pass
+ // in a non-zero iter number to resume training for a pre-trained net.
+ void Solve(Net<Dtype>* net, char* state_file = NULL);
+ virtual ~Solver() {}
+
+ protected:
+ // PreSolve is run before any solving iteration starts, allowing one to
+ // put up some scaffold.
+ virtual void PreSolve() {}
+ // Get the update value for the current iteration.
+ virtual void ComputeUpdateValue() = 0;
+ // The Solver::Snapshot function implements the basic snapshotting utility
+ // that stores the learned net. You should implement the SnapshotSolverState()
+ // function that produces a SolverState protocol buffer that needs to be
+ // written to disk together with the learned net.
+ void Snapshot();
+ virtual void SnapshotSolverState(SolverState* state) = 0;
+ // The Restore function implements how one should restore the solver to a
+ // previously snapshotted state. You should implement the RestoreSolverState()
+ // function that restores the state from a SolverState protocol buffer.
+ void Restore(char* state_file);
+ virtual void RestoreSolverState(const SolverState& state) = 0;
+ SolverParameter param_;
+ int iter_;
+ Net<Dtype>* net_;
+
+ DISABLE_COPY_AND_ASSIGN(Solver);
+};
+
+
+template <typename Dtype>
+class SGDSolver : public Solver<Dtype> {
+ public:
+ explicit SGDSolver(const SolverParameter& param)
+ : Solver<Dtype>(param) {}
+
+ protected:
+ virtual void PreSolve();
+ virtual Dtype GetLearningRate();
+ virtual void ComputeUpdateValue();
+ virtual void SnapshotSolverState(SolverState * state);
+ virtual void RestoreSolverState(const SolverState& state);
+ // history maintains the historical momentum data.
+ vector<shared_ptr<Blob<Dtype> > > history_;
+};
+
+
+} // namspace caffe
+
+#endif // CAFFE_OPTIMIZATION_SOLVER_HPP_