summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/train_net.cpp9
-rw-r--r--include/caffe/net.hpp6
-rw-r--r--src/caffe/net.cpp51
3 files changed, 52 insertions, 14 deletions
diff --git a/examples/train_net.cpp b/examples/train_net.cpp
index 06ca1213..b4181d64 100644
--- a/examples/train_net.cpp
+++ b/examples/train_net.cpp
@@ -3,7 +3,7 @@
// This is a simple script that allows one to quickly train a network whose
// parameters are specified by text format protocol buffers.
// Usage:
-// train_net net_proto_file solver_proto_file
+// train_net net_proto_file solver_proto_file [resume_point_file]
#include <cuda_runtime.h>
@@ -28,7 +28,12 @@ int main(int argc, char** argv) {
LOG(ERROR) << "Starting Optimization";
SGDSolver<float> solver(solver_param);
- solver.Solve(&caffe_net);
+ if (argc == 4) {
+ LOG(ERROR) << "Resuming from " << argv[3];
+ solver.Solve(&caffe_net, argv[3]);
+ } else {
+ solver.Solve(&caffe_net);
+ }
LOG(ERROR) << "Optimization Done.";
return 0;
diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp
index f0a5ebb9..9bbfd37e 100644
--- a/include/caffe/net.hpp
+++ b/include/caffe/net.hpp
@@ -65,13 +65,19 @@ class Net {
void Update();
protected:
+ // Function to get misc parameters, e.g. the learning rate multiplier and
+ // weight decay.
+ void GetLearningRateAndWeightDecay();
+
// Individual layers in the net
vector<shared_ptr<Layer<Dtype> > > layers_;
vector<string> layer_names_;
+ vector<bool> layer_need_backward_;
// blobs stores the blobs that store intermediate results between the
// layers.
vector<shared_ptr<Blob<Dtype> > > blobs_;
vector<string> blob_names_;
+ vector<bool> blob_need_backward_;
// bottom_vecs stores the vectors containing the input for each layer
vector<vector<Blob<Dtype>*> > bottom_vecs_;
vector<vector<int> > bottom_id_vecs_;
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
index 165869d4..50f4f93a 100644
--- a/src/caffe/net.cpp
+++ b/src/caffe/net.cpp
@@ -34,6 +34,7 @@ Net<Dtype>::Net(const NetParameter& param,
bottom[i]->height(), bottom[i]->width()));
blobs_.push_back(blob_pointer);
blob_names_.push_back(blob_name);
+ blob_need_backward_.push_back(false);
net_input_blob_indices_.push_back(i);
blob_name_to_idx[blob_name] = i;
available_blobs.insert(blob_name);
@@ -49,17 +50,21 @@ Net<Dtype>::Net(const NetParameter& param,
layers_.push_back(shared_ptr<Layer<Dtype> >(GetLayer<Dtype>(layer_param)));
layer_names_.push_back(layer_param.name());
LOG(INFO) << "Creating Layer " << layer_param.name();
+ bool need_backward = false;
// Figure out this layer's input and output
for (int j = 0; j < layer_connection.bottom_size(); ++j) {
const string& blob_name = layer_connection.bottom(j);
+ const int blob_id = blob_name_to_idx[blob_name];
if (available_blobs.find(blob_name) == available_blobs.end()) {
LOG(FATAL) << "Unknown blob input " << blob_name <<
" to layer" << j;
}
LOG(INFO) << layer_param.name() << " <- " << blob_name;
bottom_vecs_[i].push_back(
- blobs_[blob_name_to_idx[blob_name]].get());
- bottom_id_vecs_[i].push_back(blob_name_to_idx[blob_name]);
+ blobs_[blob_id].get());
+ bottom_id_vecs_[i].push_back(blob_id);
+ // If a blob needs backward, this layer should provide it.
+ need_backward |= blob_need_backward_[blob_id];
available_blobs.erase(blob_name);
}
for (int j = 0; j < layer_connection.top_size(); ++j) {
@@ -83,12 +88,30 @@ Net<Dtype>::Net(const NetParameter& param,
shared_ptr<Blob<Dtype> > blob_pointer(new Blob<Dtype>());
blobs_.push_back(blob_pointer);
blob_names_.push_back(blob_name);
+ blob_need_backward_.push_back(false);
blob_name_to_idx[blob_name] = blob_names_.size() - 1;
available_blobs.insert(blob_name);
top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get());
top_id_vecs_[i].push_back(blob_names_.size() - 1);
}
}
+ // After this layer is connected, set it up.
+ LOG(INFO) << "Setting up " << layer_names_[i];
+ layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
+ // Check if this layer needs backward operation itself
+ for (int j = 0; j < layers_[i]->layer_param().blobs_lr_size(); ++j) {
+ need_backward |= (layers_[i]->layer_param().blobs_lr(j) > 0);
+ }
+ // Finally, set the backward flag
+ layer_need_backward_.push_back(need_backward);
+ if (need_backward) {
+ LOG(INFO) << layer_names_[i] << " needs backward computation.";
+ for (int j = 0; j < top_id_vecs_[i].size(); ++j) {
+ blob_need_backward_[top_id_vecs_[i][j]] = true;
+ }
+ } else {
+ LOG(INFO) << layer_names_[i] << " does not need backward computation.";
+ }
}
// In the end, all remaining blobs are considered output blobs.
for (set<string>::iterator it = available_blobs.begin();
@@ -97,11 +120,15 @@ Net<Dtype>::Net(const NetParameter& param,
net_output_blob_indices_.push_back(blob_name_to_idx[*it]);
net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get());
}
+ GetLearningRateAndWeightDecay();
+ LOG(INFO) << "Network initialization done.";
+}
+
- LOG(INFO) << "Setting up the layers.";
+template <typename Dtype>
+void Net<Dtype>::GetLearningRateAndWeightDecay() {
+ LOG(INFO) << "Collecting Learning Rate and Weight Decay.";
for (int i = 0; i < layers_.size(); ++i) {
- LOG(INFO) << "Setting up " << layer_names_[i];
- layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
vector<shared_ptr<Blob<Dtype> > >& layer_blobs = layers_[i]->blobs();
for (int j = 0; j < layer_blobs.size(); ++j) {
params_.push_back(layer_blobs[j]);
@@ -111,7 +138,7 @@ Net<Dtype>::Net(const NetParameter& param,
CHECK_EQ(layers_[i]->layer_param().blobs_lr_size(), layer_blobs.size());
for (int j = 0; j < layer_blobs.size(); ++j) {
float local_lr = layers_[i]->layer_param().blobs_lr(j);
- CHECK_GT(local_lr, 0.);
+ CHECK_GE(local_lr, 0.);
params_lr_.push_back(local_lr);
}
} else {
@@ -125,7 +152,7 @@ Net<Dtype>::Net(const NetParameter& param,
layer_blobs.size());
for (int j = 0; j < layer_blobs.size(); ++j) {
float local_decay = layers_[i]->layer_param().weight_decay(j);
- CHECK_GT(local_decay, 0.);
+ CHECK_GE(local_decay, 0.);
params_weight_decay_.push_back(local_decay);
}
} else {
@@ -139,7 +166,6 @@ Net<Dtype>::Net(const NetParameter& param,
<< top_vecs_[i][topid]->width();
}
}
- LOG(INFO) << "Network initialization done.";
}
template <typename Dtype>
@@ -159,11 +185,12 @@ const vector<Blob<Dtype>*>& Net<Dtype>::Forward(
template <typename Dtype>
Dtype Net<Dtype>::Backward() {
Dtype loss = 0;
- // TODO(Yangqing): figure out those layers that do not need backward.
for (int i = layers_.size() - 1; i >= 0; --i) {
- Dtype layer_loss = layers_[i]->Backward(
- top_vecs_[i], true, &bottom_vecs_[i]);
- loss += layer_loss;
+ if (layer_need_backward_[i]) {
+ Dtype layer_loss = layers_[i]->Backward(
+ top_vecs_[i], true, &bottom_vecs_[i]);
+ loss += layer_loss;
+ }
}
return loss;
}