diff options
author | Yangqing Jia <jiayq84@gmail.com> | 2013-10-23 11:14:50 -0700 |
---|---|---|
committer | Yangqing Jia <jiayq84@gmail.com> | 2013-10-23 11:14:55 -0700 |
commit | 749ba87e48049a7676f4dca8052d17fff6c485af (patch) | |
tree | 29bdca1f52e51e4a63f82089bbd62cc943cf96fa /src | |
parent | 62089dd8da5bb78e758d3a7fe84095f75a4120f1 (diff) | |
download | caffe-749ba87e48049a7676f4dca8052d17fff6c485af.tar.gz caffe-749ba87e48049a7676f4dca8052d17fff6c485af.tar.bz2 caffe-749ba87e48049a7676f4dca8052d17fff6c485af.zip |
need backward computation, and train_net resume point. Not debugged.
Diffstat (limited to 'src')
-rw-r--r-- | src/caffe/net.cpp | 51 |
1 files changed, 39 insertions, 12 deletions
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 165869d4..50f4f93a 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -34,6 +34,7 @@ Net<Dtype>::Net(const NetParameter& param, bottom[i]->height(), bottom[i]->width())); blobs_.push_back(blob_pointer); blob_names_.push_back(blob_name); + blob_need_backward_.push_back(false); net_input_blob_indices_.push_back(i); blob_name_to_idx[blob_name] = i; available_blobs.insert(blob_name); @@ -49,17 +50,21 @@ Net<Dtype>::Net(const NetParameter& param, layers_.push_back(shared_ptr<Layer<Dtype> >(GetLayer<Dtype>(layer_param))); layer_names_.push_back(layer_param.name()); LOG(INFO) << "Creating Layer " << layer_param.name(); + bool need_backward = false; // Figure out this layer's input and output for (int j = 0; j < layer_connection.bottom_size(); ++j) { const string& blob_name = layer_connection.bottom(j); + const int blob_id = blob_name_to_idx[blob_name]; if (available_blobs.find(blob_name) == available_blobs.end()) { LOG(FATAL) << "Unknown blob input " << blob_name << " to layer" << j; } LOG(INFO) << layer_param.name() << " <- " << blob_name; bottom_vecs_[i].push_back( - blobs_[blob_name_to_idx[blob_name]].get()); - bottom_id_vecs_[i].push_back(blob_name_to_idx[blob_name]); + blobs_[blob_id].get()); + bottom_id_vecs_[i].push_back(blob_id); + // If a blob needs backward, this layer should provide it. + need_backward |= blob_need_backward_[blob_id]; available_blobs.erase(blob_name); } for (int j = 0; j < layer_connection.top_size(); ++j) { @@ -83,12 +88,30 @@ Net<Dtype>::Net(const NetParameter& param, shared_ptr<Blob<Dtype> > blob_pointer(new Blob<Dtype>()); blobs_.push_back(blob_pointer); blob_names_.push_back(blob_name); + blob_need_backward_.push_back(false); blob_name_to_idx[blob_name] = blob_names_.size() - 1; available_blobs.insert(blob_name); top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get()); top_id_vecs_[i].push_back(blob_names_.size() - 1); } } + // After this layer is connected, set it up. + LOG(INFO) << "Setting up " << layer_names_[i]; + layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]); + // Check if this layer needs backward operation itself + for (int j = 0; j < layers_[i]->layer_param().blobs_lr_size(); ++j) { + need_backward |= (layers_[i]->layer_param().blobs_lr(j) > 0); + } + // Finally, set the backward flag + layer_need_backward_.push_back(need_backward); + if (need_backward) { + LOG(INFO) << layer_names_[i] << " needs backward computation."; + for (int j = 0; j < top_id_vecs_[i].size(); ++j) { + blob_need_backward_[top_id_vecs_[i][j]] = true; + } + } else { + LOG(INFO) << layer_names_[i] << " does not need backward computation."; + } } // In the end, all remaining blobs are considered output blobs. for (set<string>::iterator it = available_blobs.begin(); @@ -97,11 +120,15 @@ Net<Dtype>::Net(const NetParameter& param, net_output_blob_indices_.push_back(blob_name_to_idx[*it]); net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get()); } + GetLearningRateAndWeightDecay(); + LOG(INFO) << "Network initialization done."; +} + - LOG(INFO) << "Setting up the layers."; +template <typename Dtype> +void Net<Dtype>::GetLearningRateAndWeightDecay() { + LOG(INFO) << "Collecting Learning Rate and Weight Decay."; for (int i = 0; i < layers_.size(); ++i) { - LOG(INFO) << "Setting up " << layer_names_[i]; - layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]); vector<shared_ptr<Blob<Dtype> > >& layer_blobs = layers_[i]->blobs(); for (int j = 0; j < layer_blobs.size(); ++j) { params_.push_back(layer_blobs[j]); @@ -111,7 +138,7 @@ Net<Dtype>::Net(const NetParameter& param, CHECK_EQ(layers_[i]->layer_param().blobs_lr_size(), layer_blobs.size()); for (int j = 0; j < layer_blobs.size(); ++j) { float local_lr = layers_[i]->layer_param().blobs_lr(j); - CHECK_GT(local_lr, 0.); + CHECK_GE(local_lr, 0.); params_lr_.push_back(local_lr); } } else { @@ -125,7 +152,7 @@ Net<Dtype>::Net(const NetParameter& param, layer_blobs.size()); for (int j = 0; j < layer_blobs.size(); ++j) { float local_decay = layers_[i]->layer_param().weight_decay(j); - CHECK_GT(local_decay, 0.); + CHECK_GE(local_decay, 0.); params_weight_decay_.push_back(local_decay); } } else { @@ -139,7 +166,6 @@ Net<Dtype>::Net(const NetParameter& param, << top_vecs_[i][topid]->width(); } } - LOG(INFO) << "Network initialization done."; } template <typename Dtype> @@ -159,11 +185,12 @@ const vector<Blob<Dtype>*>& Net<Dtype>::Forward( template <typename Dtype> Dtype Net<Dtype>::Backward() { Dtype loss = 0; - // TODO(Yangqing): figure out those layers that do not need backward. for (int i = layers_.size() - 1; i >= 0; --i) { - Dtype layer_loss = layers_[i]->Backward( - top_vecs_[i], true, &bottom_vecs_[i]); - loss += layer_loss; + if (layer_need_backward_[i]) { + Dtype layer_loss = layers_[i]->Backward( + top_vecs_[i], true, &bottom_vecs_[i]); + loss += layer_loss; + } } return loss; } |