From 2d975527a4c4eb3afd067bf9673415f25d15793b Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Mon, 14 Oct 2013 15:29:30 -0700 Subject: allow in-place neuron layers --- src/caffe/layers/neuron_layer.cpp | 8 ++++++-- src/caffe/net.cpp | 31 ++++++++++++++++++++++--------- 2 files changed, 28 insertions(+), 11 deletions(-) (limited to 'src') diff --git a/src/caffe/layers/neuron_layer.cpp b/src/caffe/layers/neuron_layer.cpp index fcf6ff52..dd09dca3 100644 --- a/src/caffe/layers/neuron_layer.cpp +++ b/src/caffe/layers/neuron_layer.cpp @@ -12,8 +12,12 @@ void NeuronLayer::SetUp(const vector*>& bottom, vector*>* top) { CHECK_EQ(bottom.size(), 1) << "Neuron Layer takes a single blob as input."; CHECK_EQ(top->size(), 1) << "Neuron Layer takes a single blob as output."; - (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), - bottom[0]->height(), bottom[0]->width()); + // NeuronLayer allows in-place computations. If the computation is not + // in-place, we will need to initialize the top blob. + if ((*top)[0] != bottom[0]) { + (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + } }; INSTANTIATE_CLASS(NeuronLayer); diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index ff1cca4b..22250da5 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -64,17 +64,30 @@ Net::Net(const NetParameter& param, } for (int j = 0; j < layer_connection.top_size(); ++j) { const string& blob_name = layer_connection.top(j); - if (blob_name_to_idx.find(blob_name) != blob_name_to_idx.end()) { + // Check if we are doing in-place computation + if (layer_connection.bottom_size() > j && + blob_name == layer_connection.bottom(j)) { + // In-place computation + LOG(INFO) << layer_param.name() << " -> " << blob_name << " (in-place)"; + available_blobs.insert(blob_name); + top_vecs_[i].push_back( + blobs_[blob_name_to_idx[blob_name]].get()); + top_id_vecs_[i].push_back(blob_name_to_idx[blob_name]); + } else if (blob_name_to_idx.find(blob_name) != blob_name_to_idx.end()) { + // If we are not doing in-place computation but has duplicated blobs, + // raise an error. LOG(FATAL) << "Duplicate blobs produced by multiple sources."; + } else { + // Normal output. + LOG(INFO) << layer_param.name() << " -> " << blob_name; + shared_ptr > blob_pointer(new Blob()); + blobs_.push_back(blob_pointer); + blob_names_.push_back(blob_name); + blob_name_to_idx[blob_name] = blob_names_.size() - 1; + available_blobs.insert(blob_name); + top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get()); + top_id_vecs_[i].push_back(blob_names_.size() - 1); } - LOG(INFO) << layer_param.name() << " -> " << blob_name; - shared_ptr > blob_pointer(new Blob()); - blobs_.push_back(blob_pointer); - blob_names_.push_back(blob_name); - blob_name_to_idx[blob_name] = blob_names_.size() - 1; - available_blobs.insert(blob_name); - top_vecs_[i].push_back(blobs_[blob_names_.size() - 1].get()); - top_id_vecs_[i].push_back(blob_names_.size() - 1); } } // In the end, all remaining blobs are considered output blobs. -- cgit v1.2.3