diff options
author | Yangqing Jia <jiayq84@gmail.com> | 2013-10-15 11:28:26 -0700 |
---|---|---|
committer | Yangqing Jia <jiayq84@gmail.com> | 2013-10-15 11:28:26 -0700 |
commit | 4f7e519ea0a1db36a54c68f5021f6c4acfc3a657 (patch) | |
tree | b547f2a8625264de104a5c6431e048f24bc241a0 /src | |
parent | a0f2c505f4c1d9e0c5eeb9e4e5b29fe2afe6b6b4 (diff) | |
download | caffeonacl-4f7e519ea0a1db36a54c68f5021f6c4acfc3a657.tar.gz caffeonacl-4f7e519ea0a1db36a54c68f5021f6c4acfc3a657.tar.bz2 caffeonacl-4f7e519ea0a1db36a54c68f5021f6c4acfc3a657.zip |
Moved the layer factory implementation to cpp; added snapshot and restore functions to solver.
Diffstat (limited to 'src')
-rw-r--r-- | src/caffe/caffe.hpp | 1 | ||||
-rw-r--r-- | src/caffe/layer.hpp | 4 | ||||
-rw-r--r-- | src/caffe/layer_factory.cpp (renamed from src/caffe/layer_factory.hpp) | 2 | ||||
-rw-r--r-- | src/caffe/net.cpp | 2 | ||||
-rw-r--r-- | src/caffe/optimization/solver.cpp | 46 | ||||
-rw-r--r-- | src/caffe/optimization/solver.hpp | 19 | ||||
-rw-r--r-- | src/caffe/proto/caffe.proto | 18 | ||||
-rw-r--r-- | src/caffe/vision_layers.hpp | 1 |
8 files changed, 86 insertions, 7 deletions
diff --git a/src/caffe/caffe.hpp b/src/caffe/caffe.hpp index 800138f9..5806bc02 100644 --- a/src/caffe/caffe.hpp +++ b/src/caffe/caffe.hpp @@ -7,7 +7,6 @@ #include "caffe/blob.hpp" #include "caffe/filler.hpp" #include "caffe/layer.hpp" -#include "caffe/layer_factory.hpp" #include "caffe/net.hpp" #include "caffe/vision_layers.hpp" diff --git a/src/caffe/layer.hpp b/src/caffe/layer.hpp index cbfde0cb..adc63657 100644 --- a/src/caffe/layer.hpp +++ b/src/caffe/layer.hpp @@ -127,6 +127,10 @@ void Layer<Dtype>::ToProto(LayerParameter* param, bool write_diff) { } } +// The layer factory function +template <typename Dtype> +Layer<Dtype>* GetLayer(const LayerParameter& param); + } // namespace caffe #endif // CAFFE_LAYER_H_ diff --git a/src/caffe/layer_factory.hpp b/src/caffe/layer_factory.cpp index d231e17b..6961bb3f 100644 --- a/src/caffe/layer_factory.hpp +++ b/src/caffe/layer_factory.cpp @@ -54,6 +54,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) { return (Layer<Dtype>*)(NULL); } +template Layer<float>* GetLayer(const LayerParameter& param); +template Layer<double>* GetLayer(const LayerParameter& param); } // namespace caffe diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 22250da5..e1442ecb 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -6,7 +6,7 @@ #include <vector> #include "caffe/proto/caffe.pb.h" -#include "caffe/layer_factory.hpp" +#include "caffe/layer.hpp" #include "caffe/net.hpp" using std::pair; diff --git a/src/caffe/optimization/solver.cpp b/src/caffe/optimization/solver.cpp index b2a57600..73c69c03 100644 --- a/src/caffe/optimization/solver.cpp +++ b/src/caffe/optimization/solver.cpp @@ -18,11 +18,17 @@ using std::min; namespace caffe { template <typename Dtype> -void Solver<Dtype>::Solve(Net<Dtype>* net) { +void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) { net_ = net; LOG(INFO) << "Solving " << net_->name(); PreSolve(); + iter_ = 0; + if (resume_file) { + LOG(INFO) << "Restoring previous solver status from " << resume_file; + Restore(resume_file); + } + // For a network that is trained by the solver, no bottom or top vecs // should be given, and we will just provide dummy vecs. vector<Blob<Dtype>*> bottom_vec; @@ -56,8 +62,26 @@ void Solver<Dtype>::Snapshot(bool is_final) { sprintf(iter_str_buffer, "_iter_%d", iter_); filename += iter_str_buffer; } - LOG(ERROR) << "Snapshotting to " << filename; + LOG(INFO) << "Snapshotting to " << filename; WriteProtoToBinaryFile(net_param, filename.c_str()); + SolverState state; + SnapshotSolverState(&state); + state.set_iter(iter_); + state.set_learned_net(filename); + filename += ".solverstate"; + LOG(INFO) << "Snapshotting solver state to " << filename; + WriteProtoToBinaryFile(state, filename.c_str()); +} + +template <typename Dtype> +void Solver<Dtype>::Restore(char* state_file) { + SolverState state; + NetParameter net_param; + ReadProtoFromBinaryFile(state_file, &state); + ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param); + net_->CopyTrainedLayersFrom(net_param); + iter_ = state.iter(); + RestoreSolverState(state); } @@ -167,6 +191,24 @@ void SGDSolver<Dtype>::ComputeUpdateValue() { } } +template <typename Dtype> +void SGDSolver<Dtype>::SnapshotSolverState(SolverState* state) { + state->clear_history(); + for (int i = 0; i < history_.size(); ++i) { + // Add history + BlobProto* history_blob = state->add_history(); + history_[i]->ToProto(history_blob); + } +} + +template <typename Dtype> +void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) { + CHECK_EQ(state.history_size(), history_.size()) + << "Incorrect length of history blobs."; + for (int i = 0; i < history_.size(); ++i) { + history_[i]->FromProto(state.history(i)); + } +} INSTANTIATE_CLASS(Solver); INSTANTIATE_CLASS(SGDSolver); diff --git a/src/caffe/optimization/solver.hpp b/src/caffe/optimization/solver.hpp index 8dc41aff..a5ea6126 100644 --- a/src/caffe/optimization/solver.hpp +++ b/src/caffe/optimization/solver.hpp @@ -12,8 +12,9 @@ class Solver { public: explicit Solver(const SolverParameter& param) : param_(param) {} - // The main entry of the solver function. - void Solve(Net<Dtype>* net); + // The main entry of the solver function. In default, iter will be zero. Pass + // in a non-zero iter number to resume training for a pre-trained net. + void Solve(Net<Dtype>* net, char* state_file = NULL); virtual ~Solver() {} protected: @@ -22,7 +23,17 @@ class Solver { virtual void PreSolve() {} // Get the update value for the current iteration. virtual void ComputeUpdateValue() = 0; + // The Solver::Snapshot function implements the basic snapshotting utility + // that stores the learned net. You should implement the SnapshotSolverState() + // function that produces a SolverState protocol buffer that needs to be + // written to disk together with the learned net. void Snapshot(bool is_final = false); + virtual void SnapshotSolverState(SolverState* state) = 0; + // The Restore function implements how one should restore the solver to a + // previously snapshotted state. You should implement the RestoreSolverState() + // function that restores the state from a SolverState protocol buffer. + void Restore(char* state_file); + virtual void RestoreSolverState(const SolverState& state) = 0; SolverParameter param_; int iter_; Net<Dtype>* net_; @@ -39,8 +50,10 @@ class SGDSolver : public Solver<Dtype> { protected: virtual void PreSolve(); - Dtype GetLearningRate(); + virtual Dtype GetLearningRate(); virtual void ComputeUpdateValue(); + virtual void SnapshotSolverState(SolverState * state); + virtual void RestoreSolverState(const SolverState& state); // history maintains the historical momentum data. vector<shared_ptr<Blob<Dtype> > > history_; }; diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 87f2c2cc..4be96963 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -105,4 +105,22 @@ message SolverParameter { optional float stepsize = 12; // the stepsize for learning rate policy "step" optional string snapshot_prefix = 13; // The prefix for the snapshot. + + // Adagrad solver parameters + // For Adagrad, we will first run normal sgd using the sgd parameters above + // for adagrad_skip iterations, and then kick in the adagrad algorithm, with + // the learning rate being adagrad_gamma * adagrad_skip. Note that the adagrad + // algorithm will NOT use the learning rate multiplier that is specified in + // the layer parameter specifications, as it will adjust the learning rate + // of individual parameters in a data-dependent way. + // WORK IN PROGRESS: not actually implemented yet. + optional float adagrad_gamma = 14; // adagrad learning rate multiplier + optional float adagrad_skip = 15; // the steps to skip before adagrad kicks in } + +// A message that stores the solver snapshots +message SolverState { + optional int32 iter = 1; // The current iteration + optional string learned_net = 2; // The file that stores the learned net. + repeated BlobProto history = 3; // The history for sgd solvers +}
\ No newline at end of file diff --git a/src/caffe/vision_layers.hpp b/src/caffe/vision_layers.hpp index b07307bb..0dc34763 100644 --- a/src/caffe/vision_layers.hpp +++ b/src/caffe/vision_layers.hpp @@ -274,6 +274,7 @@ class DataLayer : public Layer<Dtype> { pthread_t thread_; shared_ptr<Blob<Dtype> > prefetch_data_; shared_ptr<Blob<Dtype> > prefetch_label_; + Blob<Dtype> data_mean_; }; |