diff options
author | Ronghang Hu <huronghang@hotmail.com> | 2015-08-13 13:28:11 -0700 |
---|---|---|
committer | Ronghang Hu <huronghang@hotmail.com> | 2015-08-13 13:28:11 -0700 |
commit | bb0a90e8aabb0e1ffd5732439591814d10d7fd45 (patch) | |
tree | cc0f44d36d36ada214e4dfecaebab518d5bdf547 /src | |
parent | 8181870b9ac330a094ab0f8d53f54a0202f697a0 (diff) | |
parent | 6b50ed6fc1897ce1ccd673cf0287788b38b58a6d (diff) | |
download | caffeonacl-bb0a90e8aabb0e1ffd5732439591814d10d7fd45.tar.gz caffeonacl-bb0a90e8aabb0e1ffd5732439591814d10d7fd45.tar.bz2 caffeonacl-bb0a90e8aabb0e1ffd5732439591814d10d7fd45.zip |
Merge pull request #2903 from ronghanghu/multi_gpu
Multi-GPU Data Parallelism
Diffstat (limited to 'src')
-rw-r--r-- | src/caffe/common.cpp | 16 | ||||
-rw-r--r-- | src/caffe/data_reader.cpp | 119 | ||||
-rw-r--r-- | src/caffe/data_transformer.cpp | 4 | ||||
-rw-r--r-- | src/caffe/internal_thread.cpp | 58 | ||||
-rw-r--r-- | src/caffe/layer.cpp | 27 | ||||
-rw-r--r-- | src/caffe/layers/base_data_layer.cpp | 90 | ||||
-rw-r--r-- | src/caffe/layers/base_data_layer.cu | 15 | ||||
-rw-r--r-- | src/caffe/layers/data_layer.cpp | 81 | ||||
-rw-r--r-- | src/caffe/layers/image_data_layer.cpp | 28 | ||||
-rw-r--r-- | src/caffe/layers/window_data_layer.cpp | 20 | ||||
-rw-r--r-- | src/caffe/net.cpp | 213 | ||||
-rw-r--r-- | src/caffe/parallel.cpp | 438 | ||||
-rw-r--r-- | src/caffe/proto/caffe.proto | 8 | ||||
-rw-r--r-- | src/caffe/solver.cpp | 76 | ||||
-rw-r--r-- | src/caffe/syncedmem.cpp | 46 | ||||
-rw-r--r-- | src/caffe/test/test_gradient_based_solver.cpp | 75 | ||||
-rw-r--r-- | src/caffe/test/test_internal_thread.cpp | 34 | ||||
-rw-r--r-- | src/caffe/test/test_layer_factory.cpp | 14 | ||||
-rw-r--r-- | src/caffe/test/test_upgrade_proto.cpp | 12 | ||||
-rw-r--r-- | src/caffe/util/blocking_queue.cpp | 96 |
20 files changed, 1240 insertions, 230 deletions
diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index af96cac4..7077f378 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -1,3 +1,4 @@ +#include <boost/thread.hpp> #include <glog/logging.h> #include <cstdio> #include <ctime> @@ -7,7 +8,15 @@ namespace caffe { -shared_ptr<Caffe> Caffe::singleton_; +// Make sure each thread can have different values. +static boost::thread_specific_ptr<Caffe> thread_instance_; + +Caffe& Caffe::Get() { + if (!thread_instance_.get()) { + thread_instance_.reset(new Caffe()); + } + return *(thread_instance_.get()); +} // random seeding int64_t cluster_seedgen(void) { @@ -42,7 +51,8 @@ void GlobalInit(int* pargc, char*** pargv) { #ifdef CPU_ONLY // CPU-only Caffe. Caffe::Caffe() - : random_generator_(), mode_(Caffe::CPU) { } + : random_generator_(), mode_(Caffe::CPU), + solver_count_(1), root_solver_(true) { } Caffe::~Caffe() { } @@ -86,7 +96,7 @@ void* Caffe::RNG::generator() { Caffe::Caffe() : cublas_handle_(NULL), curand_generator_(NULL), random_generator_(), - mode_(Caffe::CPU) { + mode_(Caffe::CPU), solver_count_(1), root_solver_(true) { // Try to create a cublas handler, and report an error if failed (but we will // keep the program running as one might just want to run CPU code). if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) { diff --git a/src/caffe/data_reader.cpp b/src/caffe/data_reader.cpp new file mode 100644 index 00000000..16378203 --- /dev/null +++ b/src/caffe/data_reader.cpp @@ -0,0 +1,119 @@ +#include <boost/thread.hpp> +#include <map> +#include <string> +#include <vector> + +#include "caffe/common.hpp" +#include "caffe/data_layers.hpp" +#include "caffe/data_reader.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +using boost::weak_ptr; + +map<const string, weak_ptr<DataReader::Body> > DataReader::bodies_; +static boost::mutex bodies_mutex_; + +DataReader::DataReader(const LayerParameter& param) + : queue_pair_(new QueuePair( // + param.data_param().prefetch() * param.data_param().batch_size())) { + // Get or create a body + boost::mutex::scoped_lock lock(bodies_mutex_); + string key = source_key(param); + weak_ptr<Body>& weak = bodies_[key]; + body_ = weak.lock(); + if (!body_) { + body_.reset(new Body(param)); + bodies_[key] = weak_ptr<Body>(body_); + } + body_->new_queue_pairs_.push(queue_pair_); +} + +DataReader::~DataReader() { + string key = source_key(body_->param_); + body_.reset(); + boost::mutex::scoped_lock lock(bodies_mutex_); + if (bodies_[key].expired()) { + bodies_.erase(key); + } +} + +// + +DataReader::QueuePair::QueuePair(int size) { + // Initialize the free queue with requested number of datums + for (int i = 0; i < size; ++i) { + free_.push(new Datum()); + } +} + +DataReader::QueuePair::~QueuePair() { + Datum* datum; + while (free_.try_pop(&datum)) { + delete datum; + } + while (full_.try_pop(&datum)) { + delete datum; + } +} + +// + +DataReader::Body::Body(const LayerParameter& param) + : param_(param), + new_queue_pairs_() { + StartInternalThread(); +} + +DataReader::Body::~Body() { + StopInternalThread(); +} + +void DataReader::Body::InternalThreadEntry() { + shared_ptr<db::DB> db(db::GetDB(param_.data_param().backend())); + db->Open(param_.data_param().source(), db::READ); + shared_ptr<db::Cursor> cursor(db->NewCursor()); + vector<shared_ptr<QueuePair> > qps; + try { + int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1; + + // To ensure deterministic runs, only start running once all solvers + // are ready. But solvers need to peek on one item during initialization, + // so read one item, then wait for the next solver. + for (int i = 0; i < solver_count; ++i) { + shared_ptr<QueuePair> qp(new_queue_pairs_.pop()); + read_one(cursor.get(), qp.get()); + qps.push_back(qp); + } + // Main loop + while (!must_stop()) { + for (int i = 0; i < solver_count; ++i) { + read_one(cursor.get(), qps[i].get()); + } + // Check no additional readers have been created. This can happen if + // more than one net is trained at a time per process, whether single + // or multi solver. It might also happen if two data layers have same + // name and same source. + CHECK_EQ(new_queue_pairs_.size(), 0); + } + } catch (boost::thread_interrupted&) { + // Interrupted exception is expected on shutdown + } +} + +void DataReader::Body::read_one(db::Cursor* cursor, QueuePair* qp) { + Datum* datum = qp->free_.pop(); + // TODO deserialize in-place instead of copy? + datum->ParseFromString(cursor->value()); + qp->full_.push(datum); + + // go to the next iter + cursor->Next(); + if (!cursor->valid()) { + DLOG(INFO) << "Restarting data prefetching from start."; + cursor->SeekToFirst(); + } +} + +} // namespace caffe diff --git a/src/caffe/data_transformer.cpp b/src/caffe/data_transformer.cpp index 22633922..4666d9bd 100644 --- a/src/caffe/data_transformer.cpp +++ b/src/caffe/data_transformer.cpp @@ -19,7 +19,9 @@ DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param, CHECK_EQ(param_.mean_value_size(), 0) << "Cannot specify mean_file and mean_value at the same time"; const string& mean_file = param.mean_file(); - LOG(INFO) << "Loading mean file from: " << mean_file; + if (Caffe::root_solver()) { + LOG(INFO) << "Loading mean file from: " << mean_file; + } BlobProto blob_proto; ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); data_mean_.FromProto(blob_proto); diff --git a/src/caffe/internal_thread.cpp b/src/caffe/internal_thread.cpp index c2d19d43..104884e0 100644 --- a/src/caffe/internal_thread.cpp +++ b/src/caffe/internal_thread.cpp @@ -1,40 +1,66 @@ #include <boost/thread.hpp> +#include <exception> + #include "caffe/internal_thread.hpp" +#include "caffe/util/math_functions.hpp" namespace caffe { InternalThread::~InternalThread() { - WaitForInternalThreadToExit(); + StopInternalThread(); } bool InternalThread::is_started() const { - return thread_.get() != NULL && thread_->joinable(); + return thread_ && thread_->joinable(); +} + +bool InternalThread::must_stop() { + return thread_ && thread_->interruption_requested(); } +void InternalThread::StartInternalThread() { + CHECK(!is_started()) << "Threads should persist and not be restarted."; + + int device = 0; +#ifndef CPU_ONLY + CUDA_CHECK(cudaGetDevice(&device)); +#endif + Caffe::Brew mode = Caffe::mode(); + int rand_seed = caffe_rng_rand(); + int solver_count = Caffe::solver_count(); + bool root_solver = Caffe::root_solver(); -bool InternalThread::StartInternalThread() { - if (!WaitForInternalThreadToExit()) { - return false; - } try { - thread_.reset( - new boost::thread(&InternalThread::InternalThreadEntry, this)); - } catch (...) { - return false; + thread_.reset(new boost::thread(&InternalThread::entry, this, device, mode, + rand_seed, solver_count, root_solver)); + } catch (std::exception& e) { + LOG(FATAL) << "Thread exception: " << e.what(); } - return true; } -/** Will not return until the internal thread has exited. */ -bool InternalThread::WaitForInternalThreadToExit() { +void InternalThread::entry(int device, Caffe::Brew mode, int rand_seed, + int solver_count, bool root_solver) { +#ifndef CPU_ONLY + CUDA_CHECK(cudaSetDevice(device)); +#endif + Caffe::set_mode(mode); + Caffe::set_random_seed(rand_seed); + Caffe::set_solver_count(solver_count); + Caffe::set_root_solver(root_solver); + + InternalThreadEntry(); +} + +void InternalThread::StopInternalThread() { if (is_started()) { + thread_->interrupt(); try { thread_->join(); - } catch (...) { - return false; + } catch (boost::thread_interrupted&) { + } catch (std::exception& e) { + LOG(FATAL) << "Thread exception: " << e.what(); } } - return true; } } // namespace caffe diff --git a/src/caffe/layer.cpp b/src/caffe/layer.cpp new file mode 100644 index 00000000..3b912898 --- /dev/null +++ b/src/caffe/layer.cpp @@ -0,0 +1,27 @@ +#include <boost/thread.hpp> +#include "caffe/layer.hpp" + +namespace caffe { + +template <typename Dtype> +void Layer<Dtype>::InitMutex() { + forward_mutex_.reset(new boost::mutex()); +} + +template <typename Dtype> +void Layer<Dtype>::Lock() { + if (IsShared()) { + forward_mutex_->lock(); + } +} + +template <typename Dtype> +void Layer<Dtype>::Unlock() { + if (IsShared()) { + forward_mutex_->unlock(); + } +} + +INSTANTIATE_CLASS(Layer); + +} // namespace caffe diff --git a/src/caffe/layers/base_data_layer.cpp b/src/caffe/layers/base_data_layer.cpp index 26a11182..20f76f62 100644 --- a/src/caffe/layers/base_data_layer.cpp +++ b/src/caffe/layers/base_data_layer.cpp @@ -1,7 +1,9 @@ +#include <boost/thread.hpp> #include <string> #include <vector> #include "caffe/data_layers.hpp" +#include "caffe/net.hpp" #include "caffe/util/io.hpp" namespace caffe { @@ -28,55 +30,91 @@ void BaseDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, } template <typename Dtype> +BasePrefetchingDataLayer<Dtype>::BasePrefetchingDataLayer( + const LayerParameter& param) + : BaseDataLayer<Dtype>(param), + prefetch_free_(), prefetch_full_() { + for (int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_free_.push(&prefetch_[i]); + } +} + +template <typename Dtype> void BasePrefetchingDataLayer<Dtype>::LayerSetUp( const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { BaseDataLayer<Dtype>::LayerSetUp(bottom, top); - // Now, start the prefetch thread. Before calling prefetch, we make two - // cpu_data calls so that the prefetch thread does not accidentally make - // simultaneous cudaMalloc calls when the main thread is running. In some - // GPUs this seems to cause failures if we do not so. - this->prefetch_data_.mutable_cpu_data(); - if (this->output_labels_) { - this->prefetch_label_.mutable_cpu_data(); + // Before starting the prefetch thread, we make cpu_data and gpu_data + // calls so that the prefetch thread does not accidentally make simultaneous + // cudaMalloc calls when the main thread is running. In some GPUs this + // seems to cause failures if we do not so. + for (int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_[i].data_.mutable_cpu_data(); + if (this->output_labels_) { + prefetch_[i].label_.mutable_cpu_data(); + } + } +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + for (int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_[i].data_.mutable_gpu_data(); + if (this->output_labels_) { + prefetch_[i].label_.mutable_gpu_data(); + } + } } +#endif DLOG(INFO) << "Initializing prefetch"; - this->CreatePrefetchThread(); + this->data_transformer_->InitRand(); + StartInternalThread(); DLOG(INFO) << "Prefetch initialized."; } template <typename Dtype> -void BasePrefetchingDataLayer<Dtype>::CreatePrefetchThread() { - this->data_transformer_->InitRand(); - CHECK(StartInternalThread()) << "Thread execution failed"; -} +void BasePrefetchingDataLayer<Dtype>::InternalThreadEntry() { +#ifndef CPU_ONLY + cudaStream_t stream; + if (Caffe::mode() == Caffe::GPU) { + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + } +#endif -template <typename Dtype> -void BasePrefetchingDataLayer<Dtype>::JoinPrefetchThread() { - CHECK(WaitForInternalThreadToExit()) << "Thread joining failed"; + try { + while (!must_stop()) { + Batch<Dtype>* batch = prefetch_free_.pop(); + load_batch(batch); +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + batch->data_.data().get()->async_gpu_push(stream); + cudaStreamSynchronize(stream); + } +#endif + prefetch_full_.push(batch); + } + } catch (boost::thread_interrupted&) { + // Interrupted exception is expected on shutdown + } } template <typename Dtype> void BasePrefetchingDataLayer<Dtype>::Forward_cpu( const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { - // First, join the thread - JoinPrefetchThread(); - DLOG(INFO) << "Thread joined"; + Batch<Dtype>* batch = prefetch_full_.pop("Data layer prefetch queue empty"); // Reshape to loaded data. - top[0]->ReshapeLike(prefetch_data_); + top[0]->Reshape(batch->data_.num(), batch->data_.channels(), + batch->data_.height(), batch->data_.width()); // Copy the data - caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), + caffe_copy(batch->data_.count(), batch->data_.cpu_data(), top[0]->mutable_cpu_data()); DLOG(INFO) << "Prefetch copied"; if (this->output_labels_) { // Reshape to loaded labels. - top[1]->ReshapeLike(prefetch_label_); + top[1]->ReshapeLike(batch->label_); // Copy the labels. - caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), - top[1]->mutable_cpu_data()); + caffe_copy(batch->label_.count(), batch->label_.cpu_data(), + top[1]->mutable_cpu_data()); } - // Start a new prefetch thread - DLOG(INFO) << "CreatePrefetchThread"; - CreatePrefetchThread(); + + prefetch_free_.push(batch); } #ifdef CPU_ONLY diff --git a/src/caffe/layers/base_data_layer.cu b/src/caffe/layers/base_data_layer.cu index 9335a5bc..56439bc5 100644 --- a/src/caffe/layers/base_data_layer.cu +++ b/src/caffe/layers/base_data_layer.cu @@ -7,22 +7,21 @@ namespace caffe { template <typename Dtype> void BasePrefetchingDataLayer<Dtype>::Forward_gpu( const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { - // First, join the thread - JoinPrefetchThread(); + Batch<Dtype>* batch = prefetch_full_.pop("Data layer prefetch queue empty"); // Reshape to loaded data. - top[0]->ReshapeLike(this->prefetch_data_); + top[0]->ReshapeLike(batch->data_); // Copy the data - caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), + caffe_copy(batch->data_.count(), batch->data_.gpu_data(), top[0]->mutable_gpu_data()); if (this->output_labels_) { // Reshape to loaded labels. - top[1]->ReshapeLike(prefetch_label_); + top[1]->ReshapeLike(batch->label_); // Copy the labels. - caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), + caffe_copy(batch->label_.count(), batch->label_.gpu_data(), top[1]->mutable_gpu_data()); } - // Start a new prefetch thread - CreatePrefetchThread(); + + prefetch_free_.push(batch); } INSTANTIATE_LAYER_GPU_FORWARD(BasePrefetchingDataLayer); diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 161a75e0..0932d9fe 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -11,93 +11,85 @@ #include "caffe/proto/caffe.pb.h" #include "caffe/util/benchmark.hpp" #include "caffe/util/io.hpp" -#include "caffe/util/math_functions.hpp" -#include "caffe/util/rng.hpp" namespace caffe { template <typename Dtype> -DataLayer<Dtype>::~DataLayer<Dtype>() { - this->JoinPrefetchThread(); +DataLayer<Dtype>::DataLayer(const LayerParameter& param) + : BasePrefetchingDataLayer<Dtype>(param), + reader_(param) { +} + +template <typename Dtype> +DataLayer<Dtype>::~DataLayer() { + this->StopInternalThread(); } template <typename Dtype> void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { - // Initialize DB - db_.reset(db::GetDB(this->layer_param_.data_param().backend())); - db_->Open(this->layer_param_.data_param().source(), db::READ); - cursor_.reset(db_->NewCursor()); + const int batch_size = this->layer_param_.data_param().batch_size(); + // Read a data point, and use it to initialize the top blob. + Datum& datum = *(reader_.full().peek()); - // Check if we should randomly skip a few data points - if (this->layer_param_.data_param().rand_skip()) { - unsigned int skip = caffe_rng_rand() % - this->layer_param_.data_param().rand_skip(); - LOG(INFO) << "Skipping first " << skip << " data points."; - while (skip-- > 0) { - cursor_->Next(); - } - } - // Read a data point, to initialize the prefetch and top blobs. - Datum datum; - datum.ParseFromString(cursor_->value()); // Use data_transformer to infer the expected blob shape from datum. vector<int> top_shape = this->data_transformer_->InferBlobShape(datum); this->transformed_data_.Reshape(top_shape); // Reshape top[0] and prefetch_data according to the batch_size. - top_shape[0] = this->layer_param_.data_param().batch_size(); - this->prefetch_data_.Reshape(top_shape); - top[0]->ReshapeLike(this->prefetch_data_); - + top_shape[0] = batch_size; + top[0]->Reshape(top_shape); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].data_.Reshape(top_shape); + } LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," << top[0]->width(); // label if (this->output_labels_) { - vector<int> label_shape(1, this->layer_param_.data_param().batch_size()); + vector<int> label_shape(1, batch_size); top[1]->Reshape(label_shape); - this->prefetch_label_.Reshape(label_shape); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].label_.Reshape(label_shape); + } } } -// This function is used to create a thread that prefetches the data. -template <typename Dtype> -void DataLayer<Dtype>::InternalThreadEntry() { +// This function is called on prefetch thread +template<typename Dtype> +void DataLayer<Dtype>::load_batch(Batch<Dtype>* batch) { CPUTimer batch_timer; batch_timer.Start(); double read_time = 0; double trans_time = 0; CPUTimer timer; - CHECK(this->prefetch_data_.count()); + CHECK(batch->data_.count()); CHECK(this->transformed_data_.count()); // Reshape according to the first datum of each batch // on single input batches allows for inputs of varying dimension. const int batch_size = this->layer_param_.data_param().batch_size(); - Datum datum; - datum.ParseFromString(cursor_->value()); + Datum& datum = *(reader_.full().peek()); // Use data_transformer to infer the expected blob shape from datum. vector<int> top_shape = this->data_transformer_->InferBlobShape(datum); this->transformed_data_.Reshape(top_shape); - // Reshape prefetch_data according to the batch_size. + // Reshape batch according to the batch_size. top_shape[0] = batch_size; - this->prefetch_data_.Reshape(top_shape); + batch->data_.Reshape(top_shape); - Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); + Dtype* top_data = batch->data_.mutable_cpu_data(); Dtype* top_label = NULL; // suppress warnings about uninitialized variables if (this->output_labels_) { - top_label = this->prefetch_label_.mutable_cpu_data(); + top_label = batch->label_.mutable_cpu_data(); } - timer.Start(); for (int item_id = 0; item_id < batch_size; ++item_id) { + timer.Start(); // get a datum - Datum datum; - datum.ParseFromString(cursor_->value()); + Datum& datum = *(reader_.full().pop("Waiting for data")); read_time += timer.MicroSeconds(); timer.Start(); // Apply data transformations (mirror, scale, crop...) - int offset = this->prefetch_data_.offset(item_id); + int offset = batch->data_.offset(item_id); this->transformed_data_.set_cpu_data(top_data + offset); this->data_transformer_->Transform(datum, &(this->transformed_data_)); // Copy label. @@ -105,13 +97,8 @@ void DataLayer<Dtype>::InternalThreadEntry() { top_label[item_id] = datum.label(); } trans_time += timer.MicroSeconds(); - timer.Start(); - // go to the next item. - cursor_->Next(); - if (!cursor_->valid()) { - DLOG(INFO) << "Restarting data prefetching from start."; - cursor_->SeekToFirst(); - } + + reader_.free().push(const_cast<Datum*>(&datum)); } timer.Stop(); batch_timer.Stop(); diff --git a/src/caffe/layers/image_data_layer.cpp b/src/caffe/layers/image_data_layer.cpp index dcc53348..223ba3a7 100644 --- a/src/caffe/layers/image_data_layer.cpp +++ b/src/caffe/layers/image_data_layer.cpp @@ -17,7 +17,7 @@ namespace caffe { template <typename Dtype> ImageDataLayer<Dtype>::~ImageDataLayer<Dtype>() { - this->JoinPrefetchThread(); + this->StopInternalThread(); } template <typename Dtype> @@ -70,8 +70,10 @@ void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom, const int batch_size = this->layer_param_.image_data_param().batch_size(); CHECK_GT(batch_size, 0) << "Positive batch size required"; top_shape[0] = batch_size; - this->prefetch_data_.Reshape(top_shape); - top[0]->ReshapeLike(this->prefetch_data_); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].data_.Reshape(top_shape); + } + top[0]->Reshape(top_shape); LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," @@ -79,7 +81,9 @@ void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom, // label vector<int> label_shape(1, batch_size); top[1]->Reshape(label_shape); - this->prefetch_label_.Reshape(label_shape); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].label_.Reshape(label_shape); + } } template <typename Dtype> @@ -89,15 +93,15 @@ void ImageDataLayer<Dtype>::ShuffleImages() { shuffle(lines_.begin(), lines_.end(), prefetch_rng); } -// This function is used to create a thread that prefetches the data. +// This function is called on prefetch thread template <typename Dtype> -void ImageDataLayer<Dtype>::InternalThreadEntry() { +void ImageDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) { CPUTimer batch_timer; batch_timer.Start(); double read_time = 0; double trans_time = 0; CPUTimer timer; - CHECK(this->prefetch_data_.count()); + CHECK(batch->data_.count()); CHECK(this->transformed_data_.count()); ImageDataParameter image_data_param = this->layer_param_.image_data_param(); const int batch_size = image_data_param.batch_size(); @@ -114,12 +118,12 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() { // Use data_transformer to infer the expected blob shape from a cv_img. vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img); this->transformed_data_.Reshape(top_shape); - // Reshape prefetch_data according to the batch_size. + // Reshape batch according to the batch_size. top_shape[0] = batch_size; - this->prefetch_data_.Reshape(top_shape); + batch->data_.Reshape(top_shape); - Dtype* prefetch_data = this->prefetch_data_.mutable_cpu_data(); - Dtype* prefetch_label = this->prefetch_label_.mutable_cpu_data(); + Dtype* prefetch_data = batch->data_.mutable_cpu_data(); + Dtype* prefetch_label = batch->label_.mutable_cpu_data(); // datum scales const int lines_size = lines_.size(); @@ -133,7 +137,7 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() { read_time += timer.MicroSeconds(); timer.Start(); // Apply transformations (mirror, crop...) to the image - int offset = this->prefetch_data_.offset(item_id); + int offset = batch->data_.offset(item_id); this->transformed_data_.set_cpu_data(prefetch_data + offset); this->data_transformer_->Transform(cv_img, &(this->transformed_data_)); trans_time += timer.MicroSeconds(); diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp index c127d56b..f637f2ec 100644 --- a/src/caffe/layers/window_data_layer.cpp +++ b/src/caffe/layers/window_data_layer.cpp @@ -27,7 +27,7 @@ namespace caffe { template <typename Dtype> WindowDataLayer<Dtype>::~WindowDataLayer<Dtype>() { - this->JoinPrefetchThread(); + this->StopInternalThread(); } template <typename Dtype> @@ -171,7 +171,9 @@ void WindowDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom, CHECK_GT(crop_size, 0); const int batch_size = this->layer_param_.window_data_param().batch_size(); top[0]->Reshape(batch_size, channels, crop_size, crop_size); - this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].data_.Reshape( + batch_size, channels, crop_size, crop_size); LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," @@ -179,7 +181,9 @@ void WindowDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom, // label vector<int> label_shape(1, batch_size); top[1]->Reshape(label_shape); - this->prefetch_label_.Reshape(label_shape); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].label_.Reshape(label_shape); + } // data mean has_mean_file_ = this->transform_param_.has_mean_file(); @@ -217,9 +221,9 @@ unsigned int WindowDataLayer<Dtype>::PrefetchRand() { return (*prefetch_rng)(); } -// Thread fetching the data +// This function is called on prefetch thread template <typename Dtype> -void WindowDataLayer<Dtype>::InternalThreadEntry() { +void WindowDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) { // At each iteration, sample N windows where N*p are foreground (object) // windows and N*(1-p) are background (non-object) windows CPUTimer batch_timer; @@ -227,8 +231,8 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() { double read_time = 0; double trans_time = 0; CPUTimer timer; - Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); - Dtype* top_label = this->prefetch_label_.mutable_cpu_data(); + Dtype* top_data = batch->data_.mutable_cpu_data(); + Dtype* top_label = batch->label_.mutable_cpu_data(); const Dtype scale = this->layer_param_.window_data_param().scale(); const int batch_size = this->layer_param_.window_data_param().batch_size(); const int context_pad = this->layer_param_.window_data_param().context_pad(); @@ -252,7 +256,7 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() { bool use_square = (crop_mode == "square") ? true : false; // zero out batch - caffe_set(this->prefetch_data_.count(), Dtype(0), top_data); + caffe_set(batch->data_.count(), Dtype(0), top_data); const int num_fg = static_cast<int>(static_cast<float>(batch_size) * fg_fraction); diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 91883a10..7875285f 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -10,6 +10,7 @@ #include "caffe/common.hpp" #include "caffe/layer.hpp" #include "caffe/net.hpp" +#include "caffe/parallel.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/hdf5.hpp" #include "caffe/util/insert_splits.hpp" @@ -21,12 +22,14 @@ namespace caffe { template <typename Dtype> -Net<Dtype>::Net(const NetParameter& param) { +Net<Dtype>::Net(const NetParameter& param, const Net* root_net) + : root_net_(root_net) { Init(param); } template <typename Dtype> -Net<Dtype>::Net(const string& param_file, Phase phase) { +Net<Dtype>::Net(const string& param_file, Phase phase, const Net* root_net) + : root_net_(root_net) { NetParameter param; ReadNetParamsFromTextFileOrDie(param_file, ¶m); param.mutable_state()->set_phase(phase); @@ -35,14 +38,18 @@ Net<Dtype>::Net(const string& param_file, Phase phase) { template <typename Dtype> void Net<Dtype>::Init(const NetParameter& in_param) { + CHECK(Caffe::root_solver() || root_net_) + << "root_net_ needs to be set for all non-root solvers"; // Set phase from the state. phase_ = in_param.state().phase(); // Filter layers based on their include/exclude rules and // the current NetState. NetParameter filtered_param; FilterNet(in_param, &filtered_param); - LOG(INFO) << "Initializing net from parameters: " << std::endl - << filtered_param.DebugString(); + if (Caffe::root_solver()) { + LOG(INFO) << "Initializing net from parameters: " << std::endl + << filtered_param.DebugString(); + } // Create a copy of filtered_param with splits added where necessary. NetParameter param; InsertSplits(filtered_param, ¶m); @@ -66,7 +73,8 @@ void Net<Dtype>::Init(const NetParameter& in_param) { const int layer_id = -1; // inputs have fake layer ID -1 AppendTop(param, layer_id, input_id, &available_blobs, &blob_name_to_idx); } - DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + DLOG_IF(INFO, Caffe::root_solver()) + << "Memory required for data: " << memory_used_ * sizeof(Dtype); // For each layer, set up its input and output bottom_vecs_.resize(param.layer_size()); top_vecs_.resize(param.layer_size()); @@ -75,6 +83,9 @@ void Net<Dtype>::Init(const NetParameter& in_param) { top_id_vecs_.resize(param.layer_size()); bottom_need_backward_.resize(param.layer_size()); for (int layer_id = 0; layer_id < param.layer_size(); ++layer_id) { + // For non-root solvers, whether this layer is shared from root_net_. + bool share_from_root = !Caffe::root_solver() + && root_net_->layers_[layer_id]->ShareInParallel(); // Inherit phase from net if unset. if (!param.layer(layer_id).has_phase()) { param.mutable_layer(layer_id)->set_phase(phase_); @@ -87,9 +98,17 @@ void Net<Dtype>::Init(const NetParameter& in_param) { << "propagate_down param must be specified " << "either 0 or bottom_size times "; } - layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param)); + if (share_from_root) { + LOG(INFO) << "Sharing layer " << layer_param.name() << " from root net"; + layers_.push_back(root_net_->layers_[layer_id]); + layers_[layer_id]->SetShared(true); + } else { + layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param)); + } layer_names_.push_back(layer_param.name()); - LOG(INFO) << "Creating Layer " << layer_param.name(); + if (Caffe::root_solver()) { + LOG(INFO) << "Creating Layer " << layer_param.name(); + } bool need_backward = false; // Figure out this layer's input and output @@ -119,20 +138,42 @@ void Net<Dtype>::Init(const NetParameter& in_param) { } } // After this layer is connected, set it up. - LOG(INFO) << "Setting up " << layer_names_[layer_id]; - layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]); + if (share_from_root) { + // Set up size of top blobs using root_net_ + const vector<Blob<Dtype>*>& base_top = root_net_->top_vecs_[layer_id]; + const vector<Blob<Dtype>*>& this_top = this->top_vecs_[layer_id]; + for (int top_id = 0; top_id < base_top.size(); ++top_id) { + this_top[top_id]->ReshapeLike(*base_top[top_id]); + LOG(INFO) << "Created top blob " << top_id << " (shape: " + << this_top[top_id]->shape_string() << ") for shared layer " + << layer_param.name(); + } + } else { + layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]); + } + if (Caffe::root_solver()) { + LOG(INFO) << "Setting up " << layer_names_[layer_id]; + } for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) { if (blob_loss_weights_.size() <= top_id_vecs_[layer_id][top_id]) { blob_loss_weights_.resize(top_id_vecs_[layer_id][top_id] + 1, Dtype(0)); } blob_loss_weights_[top_id_vecs_[layer_id][top_id]] = layer->loss(top_id); - LOG(INFO) << "Top shape: " << top_vecs_[layer_id][top_id]->shape_string(); + if (Caffe::root_solver()) { + LOG(INFO) << "Top shape: " + << top_vecs_[layer_id][top_id]->shape_string(); + } if (layer->loss(top_id)) { - LOG(INFO) << " with loss weight " << layer->loss(top_id); + if (Caffe::root_solver()) { + LOG(INFO) << " with loss weight " << layer->loss(top_id); + } } memory_used_ += top_vecs_[layer_id][top_id]->count(); } - DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + if (Caffe::root_solver()) { + DLOG(INFO) << "Memory required for data: " + << memory_used_ * sizeof(Dtype); + } const int param_size = layer_param.param_size(); const int num_param_blobs = layers_[layer_id]->blobs().size(); CHECK_LE(param_size, num_param_blobs) @@ -191,10 +232,14 @@ void Net<Dtype>::Init(const NetParameter& in_param) { } if (!layer_contributes_loss) { layer_need_backward_[layer_id] = false; } if (layer_need_backward_[layer_id]) { - LOG(INFO) << layer_names_[layer_id] << " needs backward computation."; + if (Caffe::root_solver()) { + LOG(INFO) << layer_names_[layer_id] << " needs backward computation."; + } } else { - LOG(INFO) << layer_names_[layer_id] - << " does not need backward computation."; + if (Caffe::root_solver()) { + LOG(INFO) << layer_names_[layer_id] + << " does not need backward computation."; + } } for (int bottom_id = 0; bottom_id < bottom_vecs_[layer_id].size(); ++bottom_id) { @@ -234,7 +279,9 @@ void Net<Dtype>::Init(const NetParameter& in_param) { // In the end, all remaining blobs are considered output blobs. for (set<string>::iterator it = available_blobs.begin(); it != available_blobs.end(); ++it) { - LOG(INFO) << "This network produces output " << *it; + if (Caffe::root_solver()) { + LOG(INFO) << "This network produces output " << *it; + } net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get()); net_output_blob_indices_.push_back(blob_name_to_idx[*it]); } @@ -246,8 +293,10 @@ void Net<Dtype>::Init(const NetParameter& in_param) { } ShareWeights(); debug_info_ = param.debug_info(); - LOG(INFO) << "Network initialization done."; - LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + if (Caffe::root_solver()) { + LOG(INFO) << "Network initialization done."; + LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + } } template <typename Dtype> @@ -286,27 +335,33 @@ bool Net<Dtype>::StateMeetsRule(const NetState& state, // Check whether the rule is broken due to phase. if (rule.has_phase()) { if (rule.phase() != state.phase()) { - LOG(INFO) << "The NetState phase (" << state.phase() - << ") differed from the phase (" << rule.phase() - << ") specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState phase (" << state.phase() + << ") differed from the phase (" << rule.phase() + << ") specified by a rule in layer " << layer_name; + } return false; } } // Check whether the rule is broken due to min level. if (rule.has_min_level()) { if (state.level() < rule.min_level()) { - LOG(INFO) << "The NetState level (" << state.level() - << ") is above the min_level (" << rule.min_level() - << ") specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState level (" << state.level() + << ") is above the min_level (" << rule.min_level() + << ") specified by a rule in layer " << layer_name; + } return false; } } // Check whether the rule is broken due to max level. if (rule.has_max_level()) { if (state.level() > rule.max_level()) { - LOG(INFO) << "The NetState level (" << state.level() - << ") is above the max_level (" << rule.max_level() - << ") specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState level (" << state.level() + << ") is above the max_level (" << rule.max_level() + << ") specified by a rule in layer " << layer_name; + } return false; } } @@ -319,8 +374,10 @@ bool Net<Dtype>::StateMeetsRule(const NetState& state, if (rule.stage(i) == state.stage(j)) { has_stage = true; } } if (!has_stage) { - LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i) - << "' specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i) + << "' specified by a rule in layer " << layer_name; + } return false; } } @@ -333,8 +390,10 @@ bool Net<Dtype>::StateMeetsRule(const NetState& state, if (rule.not_stage(i) == state.stage(j)) { has_stage = true; } } if (has_stage) { - LOG(INFO) << "The NetState contained a not_stage '" << rule.not_stage(i) - << "' specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState contained a not_stage '" << rule.not_stage(i) + << "' specified by a rule in layer " << layer_name; + } return false; } } @@ -356,7 +415,9 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id, if (blob_name_to_idx && layer_param && layer_param->bottom_size() > top_id && blob_name == layer_param->bottom(top_id)) { // In-place computation - LOG(INFO) << layer_param->name() << " -> " << blob_name << " (in-place)"; + if (Caffe::root_solver()) { + LOG(INFO) << layer_param->name() << " -> " << blob_name << " (in-place)"; + } top_vecs_[layer_id].push_back(blobs_[(*blob_name_to_idx)[blob_name]].get()); top_id_vecs_[layer_id].push_back((*blob_name_to_idx)[blob_name]); } else if (blob_name_to_idx && @@ -366,10 +427,12 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id, LOG(FATAL) << "Duplicate blobs produced by multiple sources."; } else { // Normal output. - if (layer_param) { - LOG(INFO) << layer_param->name() << " -> " << blob_name; - } else { - LOG(INFO) << "Input " << top_id << " -> " << blob_name; + if (Caffe::root_solver()) { + if (layer_param) { + LOG(INFO) << layer_param->name() << " -> " << blob_name; + } else { + LOG(INFO) << "Input " << top_id << " -> " << blob_name; + } } shared_ptr<Blob<Dtype> > blob_pointer(new Blob<Dtype>()); const int blob_id = blobs_.size(); @@ -409,7 +472,9 @@ int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id, << " (at index " << bottom_id << ") to layer " << layer_id; } const int blob_id = (*blob_name_to_idx)[blob_name]; - LOG(INFO) << layer_names_[layer_id] << " <- " << blob_name; + if (Caffe::root_solver()) { + LOG(INFO) << layer_names_[layer_id] << " <- " << blob_name; + } bottom_vecs_[layer_id].push_back(blobs_[blob_id].get()); bottom_id_vecs_[layer_id].push_back(blob_id); available_blobs->erase(blob_name); @@ -468,9 +533,10 @@ void Net<Dtype>::AppendParam(const NetParameter& param, const int layer_id, param_layer_indices_[owner_net_param_id]; const int owner_layer_id = owner_index.first; const int owner_param_id = owner_index.second; - LOG(INFO) << "Sharing parameters '" << param_name << "' owned by " - << "layer '" << layer_names_[owner_layer_id] << "', param " - << "index " << owner_param_id; + LOG_IF(INFO, Caffe::root_solver()) << "Sharing parameters '" << param_name + << "' owned by " + << "layer '" << layer_names_[owner_layer_id] << "', param " + << "index " << owner_param_id; Blob<Dtype>* this_blob = layers_[layer_id]->blobs()[param_id].get(); Blob<Dtype>* owner_blob = layers_[owner_layer_id]->blobs()[owner_param_id].get(); @@ -596,8 +662,10 @@ void Net<Dtype>::InputDebugInfo(const int input_id) { const Blob<Dtype>& blob = *net_input_blobs_[input_id]; const string& blob_name = blob_names_[net_input_blob_indices_[input_id]]; const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Forward] " - << "Input " << blob_name << " data: " << data_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Forward] " + << "Input " << blob_name << " data: " << data_abs_val_mean; + } } template <typename Dtype> @@ -606,9 +674,12 @@ void Net<Dtype>::ForwardDebugInfo(const int layer_id) { const Blob<Dtype>& blob = *top_vecs_[layer_id][top_id]; const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]]; const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Forward] " - << "Layer " << layer_names_[layer_id] << ", top blob " << blob_name - << " data: " << data_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Forward] " + << "Layer " << layer_names_[layer_id] + << ", top blob " << blob_name + << " data: " << data_abs_val_mean; + } } for (int param_id = 0; param_id < layers_[layer_id]->blobs().size(); ++param_id) { @@ -616,9 +687,12 @@ void Net<Dtype>::ForwardDebugInfo(const int layer_id) { const int net_param_id = param_id_vecs_[layer_id][param_id]; const string& blob_name = param_display_names_[net_param_id]; const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Forward] " - << "Layer " << layer_names_[layer_id] << ", param blob " << blob_name - << " data: " << data_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Forward] " + << "Layer " << layer_names_[layer_id] + << ", param blob " << blob_name + << " data: " << data_abs_val_mean; + } } } @@ -630,18 +704,24 @@ void Net<Dtype>::BackwardDebugInfo(const int layer_id) { const Blob<Dtype>& blob = *bottom_vec[bottom_id]; const string& blob_name = blob_names_[bottom_id_vecs_[layer_id][bottom_id]]; const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); - LOG(INFO) << " [Backward] " - << "Layer " << layer_names_[layer_id] << ", bottom blob " << blob_name - << " diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Backward] " + << "Layer " << layer_names_[layer_id] + << ", bottom blob " << blob_name + << " diff: " << diff_abs_val_mean; + } } for (int param_id = 0; param_id < layers_[layer_id]->blobs().size(); ++param_id) { if (!layers_[layer_id]->param_propagate_down(param_id)) { continue; } const Blob<Dtype>& blob = *layers_[layer_id]->blobs()[param_id]; const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); - LOG(INFO) << " [Backward] " - << "Layer " << layer_names_[layer_id] << ", param blob " << param_id - << " diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Backward] " + << "Layer " << layer_names_[layer_id] + << ", param blob " << param_id + << " diff: " << diff_abs_val_mean; + } } } @@ -654,17 +734,22 @@ void Net<Dtype>::UpdateDebugInfo(const int param_id) { const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); if (param_owner < 0) { const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Update] Layer " << layer_name - << ", param " << param_display_name - << " data: " << data_abs_val_mean << "; diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Update] Layer " << layer_name + << ", param " << param_display_name + << " data: " << data_abs_val_mean + << "; diff: " << diff_abs_val_mean; + } } else { const string& owner_layer_name = layer_names_[param_layer_indices_[param_owner].first]; - LOG(INFO) << " [Update] Layer " << layer_name - << ", param blob " << param_display_name - << " (owned by layer " << owner_layer_name << ", " - << "param " << param_display_names_[param_owners_[param_id]] << ")" - << " diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Update] Layer " << layer_name + << ", param blob " << param_display_name + << " (owned by layer " << owner_layer_name << ", " << "param " + << param_display_names_[param_owners_[param_id]] << ")" + << " diff: " << diff_abs_val_mean; + } } } @@ -721,8 +806,8 @@ void Net<Dtype>::Backward() { const Dtype l2norm_data = std::sqrt(sumsq_data); const Dtype l2norm_diff = std::sqrt(sumsq_diff); LOG(ERROR) << " [Backward] All net params (data, diff): " - << "L1 norm = (" << asum_data << ", " << asum_diff << "); " - << "L2 norm = (" << l2norm_data << ", " << l2norm_diff << ")"; + << "L1 norm = (" << asum_data << ", " << asum_diff << "); " + << "L2 norm = (" << l2norm_data << ", " << l2norm_diff << ")"; } } diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp new file mode 100644 index 00000000..6e7d802b --- /dev/null +++ b/src/caffe/parallel.cpp @@ -0,0 +1,438 @@ +#ifndef CPU_ONLY +#include <cuda_runtime.h> +#endif +#include <glog/logging.h> +#include <stdio.h> +#include <sys/ioctl.h> +#include <sys/mman.h> +#include <sys/stat.h> + +#include <cstdlib> +#include <sstream> +#include <string> +#include <vector> + +#include "boost/thread.hpp" +#include "caffe/caffe.hpp" +#include "caffe/parallel.hpp" + +namespace caffe { + +enum Op { + copy, + replace_cpu, + replace_gpu, + replace_cpu_diff, + replace_gpu_diff +}; + +template<typename Dtype> +static void apply_buffers(const vector<Blob<Dtype>*>& blobs, + Dtype* buffer, size_t total_size, Op op) { + Dtype* ptr = buffer; + for (int i = 0; i < blobs.size(); ++i) { + int size = blobs[i]->count(); + switch (op) { + case copy: { + // Init buffer to current values of blobs + caffe_copy(size, + reinterpret_cast<const Dtype*>(blobs[i]->data()->cpu_data()), + ptr); + break; + } + case replace_cpu: + blobs[i]->data()->set_cpu_data(ptr); + break; + case replace_gpu: + blobs[i]->data()->set_gpu_data(ptr); + break; + case replace_cpu_diff: + blobs[i]->diff()->set_cpu_data(ptr); + break; + case replace_gpu_diff: + blobs[i]->diff()->set_gpu_data(ptr); + break; + } + ptr += size; + } + CHECK_EQ(total_size, ptr - buffer); +} + +// Buffer size necessary to store given blobs +template<typename Dtype> +static size_t total_size(const vector<Blob<Dtype>*>& params) { + size_t size = 0; + for (int i = 0; i < params.size(); ++i) + size += params[i]->count(); + return size; +} + +template<typename Dtype> +Params<Dtype>::Params(shared_ptr<Solver<Dtype> > root_solver) + : size_(total_size<Dtype>(root_solver->net()->learnable_params())), + data_(), + diff_() { +} + +template<typename Dtype> +GPUParams<Dtype>::GPUParams(shared_ptr<Solver<Dtype> > root_solver, int device) + : Params<Dtype>(root_solver) { +#ifndef CPU_ONLY + int initial_device; + CUDA_CHECK(cudaGetDevice(&initial_device)); + + // Allocate device buffers + CUDA_CHECK(cudaSetDevice(device)); + CUDA_CHECK(cudaMalloc(&data_, size_ * sizeof(Dtype))); + + // Copy blob values + const vector<Blob<Dtype>*>& net = + root_solver->net()->learnable_params(); + apply_buffers(net, data_, size_, copy); + + CUDA_CHECK(cudaMalloc(&diff_, size_ * sizeof(Dtype))); + caffe_gpu_set(size_, Dtype(0), diff_); + + CUDA_CHECK(cudaSetDevice(initial_device)); +#else + NO_GPU; +#endif +} + +template<typename Dtype> +GPUParams<Dtype>::~GPUParams() { +#ifndef CPU_ONLY + CUDA_CHECK(cudaFree(data_)); + CUDA_CHECK(cudaFree(diff_)); +#endif +} + +template<typename Dtype> +void GPUParams<Dtype>::configure(Solver<Dtype>* solver) const { + const vector<Blob<Dtype>*>& net = + solver->net()->learnable_params(); + apply_buffers(net, data_, size_, replace_gpu); + apply_buffers(net, diff_, size_, replace_gpu_diff); +} + +void DevicePair::compute(const vector<int> devices, vector<DevicePair>* pairs) { +#ifndef CPU_ONLY + vector<int> remaining(devices); + + // Depth for reduction tree + int remaining_depth = static_cast<int>(ceil(log2(remaining.size()))); + + // Group GPUs by board + for (int d = 0; d < remaining_depth; ++d) { + for (int i = 0; i < remaining.size(); ++i) { + for (int j = i + 1; j < remaining.size(); ++j) { + cudaDeviceProp a, b; + CUDA_CHECK(cudaGetDeviceProperties(&a, remaining[i])); + CUDA_CHECK(cudaGetDeviceProperties(&b, remaining[j])); + if (a.isMultiGpuBoard && b.isMultiGpuBoard) { + if (a.multiGpuBoardGroupID == b.multiGpuBoardGroupID) { + pairs->push_back(DevicePair(remaining[i], remaining[j])); + DLOG(INFO) << "GPU board: " << remaining[i] << ":" << remaining[j]; + remaining.erase(remaining.begin() + j); + break; + } + } + } + } + } + ostringstream s; + for (int i = 0; i < remaining.size(); ++i) { + s << (i ? ", " : "") << remaining[i]; + } + DLOG(INFO) << "GPUs paired by boards, remaining: " << s.str(); + + // Group by P2P accessibility + remaining_depth = ceil(log2(remaining.size())); + for (int d = 0; d < remaining_depth; ++d) { + for (int i = 0; i < remaining.size(); ++i) { + for (int j = i + 1; j < remaining.size(); ++j) { + int access; + CUDA_CHECK( + cudaDeviceCanAccessPeer(&access, remaining[i], remaining[j])); + if (access) { + pairs->push_back(DevicePair(remaining[i], remaining[j])); + DLOG(INFO) << "P2P pair: " << remaining[i] << ":" << remaining[j]; + remaining.erase(remaining.begin() + j); + break; + } + } + } + } + s.str(""); + for (int i = 0; i < remaining.size(); ++i) { + s << (i ? ", " : "") << remaining[i]; + } + DLOG(INFO) << "GPUs paired by P2P access, remaining: " << s.str(); + + // Group remaining + remaining_depth = ceil(log2(remaining.size())); + for (int d = 0; d < remaining_depth; ++d) { + for (int i = 0; i < remaining.size(); ++i) { + pairs->push_back(DevicePair(remaining[i], remaining[i + 1])); + DLOG(INFO) << "Remaining pair: " << remaining[i] << ":" + << remaining[i + 1]; + remaining.erase(remaining.begin() + i + 1); + } + } + + // Should only be the parent node remaining + CHECK_EQ(remaining.size(), 1); + + pairs->insert(pairs->begin(), DevicePair(-1, remaining[0])); + + CHECK(pairs->size() == devices.size()); + for (int i = 0; i < pairs->size(); ++i) { + CHECK((*pairs)[i].parent() != (*pairs)[i].device()); + for (int j = i + 1; j < pairs->size(); ++j) { + CHECK((*pairs)[i].device() != (*pairs)[j].device()); + } + } +#else + NO_GPU; +#endif +} + +// + +template<typename Dtype> +P2PSync<Dtype>::P2PSync(shared_ptr<Solver<Dtype> > root_solver, + P2PSync<Dtype>* parent, const SolverParameter& param) + : GPUParams<Dtype>(root_solver, param.device_id()), + parent_(parent), + children_(), + queue_(), + initial_iter_(root_solver->iter()), + solver_() { +#ifndef CPU_ONLY + int initial_device; + CUDA_CHECK(cudaGetDevice(&initial_device)); + const int self = param.device_id(); + CUDA_CHECK(cudaSetDevice(self)); + + if (parent == NULL) { + solver_ = root_solver; + } else { + Caffe::set_root_solver(false); + solver_.reset(new WorkerSolver<Dtype>(param, root_solver.get())); + Caffe::set_root_solver(true); + } + this->configure(solver_.get()); + solver_->add_callback(this); + + if (parent) { + // Enable p2p access between devices + const int peer = parent->solver_->param().device_id(); + int access; + CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer)); + if (access) { + CUDA_CHECK(cudaDeviceEnablePeerAccess(peer, 0)); + } else { + LOG(INFO)<< "GPU " << self << " does not have p2p access to GPU " << peer; + } + // Allocate receiving buffer on parent + CUDA_CHECK(cudaSetDevice(peer)); + CUDA_CHECK(cudaMalloc(&parent_grads_, size_ * sizeof(Dtype))); + CUDA_CHECK(cudaSetDevice(self)); + } + + CUDA_CHECK(cudaSetDevice(initial_device)); +#else + NO_GPU; +#endif +} + +template<typename Dtype> +P2PSync<Dtype>::~P2PSync() { +#ifndef CPU_ONLY + int initial_device; + CUDA_CHECK(cudaGetDevice(&initial_device)); + const int self = solver_->param().device_id(); + CUDA_CHECK(cudaSetDevice(self)); + + if (parent_) { + CUDA_CHECK(cudaFree(parent_grads_)); + const int peer = parent_->solver_->param().device_id(); + int access; + CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer)); + if (access) { + CUDA_CHECK(cudaDeviceDisablePeerAccess(peer)); + } + } + + CUDA_CHECK(cudaSetDevice(initial_device)); +#endif +} + +template<typename Dtype> +void P2PSync<Dtype>::InternalThreadEntry() { + Caffe::SetDevice(solver_->param().device_id()); + CHECK(Caffe::root_solver()); + Caffe::set_root_solver(false); + // See if there is a defined seed and reset random state if so + if (solver_->param().random_seed() >= 0) { + // Fetch random seed and modulate by device ID to make sure + // everyone doesn't have the same seed. We seem to have some + // solver instability if we have everyone with the same seed + Caffe::set_random_seed( + solver_->param().random_seed() + solver_->param().device_id()); + } + solver_->Step(solver_->param().max_iter() - initial_iter_); +} + +template<typename Dtype> +void P2PSync<Dtype>::on_start() { +#ifndef CPU_ONLY +#ifdef DEBUG + int device; + CUDA_CHECK(cudaGetDevice(&device)); + CHECK(device == solver_->param().device_id()); +#else +// CHECK(false); +#endif + + // Wait for update from parent + if (parent_) { + P2PSync<Dtype> *parent = queue_.pop(); + CHECK(parent == parent_); + } + + // Update children + for (int i = children_.size() - 1; i >= 0; i--) { + Dtype* src = data_; + Dtype* dst = children_[i]->data_; + +#ifdef DEBUG + cudaPointerAttributes attributes; + CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); + CHECK(attributes.device == device); + CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); + CHECK(attributes.device == children_[i]->solver_->param().device_id()); +#endif + + CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), + cudaMemcpyDeviceToDevice, cudaStreamDefault)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); + children_[i]->queue_.push(this); + } +#endif +} + +template<typename Dtype> +void P2PSync<Dtype>::on_gradients_ready() { +#ifndef CPU_ONLY +#ifdef DEBUG + int device; + CUDA_CHECK(cudaGetDevice(&device)); + CHECK(device == solver_->param().device_id()); +#endif + + // Sum children gradients as they appear in the queue + for (int i = 0; i < children_.size(); ++i) { + P2PSync<Dtype> *child = queue_.pop(); + Dtype* src = child->parent_grads_; + Dtype* dst = diff_; + +#ifdef DEBUG + bool ok = false; + for (int j = 0; j < children_.size(); ++j) { + if (child == children_[j]) { + ok = true; + } + } + CHECK(ok); + cudaPointerAttributes attributes; + CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); + CHECK(attributes.device == device); + CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); + CHECK(attributes.device == device); +#endif + + caffe_gpu_add(size_, src, dst, dst); + } + + // Send gradients to parent + if (parent_) { + Dtype* src = diff_; + Dtype* dst = parent_grads_; + +#ifdef DEBUG + cudaPointerAttributes attributes; + CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); + CHECK(attributes.device == device); + CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); + CHECK(attributes.device == parent_->solver_->param().device_id()); +#endif + + CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), // + cudaMemcpyDeviceToDevice, cudaStreamDefault)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); + parent_->queue_.push(this); + } else { + // Loss functions divide gradients by the batch size, so to compensate + // for split batch, the root solver divides by number of solvers. + caffe_gpu_scal(size_, Dtype(1.0 / Caffe::solver_count()), diff_); + } +#endif +} + +template<typename Dtype> +void P2PSync<Dtype>::run(const vector<int>& gpus) { + // Pair devices for map-reduce synchronization + vector<DevicePair> pairs; + DevicePair::compute(gpus, &pairs); + ostringstream s; + for (int i = 1; i < pairs.size(); ++i) { + s << (i == 1 ? "" : ", ") << pairs[i].parent() << ":" << pairs[i].device(); + } + LOG(INFO)<< "GPUs pairs " << s.str(); + + SolverParameter param(solver_->param()); + vector<shared_ptr<P2PSync<Dtype> > > syncs(gpus.size()); + + // Build the GPU tree by finding the parent for each solver + for (int attempts = 0; attempts < pairs.size(); ++attempts) { + for (int i = 1; i < pairs.size(); ++i) { + if (!syncs[i].get()) { + P2PSync<Dtype>* parent = NULL; + for (int j = 0; j < syncs.size(); ++j) { + P2PSync<Dtype>* sync = j == 0 ? this : syncs[j].get(); + if (sync) { + const SolverParameter& p = sync->solver()->param(); + if (p.device_id() == pairs[i].parent()) { + parent = sync; + } + } + } + if (parent) { + param.set_device_id(pairs[i].device()); + syncs[i].reset(new P2PSync<Dtype>(solver_, parent, param)); + parent->children_.push_back((P2PSync<Dtype>*) syncs[i].get()); + } + } + } + } + + LOG(INFO)<< "Starting Optimization"; + + for (int i = 1; i < syncs.size(); ++i) { + syncs[i]->StartInternalThread(); + } + + // Run root solver on current thread + solver_->Solve(); + + for (int i = 1; i < syncs.size(); ++i) { + syncs[i]->StopInternalThread(); + } +} + +INSTANTIATE_CLASS(Params); +INSTANTIATE_CLASS(GPUParams); +INSTANTIATE_CLASS(P2PSync); + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 7cfcaa8b..fc0d961a 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -501,6 +501,7 @@ message DataParameter { // to avoid all asynchronous sgd clients to start at the same point. The skip // point would be set as rand_skip * rand(0,1). Note that rand_skip should not // be larger than the number of keys in the database. + // DEPRECATED. Each solver accesses a different subset of the database. optional uint32 rand_skip = 7 [default = 0]; optional DB backend = 8 [default = LEVELDB]; // DEPRECATED. See TransformationParameter. For data pre-processing, we can do @@ -516,6 +517,9 @@ message DataParameter { optional bool mirror = 6 [default = false]; // Force the encoded image to have 3 color channels optional bool force_encoded_color = 9 [default = false]; + // Prefetch queue (Number of batches to prefetch to host memory, increase if + // data access bandwidth varies). + optional uint32 prefetch = 10 [default = 4]; } message DropoutParameter { @@ -737,6 +741,10 @@ message PythonParameter { // string, dictionary in Python dict format, JSON, etc. You may parse this // string in `setup` method and use it in `forward` and `backward`. optional string param_str = 3 [default = '']; + // Whether this PythonLayer is shared among worker solvers during data parallelism. + // If true, each worker solver sequentially run forward from this layer. + // This value should be set true if you are using it as a data layer. + optional bool share_in_parallel = 4 [default = false]; } // Message that stores parameters used by ReductionLayer diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 78902ca0..248f238e 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -18,14 +18,14 @@ namespace caffe { template <typename Dtype> -Solver<Dtype>::Solver(const SolverParameter& param) - : net_() { +Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver) + : net_(), callbacks_(), root_solver_(root_solver) { Init(param); } template <typename Dtype> -Solver<Dtype>::Solver(const string& param_file) - : net_() { +Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver) + : net_(), callbacks_(), root_solver_(root_solver) { SolverParameter param; ReadProtoFromTextFileOrDie(param_file, ¶m); Init(param); @@ -33,17 +33,21 @@ Solver<Dtype>::Solver(const string& param_file) template <typename Dtype> void Solver<Dtype>::Init(const SolverParameter& param) { - LOG(INFO) << "Initializing solver from parameters: " << std::endl - << param.DebugString(); + CHECK(Caffe::root_solver() || root_solver_) + << "root_solver_ needs to be set for all non-root solvers"; + LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: " + << std::endl << param.DebugString(); param_ = param; CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative."; - if (param_.random_seed() >= 0) { + if (Caffe::root_solver() && param_.random_seed() >= 0) { Caffe::set_random_seed(param_.random_seed()); } // Scaffolding code InitTrainNet(); - InitTestNets(); - LOG(INFO) << "Solver scaffolding done."; + if (Caffe::root_solver()) { + InitTestNets(); + LOG(INFO) << "Solver scaffolding done."; + } iter_ = 0; current_step_ = 0; } @@ -59,19 +63,22 @@ void Solver<Dtype>::InitTrainNet() { << "one of these fields specifying a train_net: " << field_names; NetParameter net_param; if (param_.has_train_net_param()) { - LOG(INFO) << "Creating training net specified in train_net_param."; + LOG_IF(INFO, Caffe::root_solver()) + << "Creating training net specified in train_net_param."; net_param.CopyFrom(param_.train_net_param()); } else if (param_.has_train_net()) { - LOG(INFO) << "Creating training net from train_net file: " - << param_.train_net(); + LOG_IF(INFO, Caffe::root_solver()) + << "Creating training net from train_net file: " << param_.train_net(); ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param); } if (param_.has_net_param()) { - LOG(INFO) << "Creating training net specified in net_param."; + LOG_IF(INFO, Caffe::root_solver()) + << "Creating training net specified in net_param."; net_param.CopyFrom(param_.net_param()); } if (param_.has_net()) { - LOG(INFO) << "Creating training net from net file: " << param_.net(); + LOG_IF(INFO, Caffe::root_solver()) + << "Creating training net from net file: " << param_.net(); ReadNetParamsFromTextFileOrDie(param_.net(), &net_param); } // Set the correct NetState. We start with the solver defaults (lowest @@ -83,11 +90,16 @@ void Solver<Dtype>::InitTrainNet() { net_state.MergeFrom(net_param.state()); net_state.MergeFrom(param_.train_state()); net_param.mutable_state()->CopyFrom(net_state); - net_.reset(new Net<Dtype>(net_param)); + if (Caffe::root_solver()) { + net_.reset(new Net<Dtype>(net_param)); + } else { + net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get())); + } } template <typename Dtype> void Solver<Dtype>::InitTestNets() { + CHECK(Caffe::root_solver()); const bool has_net_param = param_.has_net_param(); const bool has_net_file = param_.has_net(); const int num_generic_nets = has_net_param + has_net_file; @@ -157,7 +169,12 @@ void Solver<Dtype>::InitTestNets() { net_params[i].mutable_state()->CopyFrom(net_state); LOG(INFO) << "Creating test net (#" << i << ") specified by " << sources[i]; - test_nets_[i].reset(new Net<Dtype>(net_params[i])); + if (Caffe::root_solver()) { + test_nets_[i].reset(new Net<Dtype>(net_params[i])); + } else { + test_nets_[i].reset(new Net<Dtype>(net_params[i], + root_solver_->test_nets_[i].get())); + } test_nets_[i]->set_debug_info(param_.debug_info()); } } @@ -175,10 +192,14 @@ void Solver<Dtype>::Step(int iters) { // zero-init the params net_->ClearParamDiffs(); if (param_.test_interval() && iter_ % param_.test_interval() == 0 - && (iter_ > 0 || param_.test_initialization())) { + && (iter_ > 0 || param_.test_initialization()) + && Caffe::root_solver()) { TestAll(); } + for (int i = 0; i < callbacks_.size(); ++i) { + callbacks_[i]->on_start(); + } const bool display = param_.display() && iter_ % param_.display() == 0; net_->set_debug_info(display && param_.debug_info()); // accumulate the loss and gradient @@ -198,7 +219,8 @@ void Solver<Dtype>::Step(int iters) { losses[idx] = loss; } if (display) { - LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss; + LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_ + << ", loss = " << smoothed_loss; const vector<Blob<Dtype>*>& result = net_->output_blobs(); int score_index = 0; for (int j = 0; j < result.size(); ++j) { @@ -213,12 +235,15 @@ void Solver<Dtype>::Step(int iters) { loss_msg_stream << " (* " << loss_weight << " = " << loss_weight * result_vec[k] << " loss)"; } - LOG(INFO) << " Train net output #" + LOG_IF(INFO, Caffe::root_solver()) << " Train net output #" << score_index++ << ": " << output_name << " = " << result_vec[k] << loss_msg_stream.str(); } } } + for (int i = 0; i < callbacks_.size(); ++i) { + callbacks_[i]->on_gradients_ready(); + } ApplyUpdate(); // Increment the internal iter_ counter -- its value should always indicate @@ -226,7 +251,9 @@ void Solver<Dtype>::Step(int iters) { ++iter_; // Save a snapshot if needed. - if (param_.snapshot() && iter_ % param_.snapshot() == 0) { + if (param_.snapshot() + && iter_ % param_.snapshot() == 0 + && Caffe::root_solver()) { Snapshot(); } } @@ -234,6 +261,7 @@ void Solver<Dtype>::Step(int iters) { template <typename Dtype> void Solver<Dtype>::Solve(const char* resume_file) { + CHECK(Caffe::root_solver()); LOG(INFO) << "Solving " << net_->name(); LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); @@ -278,6 +306,7 @@ void Solver<Dtype>::TestAll() { template <typename Dtype> void Solver<Dtype>::Test(const int test_net_id) { + CHECK(Caffe::root_solver()); LOG(INFO) << "Iteration " << iter_ << ", Testing net (#" << test_net_id << ")"; CHECK_NOTNULL(test_nets_[test_net_id].get())-> @@ -328,13 +357,14 @@ void Solver<Dtype>::Test(const int test_net_id) { << " = " << loss_weight * mean_score << " loss)"; } LOG(INFO) << " Test net output #" << i << ": " << output_name << " = " - << mean_score << loss_msg_stream.str(); + << mean_score << loss_msg_stream.str(); } } template <typename Dtype> void Solver<Dtype>::Snapshot() { + CHECK(Caffe::root_solver()); string model_filename; switch (param_.snapshot_format()) { case caffe::SolverParameter_SnapshotFormat_BINARYPROTO: @@ -379,6 +409,7 @@ string Solver<Dtype>::SnapshotToHDF5() { template <typename Dtype> void Solver<Dtype>::Restore(const char* state_file) { + CHECK(Caffe::root_solver()); string state_filename(state_file); if (state_filename.size() >= 3 && state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) { @@ -480,6 +511,7 @@ void SGDSolver<Dtype>::ClipGradients() { template <typename Dtype> void SGDSolver<Dtype>::ApplyUpdate() { + CHECK(Caffe::root_solver()); Dtype rate = GetLearningRate(); if (this->param_.display() && this->iter_ % this->param_.display() == 0) { LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; @@ -723,6 +755,7 @@ void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) { template <typename Dtype> void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) { + CHECK(Caffe::root_solver()); const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); const vector<float>& net_params_lr = this->net_->params_lr(); Dtype momentum = this->param_.momentum(); @@ -783,6 +816,7 @@ void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) { template <typename Dtype> void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) { + CHECK(Caffe::root_solver()); const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); const vector<float>& net_params_lr = this->net_->params_lr(); Dtype delta = this->param_.delta(); diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp index 7617ccfb..a667a867 100644 --- a/src/caffe/syncedmem.cpp +++ b/src/caffe/syncedmem.cpp @@ -12,8 +12,14 @@ SyncedMemory::~SyncedMemory() { } #ifndef CPU_ONLY - if (gpu_ptr_) { + if (gpu_ptr_ && own_gpu_data_) { + int initial_device; + cudaGetDevice(&initial_device); + if (gpu_device_ != -1) { + CUDA_CHECK(cudaSetDevice(gpu_device_)); + } CUDA_CHECK(cudaFree(gpu_ptr_)); + cudaSetDevice(initial_device); } #endif // CPU_ONLY } @@ -48,13 +54,17 @@ inline void SyncedMemory::to_gpu() { #ifndef CPU_ONLY switch (head_) { case UNINITIALIZED: + CUDA_CHECK(cudaGetDevice(&gpu_device_)); CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); caffe_gpu_memset(size_, 0, gpu_ptr_); head_ = HEAD_AT_GPU; + own_gpu_data_ = true; break; case HEAD_AT_CPU: if (gpu_ptr_ == NULL) { + CUDA_CHECK(cudaGetDevice(&gpu_device_)); CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + own_gpu_data_ = true; } caffe_gpu_memcpy(size_, cpu_ptr_, gpu_ptr_); head_ = SYNCED; @@ -92,6 +102,26 @@ const void* SyncedMemory::gpu_data() { #endif } +void SyncedMemory::set_gpu_data(void* data) { +#ifndef CPU_ONLY + CHECK(data); + if (own_gpu_data_) { + int initial_device; + cudaGetDevice(&initial_device); + if (gpu_device_ != -1) { + CUDA_CHECK(cudaSetDevice(gpu_device_)); + } + CUDA_CHECK(cudaFree(gpu_ptr_)); + cudaSetDevice(initial_device); + } + gpu_ptr_ = data; + head_ = HEAD_AT_GPU; + own_gpu_data_ = false; +#else + NO_GPU; +#endif +} + void* SyncedMemory::mutable_cpu_data() { to_cpu(); head_ = HEAD_AT_CPU; @@ -108,6 +138,20 @@ void* SyncedMemory::mutable_gpu_data() { #endif } +#ifndef CPU_ONLY +void SyncedMemory::async_gpu_push(const cudaStream_t& stream) { + CHECK(head_ == HEAD_AT_CPU); + if (gpu_ptr_ == NULL) { + CUDA_CHECK(cudaGetDevice(&gpu_device_)); + CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + own_gpu_data_ = true; + } + const cudaMemcpyKind put = cudaMemcpyHostToDevice; + CUDA_CHECK(cudaMemcpyAsync(gpu_ptr_, cpu_ptr_, size_, put, stream)); + // Assume caller will synchronize on the stream before use + head_ = SYNCED; +} +#endif } // namespace caffe diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index c97d4ede..1d255a86 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -8,6 +8,7 @@ #include "gtest/gtest.h" #include "caffe/common.hpp" +#include "caffe/parallel.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/solver.hpp" #include "caffe/util/io.hpp" @@ -35,6 +36,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { string snapshot_prefix_; shared_ptr<SGDSolver<Dtype> > solver_; + shared_ptr<P2PSync<Dtype> > sync_; int seed_; // Dimensions are determined by generate_sample_data.py // TODO this is brittle and the hdf5 file should be checked instead. @@ -71,8 +73,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { string RunLeastSquaresSolver(const Dtype learning_rate, const Dtype weight_decay, const Dtype momentum, const int num_iters, - const int iter_size = 1, const bool snapshot = false, - const char* from_snapshot = NULL) { + const int iter_size = 1, const int devices = 1, + const bool snapshot = false, const char* from_snapshot = NULL) { ostringstream proto; proto << "snapshot_after_train: " << snapshot << " " @@ -185,7 +187,20 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { this->solver_->net()->Forward(empty_bottom_vec); } } - this->solver_->Solve(); + if (devices == 1) { + this->solver_->Solve(); + } else { + LOG(INFO) << "Multi-GPU test on " << devices << " devices"; + vector<int> gpus; + for (int i = 0; i < devices; ++i) { + gpus.push_back(i); + } + Caffe::set_solver_count(gpus.size()); + this->sync_.reset(new P2PSync<Dtype>( + this->solver_, NULL, this->solver_->param())); + this->sync_->run(gpus); + Caffe::set_solver_count(1); + } if (snapshot) { ostringstream resume_file; resume_file << snapshot_prefix_ << "/_iter_" << num_iters @@ -428,20 +443,38 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0, const Dtype weight_decay = 0.0, const Dtype momentum = 0.0, const int iter_to_check = 0) { - // Initialize the solver and run K (= iter_to_check) solver iterations. - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check); - - // Compute the (K+1)th update using the analytic least squares gradient. - vector<shared_ptr<Blob<Dtype> > > updated_params; - ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, - &updated_params); - - // Reinitialize the solver and run K+1 solver iterations. - RunLeastSquaresSolver(learning_rate, weight_decay, momentum, - iter_to_check + 1); - - // Check that the solver's solution matches ours. - CheckLeastSquaresUpdate(updated_params); + const int kNum = num_; + const int kIterSize = 1; + // Test over all numbers of devices. + int available_devices = 1; +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + CUDA_CHECK(cudaGetDeviceCount(&available_devices)); + } +#endif + for (int devices = 1; devices <= available_devices; ++devices) { + // Configure batch size for single / multi device equivalence. + // Constant data is needed for multi device as for accumulation. + num_ = kNum * devices; + + // Initialize the solver and run K (= iter_to_check) solver iterations + // (on single device). + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, + iter_to_check, kIterSize, 1); + + // Compute the (K+1)th update using the analytic least squares gradient. + vector<shared_ptr<Blob<Dtype> > > updated_params; + ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, + &updated_params); + + // Reinitialize the solver and run K+1 solver iterations. + num_ = kNum; + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, + iter_to_check + 1, kIterSize, devices); + + // Check that the solver's solution matches ours. + CheckLeastSquaresUpdate(updated_params); + } } void TestSnapshot(const Dtype learning_rate = 1.0, @@ -451,8 +484,9 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { const int total_num_iters = num_iters * 2; bool snapshot = false; const int kIterSize = 1; + const int kDevices = 1; RunLeastSquaresSolver(learning_rate, weight_decay, momentum, - total_num_iters, kIterSize, snapshot); + total_num_iters, kIterSize, kDevices, snapshot); // Save the resulting param values. vector<shared_ptr<Blob<Dtype> > > param_copies; @@ -482,12 +516,13 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { // Run the solver for num_iters iterations and snapshot. snapshot = true; string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay, - momentum, num_iters, kIterSize, snapshot); + momentum, num_iters, kIterSize, kDevices, snapshot); // Reinitialize the solver and run for num_iters more iterations. snapshot = false; RunLeastSquaresSolver(learning_rate, weight_decay, momentum, - total_num_iters, kIterSize, snapshot, snapshot_name.c_str()); + total_num_iters, kIterSize, kDevices, + snapshot, snapshot_name.c_str()); // Check that params now match. const vector<Blob<Dtype>*>& params = solver_->net()->learnable_params(); diff --git a/src/caffe/test/test_internal_thread.cpp b/src/caffe/test/test_internal_thread.cpp index 31882b6d..93f1cc54 100644 --- a/src/caffe/test/test_internal_thread.cpp +++ b/src/caffe/test/test_internal_thread.cpp @@ -2,6 +2,7 @@ #include "gtest/gtest.h" #include "caffe/internal_thread.hpp" +#include "caffe/util/math_functions.hpp" #include "caffe/test/test_caffe_main.hpp" @@ -13,11 +14,40 @@ class InternalThreadTest : public ::testing::Test {}; TEST_F(InternalThreadTest, TestStartAndExit) { InternalThread thread; EXPECT_FALSE(thread.is_started()); - EXPECT_TRUE(thread.StartInternalThread()); + thread.StartInternalThread(); EXPECT_TRUE(thread.is_started()); - EXPECT_TRUE(thread.WaitForInternalThreadToExit()); + thread.StopInternalThread(); EXPECT_FALSE(thread.is_started()); } +class TestThreadA : public InternalThread { + void InternalThreadEntry() { + EXPECT_EQ(4244559767, caffe_rng_rand()); + } +}; + +class TestThreadB : public InternalThread { + void InternalThreadEntry() { + EXPECT_EQ(1726478280, caffe_rng_rand()); + } +}; + +TEST_F(InternalThreadTest, TestRandomSeed) { + TestThreadA t1; + Caffe::set_random_seed(9658361); + t1.StartInternalThread(); + t1.StopInternalThread(); + + TestThreadA t2; + Caffe::set_random_seed(9658361); + t2.StartInternalThread(); + t2.StopInternalThread(); + + TestThreadB t3; + Caffe::set_random_seed(3435563); + t3.StartInternalThread(); + t3.StopInternalThread(); +} + } // namespace caffe diff --git a/src/caffe/test/test_layer_factory.cpp b/src/caffe/test/test_layer_factory.cpp index efb1b37a..c86fafd0 100644 --- a/src/caffe/test/test_layer_factory.cpp +++ b/src/caffe/test/test_layer_factory.cpp @@ -1,11 +1,14 @@ #include <map> #include <string> +#include "boost/scoped_ptr.hpp" #include "gtest/gtest.h" #include "caffe/common.hpp" #include "caffe/layer.hpp" #include "caffe/layer_factory.hpp" +#include "caffe/util/db.hpp" +#include "caffe/util/io.hpp" #include "caffe/test/test_caffe_main.hpp" @@ -21,11 +24,20 @@ TYPED_TEST(LayerFactoryTest, TestCreateLayer) { typename LayerRegistry<Dtype>::CreatorRegistry& registry = LayerRegistry<Dtype>::Registry(); shared_ptr<Layer<Dtype> > layer; - LayerParameter layer_param; for (typename LayerRegistry<Dtype>::CreatorRegistry::iterator iter = registry.begin(); iter != registry.end(); ++iter) { // Special case: PythonLayer is checked by pytest if (iter->first == "Python") { continue; } + LayerParameter layer_param; + // Data layers expect a DB + if (iter->first == "Data") { + string tmp; + MakeTempDir(&tmp); + boost::scoped_ptr<db::DB> db(db::GetDB(DataParameter_DB_LEVELDB)); + db->Open(tmp, db::NEW); + db->Close(); + layer_param.mutable_data_param()->set_source(tmp); + } layer_param.set_type(iter->first); layer = LayerRegistry<Dtype>::CreateLayer(layer_param); EXPECT_EQ(iter->first, layer->type()); diff --git a/src/caffe/test/test_upgrade_proto.cpp b/src/caffe/test/test_upgrade_proto.cpp index eec62765..00672023 100644 --- a/src/caffe/test/test_upgrade_proto.cpp +++ b/src/caffe/test/test_upgrade_proto.cpp @@ -2,12 +2,15 @@ #include <string> #include <vector> +#include "boost/scoped_ptr.hpp" #include "google/protobuf/text_format.h" #include "gtest/gtest.h" #include "caffe/blob.hpp" #include "caffe/common.hpp" #include "caffe/layer.hpp" +#include "caffe/util/db.hpp" +#include "caffe/util/io.hpp" #include "caffe/util/upgrade_proto.hpp" #include "caffe/test/test_caffe_main.hpp" @@ -2901,6 +2904,15 @@ TEST_F(NetUpgradeTest, TestUpgradeV1LayerType) { continue; // Empty string isn't actually a valid layer type. } layer_param.set_type(v2_layer_type); + // Data layers expect a DB + if (v2_layer_type == "Data") { + string tmp; + MakeTempDir(&tmp); + boost::scoped_ptr<db::DB> db(db::GetDB(DataParameter_DB_LEVELDB)); + db->Open(tmp, db::NEW); + db->Close(); + layer_param.mutable_data_param()->set_source(tmp); + } layer = LayerRegistry<float>::CreateLayer(layer_param); EXPECT_EQ(v2_layer_type, layer->type()); } diff --git a/src/caffe/util/blocking_queue.cpp b/src/caffe/util/blocking_queue.cpp new file mode 100644 index 00000000..d1d1fa86 --- /dev/null +++ b/src/caffe/util/blocking_queue.cpp @@ -0,0 +1,96 @@ +#include <boost/thread.hpp> +#include <string> + +#include "caffe/data_layers.hpp" +#include "caffe/data_reader.hpp" +#include "caffe/parallel.hpp" +#include "caffe/util/blocking_queue.hpp" + +namespace caffe { + +template<typename T> +class BlockingQueue<T>::sync { + public: + mutable boost::mutex mutex_; + boost::condition_variable condition_; +}; + +template<typename T> +BlockingQueue<T>::BlockingQueue() + : sync_(new sync()) { +} + +template<typename T> +void BlockingQueue<T>::push(const T& t) { + boost::mutex::scoped_lock lock(sync_->mutex_); + queue_.push(t); + lock.unlock(); + sync_->condition_.notify_one(); +} + +template<typename T> +bool BlockingQueue<T>::try_pop(T* t) { + boost::mutex::scoped_lock lock(sync_->mutex_); + + if (queue_.empty()) { + return false; + } + + *t = queue_.front(); + queue_.pop(); + return true; +} + +template<typename T> +T BlockingQueue<T>::pop(const string& log_on_wait) { + boost::mutex::scoped_lock lock(sync_->mutex_); + + while (queue_.empty()) { + if (!log_on_wait.empty()) { + LOG_EVERY_N(INFO, 1000)<< log_on_wait; + } + sync_->condition_.wait(lock); + } + + T t = queue_.front(); + queue_.pop(); + return t; +} + +template<typename T> +bool BlockingQueue<T>::try_peek(T* t) { + boost::mutex::scoped_lock lock(sync_->mutex_); + + if (queue_.empty()) { + return false; + } + + *t = queue_.front(); + return true; +} + +template<typename T> +T BlockingQueue<T>::peek() { + boost::mutex::scoped_lock lock(sync_->mutex_); + + while (queue_.empty()) { + sync_->condition_.wait(lock); + } + + return queue_.front(); +} + +template<typename T> +size_t BlockingQueue<T>::size() const { + boost::mutex::scoped_lock lock(sync_->mutex_); + return queue_.size(); +} + +template class BlockingQueue<Batch<float>*>; +template class BlockingQueue<Batch<double>*>; +template class BlockingQueue<Datum*>; +template class BlockingQueue<shared_ptr<DataReader::QueuePair> >; +template class BlockingQueue<P2PSync<float>*>; +template class BlockingQueue<P2PSync<double>*>; + +} // namespace caffe |