summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorYangqing Jia <jiayq84@gmail.com>2013-10-15 11:28:26 -0700
committerYangqing Jia <jiayq84@gmail.com>2013-10-15 11:28:26 -0700
commit4f7e519ea0a1db36a54c68f5021f6c4acfc3a657 (patch)
treeb547f2a8625264de104a5c6431e048f24bc241a0 /src
parenta0f2c505f4c1d9e0c5eeb9e4e5b29fe2afe6b6b4 (diff)
downloadcaffe-4f7e519ea0a1db36a54c68f5021f6c4acfc3a657.tar.gz
caffe-4f7e519ea0a1db36a54c68f5021f6c4acfc3a657.tar.bz2
caffe-4f7e519ea0a1db36a54c68f5021f6c4acfc3a657.zip
Moved the layer factory implementation to cpp; added snapshot and restore functions to solver.
Diffstat (limited to 'src')
-rw-r--r--src/caffe/caffe.hpp1
-rw-r--r--src/caffe/layer.hpp4
-rw-r--r--src/caffe/layer_factory.cpp (renamed from src/caffe/layer_factory.hpp)2
-rw-r--r--src/caffe/net.cpp2
-rw-r--r--src/caffe/optimization/solver.cpp46
-rw-r--r--src/caffe/optimization/solver.hpp19
-rw-r--r--src/caffe/proto/caffe.proto18
-rw-r--r--src/caffe/vision_layers.hpp1
8 files changed, 86 insertions, 7 deletions
diff --git a/src/caffe/caffe.hpp b/src/caffe/caffe.hpp
index 800138f9..5806bc02 100644
--- a/src/caffe/caffe.hpp
+++ b/src/caffe/caffe.hpp
@@ -7,7 +7,6 @@
#include "caffe/blob.hpp"
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
-#include "caffe/layer_factory.hpp"
#include "caffe/net.hpp"
#include "caffe/vision_layers.hpp"
diff --git a/src/caffe/layer.hpp b/src/caffe/layer.hpp
index cbfde0cb..adc63657 100644
--- a/src/caffe/layer.hpp
+++ b/src/caffe/layer.hpp
@@ -127,6 +127,10 @@ void Layer<Dtype>::ToProto(LayerParameter* param, bool write_diff) {
}
}
+// The layer factory function
+template <typename Dtype>
+Layer<Dtype>* GetLayer(const LayerParameter& param);
+
} // namespace caffe
#endif // CAFFE_LAYER_H_
diff --git a/src/caffe/layer_factory.hpp b/src/caffe/layer_factory.cpp
index d231e17b..6961bb3f 100644
--- a/src/caffe/layer_factory.hpp
+++ b/src/caffe/layer_factory.cpp
@@ -54,6 +54,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
return (Layer<Dtype>*)(NULL);
}
+template Layer<float>* GetLayer(const LayerParameter& param);
+template Layer<double>* GetLayer(const LayerParameter& param);
} // namespace caffe
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
index 22250da5..e1442ecb 100644
--- a/src/caffe/net.cpp
+++ b/src/caffe/net.cpp
@@ -6,7 +6,7 @@
#include <vector>
#include "caffe/proto/caffe.pb.h"
-#include "caffe/layer_factory.hpp"
+#include "caffe/layer.hpp"
#include "caffe/net.hpp"
using std::pair;
diff --git a/src/caffe/optimization/solver.cpp b/src/caffe/optimization/solver.cpp
index b2a57600..73c69c03 100644
--- a/src/caffe/optimization/solver.cpp
+++ b/src/caffe/optimization/solver.cpp
@@ -18,11 +18,17 @@ using std::min;
namespace caffe {
template <typename Dtype>
-void Solver<Dtype>::Solve(Net<Dtype>* net) {
+void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
net_ = net;
LOG(INFO) << "Solving " << net_->name();
PreSolve();
+
iter_ = 0;
+ if (resume_file) {
+ LOG(INFO) << "Restoring previous solver status from " << resume_file;
+ Restore(resume_file);
+ }
+
// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
vector<Blob<Dtype>*> bottom_vec;
@@ -56,8 +62,26 @@ void Solver<Dtype>::Snapshot(bool is_final) {
sprintf(iter_str_buffer, "_iter_%d", iter_);
filename += iter_str_buffer;
}
- LOG(ERROR) << "Snapshotting to " << filename;
+ LOG(INFO) << "Snapshotting to " << filename;
WriteProtoToBinaryFile(net_param, filename.c_str());
+ SolverState state;
+ SnapshotSolverState(&state);
+ state.set_iter(iter_);
+ state.set_learned_net(filename);
+ filename += ".solverstate";
+ LOG(INFO) << "Snapshotting solver state to " << filename;
+ WriteProtoToBinaryFile(state, filename.c_str());
+}
+
+template <typename Dtype>
+void Solver<Dtype>::Restore(char* state_file) {
+ SolverState state;
+ NetParameter net_param;
+ ReadProtoFromBinaryFile(state_file, &state);
+ ReadProtoFromBinaryFile(state.learned_net().c_str(), &net_param);
+ net_->CopyTrainedLayersFrom(net_param);
+ iter_ = state.iter();
+ RestoreSolverState(state);
}
@@ -167,6 +191,24 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
}
}
+template <typename Dtype>
+void SGDSolver<Dtype>::SnapshotSolverState(SolverState* state) {
+ state->clear_history();
+ for (int i = 0; i < history_.size(); ++i) {
+ // Add history
+ BlobProto* history_blob = state->add_history();
+ history_[i]->ToProto(history_blob);
+ }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
+ CHECK_EQ(state.history_size(), history_.size())
+ << "Incorrect length of history blobs.";
+ for (int i = 0; i < history_.size(); ++i) {
+ history_[i]->FromProto(state.history(i));
+ }
+}
INSTANTIATE_CLASS(Solver);
INSTANTIATE_CLASS(SGDSolver);
diff --git a/src/caffe/optimization/solver.hpp b/src/caffe/optimization/solver.hpp
index 8dc41aff..a5ea6126 100644
--- a/src/caffe/optimization/solver.hpp
+++ b/src/caffe/optimization/solver.hpp
@@ -12,8 +12,9 @@ class Solver {
public:
explicit Solver(const SolverParameter& param)
: param_(param) {}
- // The main entry of the solver function.
- void Solve(Net<Dtype>* net);
+ // The main entry of the solver function. In default, iter will be zero. Pass
+ // in a non-zero iter number to resume training for a pre-trained net.
+ void Solve(Net<Dtype>* net, char* state_file = NULL);
virtual ~Solver() {}
protected:
@@ -22,7 +23,17 @@ class Solver {
virtual void PreSolve() {}
// Get the update value for the current iteration.
virtual void ComputeUpdateValue() = 0;
+ // The Solver::Snapshot function implements the basic snapshotting utility
+ // that stores the learned net. You should implement the SnapshotSolverState()
+ // function that produces a SolverState protocol buffer that needs to be
+ // written to disk together with the learned net.
void Snapshot(bool is_final = false);
+ virtual void SnapshotSolverState(SolverState* state) = 0;
+ // The Restore function implements how one should restore the solver to a
+ // previously snapshotted state. You should implement the RestoreSolverState()
+ // function that restores the state from a SolverState protocol buffer.
+ void Restore(char* state_file);
+ virtual void RestoreSolverState(const SolverState& state) = 0;
SolverParameter param_;
int iter_;
Net<Dtype>* net_;
@@ -39,8 +50,10 @@ class SGDSolver : public Solver<Dtype> {
protected:
virtual void PreSolve();
- Dtype GetLearningRate();
+ virtual Dtype GetLearningRate();
virtual void ComputeUpdateValue();
+ virtual void SnapshotSolverState(SolverState * state);
+ virtual void RestoreSolverState(const SolverState& state);
// history maintains the historical momentum data.
vector<shared_ptr<Blob<Dtype> > > history_;
};
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index 87f2c2cc..4be96963 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -105,4 +105,22 @@ message SolverParameter {
optional float stepsize = 12; // the stepsize for learning rate policy "step"
optional string snapshot_prefix = 13; // The prefix for the snapshot.
+
+ // Adagrad solver parameters
+ // For Adagrad, we will first run normal sgd using the sgd parameters above
+ // for adagrad_skip iterations, and then kick in the adagrad algorithm, with
+ // the learning rate being adagrad_gamma * adagrad_skip. Note that the adagrad
+ // algorithm will NOT use the learning rate multiplier that is specified in
+ // the layer parameter specifications, as it will adjust the learning rate
+ // of individual parameters in a data-dependent way.
+ // WORK IN PROGRESS: not actually implemented yet.
+ optional float adagrad_gamma = 14; // adagrad learning rate multiplier
+ optional float adagrad_skip = 15; // the steps to skip before adagrad kicks in
}
+
+// A message that stores the solver snapshots
+message SolverState {
+ optional int32 iter = 1; // The current iteration
+ optional string learned_net = 2; // The file that stores the learned net.
+ repeated BlobProto history = 3; // The history for sgd solvers
+} \ No newline at end of file
diff --git a/src/caffe/vision_layers.hpp b/src/caffe/vision_layers.hpp
index b07307bb..0dc34763 100644
--- a/src/caffe/vision_layers.hpp
+++ b/src/caffe/vision_layers.hpp
@@ -274,6 +274,7 @@ class DataLayer : public Layer<Dtype> {
pthread_t thread_;
shared_ptr<Blob<Dtype> > prefetch_data_;
shared_ptr<Blob<Dtype> > prefetch_label_;
+ Blob<Dtype> data_mean_;
};