summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorRonghang Hu <huronghang@hotmail.com>2015-08-13 13:28:11 -0700
committerRonghang Hu <huronghang@hotmail.com>2015-08-13 13:28:11 -0700
commitbb0a90e8aabb0e1ffd5732439591814d10d7fd45 (patch)
treecc0f44d36d36ada214e4dfecaebab518d5bdf547 /src
parent8181870b9ac330a094ab0f8d53f54a0202f697a0 (diff)
parent6b50ed6fc1897ce1ccd673cf0287788b38b58a6d (diff)
downloadcaffeonacl-bb0a90e8aabb0e1ffd5732439591814d10d7fd45.tar.gz
caffeonacl-bb0a90e8aabb0e1ffd5732439591814d10d7fd45.tar.bz2
caffeonacl-bb0a90e8aabb0e1ffd5732439591814d10d7fd45.zip
Merge pull request #2903 from ronghanghu/multi_gpu
Multi-GPU Data Parallelism
Diffstat (limited to 'src')
-rw-r--r--src/caffe/common.cpp16
-rw-r--r--src/caffe/data_reader.cpp119
-rw-r--r--src/caffe/data_transformer.cpp4
-rw-r--r--src/caffe/internal_thread.cpp58
-rw-r--r--src/caffe/layer.cpp27
-rw-r--r--src/caffe/layers/base_data_layer.cpp90
-rw-r--r--src/caffe/layers/base_data_layer.cu15
-rw-r--r--src/caffe/layers/data_layer.cpp81
-rw-r--r--src/caffe/layers/image_data_layer.cpp28
-rw-r--r--src/caffe/layers/window_data_layer.cpp20
-rw-r--r--src/caffe/net.cpp213
-rw-r--r--src/caffe/parallel.cpp438
-rw-r--r--src/caffe/proto/caffe.proto8
-rw-r--r--src/caffe/solver.cpp76
-rw-r--r--src/caffe/syncedmem.cpp46
-rw-r--r--src/caffe/test/test_gradient_based_solver.cpp75
-rw-r--r--src/caffe/test/test_internal_thread.cpp34
-rw-r--r--src/caffe/test/test_layer_factory.cpp14
-rw-r--r--src/caffe/test/test_upgrade_proto.cpp12
-rw-r--r--src/caffe/util/blocking_queue.cpp96
20 files changed, 1240 insertions, 230 deletions
diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp
index af96cac4..7077f378 100644
--- a/src/caffe/common.cpp
+++ b/src/caffe/common.cpp
@@ -1,3 +1,4 @@
+#include <boost/thread.hpp>
#include <glog/logging.h>
#include <cstdio>
#include <ctime>
@@ -7,7 +8,15 @@
namespace caffe {
-shared_ptr<Caffe> Caffe::singleton_;
+// Make sure each thread can have different values.
+static boost::thread_specific_ptr<Caffe> thread_instance_;
+
+Caffe& Caffe::Get() {
+ if (!thread_instance_.get()) {
+ thread_instance_.reset(new Caffe());
+ }
+ return *(thread_instance_.get());
+}
// random seeding
int64_t cluster_seedgen(void) {
@@ -42,7 +51,8 @@ void GlobalInit(int* pargc, char*** pargv) {
#ifdef CPU_ONLY // CPU-only Caffe.
Caffe::Caffe()
- : random_generator_(), mode_(Caffe::CPU) { }
+ : random_generator_(), mode_(Caffe::CPU),
+ solver_count_(1), root_solver_(true) { }
Caffe::~Caffe() { }
@@ -86,7 +96,7 @@ void* Caffe::RNG::generator() {
Caffe::Caffe()
: cublas_handle_(NULL), curand_generator_(NULL), random_generator_(),
- mode_(Caffe::CPU) {
+ mode_(Caffe::CPU), solver_count_(1), root_solver_(true) {
// Try to create a cublas handler, and report an error if failed (but we will
// keep the program running as one might just want to run CPU code).
if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
diff --git a/src/caffe/data_reader.cpp b/src/caffe/data_reader.cpp
new file mode 100644
index 00000000..16378203
--- /dev/null
+++ b/src/caffe/data_reader.cpp
@@ -0,0 +1,119 @@
+#include <boost/thread.hpp>
+#include <map>
+#include <string>
+#include <vector>
+
+#include "caffe/common.hpp"
+#include "caffe/data_layers.hpp"
+#include "caffe/data_reader.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+namespace caffe {
+
+using boost::weak_ptr;
+
+map<const string, weak_ptr<DataReader::Body> > DataReader::bodies_;
+static boost::mutex bodies_mutex_;
+
+DataReader::DataReader(const LayerParameter& param)
+ : queue_pair_(new QueuePair( //
+ param.data_param().prefetch() * param.data_param().batch_size())) {
+ // Get or create a body
+ boost::mutex::scoped_lock lock(bodies_mutex_);
+ string key = source_key(param);
+ weak_ptr<Body>& weak = bodies_[key];
+ body_ = weak.lock();
+ if (!body_) {
+ body_.reset(new Body(param));
+ bodies_[key] = weak_ptr<Body>(body_);
+ }
+ body_->new_queue_pairs_.push(queue_pair_);
+}
+
+DataReader::~DataReader() {
+ string key = source_key(body_->param_);
+ body_.reset();
+ boost::mutex::scoped_lock lock(bodies_mutex_);
+ if (bodies_[key].expired()) {
+ bodies_.erase(key);
+ }
+}
+
+//
+
+DataReader::QueuePair::QueuePair(int size) {
+ // Initialize the free queue with requested number of datums
+ for (int i = 0; i < size; ++i) {
+ free_.push(new Datum());
+ }
+}
+
+DataReader::QueuePair::~QueuePair() {
+ Datum* datum;
+ while (free_.try_pop(&datum)) {
+ delete datum;
+ }
+ while (full_.try_pop(&datum)) {
+ delete datum;
+ }
+}
+
+//
+
+DataReader::Body::Body(const LayerParameter& param)
+ : param_(param),
+ new_queue_pairs_() {
+ StartInternalThread();
+}
+
+DataReader::Body::~Body() {
+ StopInternalThread();
+}
+
+void DataReader::Body::InternalThreadEntry() {
+ shared_ptr<db::DB> db(db::GetDB(param_.data_param().backend()));
+ db->Open(param_.data_param().source(), db::READ);
+ shared_ptr<db::Cursor> cursor(db->NewCursor());
+ vector<shared_ptr<QueuePair> > qps;
+ try {
+ int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1;
+
+ // To ensure deterministic runs, only start running once all solvers
+ // are ready. But solvers need to peek on one item during initialization,
+ // so read one item, then wait for the next solver.
+ for (int i = 0; i < solver_count; ++i) {
+ shared_ptr<QueuePair> qp(new_queue_pairs_.pop());
+ read_one(cursor.get(), qp.get());
+ qps.push_back(qp);
+ }
+ // Main loop
+ while (!must_stop()) {
+ for (int i = 0; i < solver_count; ++i) {
+ read_one(cursor.get(), qps[i].get());
+ }
+ // Check no additional readers have been created. This can happen if
+ // more than one net is trained at a time per process, whether single
+ // or multi solver. It might also happen if two data layers have same
+ // name and same source.
+ CHECK_EQ(new_queue_pairs_.size(), 0);
+ }
+ } catch (boost::thread_interrupted&) {
+ // Interrupted exception is expected on shutdown
+ }
+}
+
+void DataReader::Body::read_one(db::Cursor* cursor, QueuePair* qp) {
+ Datum* datum = qp->free_.pop();
+ // TODO deserialize in-place instead of copy?
+ datum->ParseFromString(cursor->value());
+ qp->full_.push(datum);
+
+ // go to the next iter
+ cursor->Next();
+ if (!cursor->valid()) {
+ DLOG(INFO) << "Restarting data prefetching from start.";
+ cursor->SeekToFirst();
+ }
+}
+
+} // namespace caffe
diff --git a/src/caffe/data_transformer.cpp b/src/caffe/data_transformer.cpp
index 22633922..4666d9bd 100644
--- a/src/caffe/data_transformer.cpp
+++ b/src/caffe/data_transformer.cpp
@@ -19,7 +19,9 @@ DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param,
CHECK_EQ(param_.mean_value_size(), 0) <<
"Cannot specify mean_file and mean_value at the same time";
const string& mean_file = param.mean_file();
- LOG(INFO) << "Loading mean file from: " << mean_file;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << "Loading mean file from: " << mean_file;
+ }
BlobProto blob_proto;
ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
data_mean_.FromProto(blob_proto);
diff --git a/src/caffe/internal_thread.cpp b/src/caffe/internal_thread.cpp
index c2d19d43..104884e0 100644
--- a/src/caffe/internal_thread.cpp
+++ b/src/caffe/internal_thread.cpp
@@ -1,40 +1,66 @@
#include <boost/thread.hpp>
+#include <exception>
+
#include "caffe/internal_thread.hpp"
+#include "caffe/util/math_functions.hpp"
namespace caffe {
InternalThread::~InternalThread() {
- WaitForInternalThreadToExit();
+ StopInternalThread();
}
bool InternalThread::is_started() const {
- return thread_.get() != NULL && thread_->joinable();
+ return thread_ && thread_->joinable();
+}
+
+bool InternalThread::must_stop() {
+ return thread_ && thread_->interruption_requested();
}
+void InternalThread::StartInternalThread() {
+ CHECK(!is_started()) << "Threads should persist and not be restarted.";
+
+ int device = 0;
+#ifndef CPU_ONLY
+ CUDA_CHECK(cudaGetDevice(&device));
+#endif
+ Caffe::Brew mode = Caffe::mode();
+ int rand_seed = caffe_rng_rand();
+ int solver_count = Caffe::solver_count();
+ bool root_solver = Caffe::root_solver();
-bool InternalThread::StartInternalThread() {
- if (!WaitForInternalThreadToExit()) {
- return false;
- }
try {
- thread_.reset(
- new boost::thread(&InternalThread::InternalThreadEntry, this));
- } catch (...) {
- return false;
+ thread_.reset(new boost::thread(&InternalThread::entry, this, device, mode,
+ rand_seed, solver_count, root_solver));
+ } catch (std::exception& e) {
+ LOG(FATAL) << "Thread exception: " << e.what();
}
- return true;
}
-/** Will not return until the internal thread has exited. */
-bool InternalThread::WaitForInternalThreadToExit() {
+void InternalThread::entry(int device, Caffe::Brew mode, int rand_seed,
+ int solver_count, bool root_solver) {
+#ifndef CPU_ONLY
+ CUDA_CHECK(cudaSetDevice(device));
+#endif
+ Caffe::set_mode(mode);
+ Caffe::set_random_seed(rand_seed);
+ Caffe::set_solver_count(solver_count);
+ Caffe::set_root_solver(root_solver);
+
+ InternalThreadEntry();
+}
+
+void InternalThread::StopInternalThread() {
if (is_started()) {
+ thread_->interrupt();
try {
thread_->join();
- } catch (...) {
- return false;
+ } catch (boost::thread_interrupted&) {
+ } catch (std::exception& e) {
+ LOG(FATAL) << "Thread exception: " << e.what();
}
}
- return true;
}
} // namespace caffe
diff --git a/src/caffe/layer.cpp b/src/caffe/layer.cpp
new file mode 100644
index 00000000..3b912898
--- /dev/null
+++ b/src/caffe/layer.cpp
@@ -0,0 +1,27 @@
+#include <boost/thread.hpp>
+#include "caffe/layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void Layer<Dtype>::InitMutex() {
+ forward_mutex_.reset(new boost::mutex());
+}
+
+template <typename Dtype>
+void Layer<Dtype>::Lock() {
+ if (IsShared()) {
+ forward_mutex_->lock();
+ }
+}
+
+template <typename Dtype>
+void Layer<Dtype>::Unlock() {
+ if (IsShared()) {
+ forward_mutex_->unlock();
+ }
+}
+
+INSTANTIATE_CLASS(Layer);
+
+} // namespace caffe
diff --git a/src/caffe/layers/base_data_layer.cpp b/src/caffe/layers/base_data_layer.cpp
index 26a11182..20f76f62 100644
--- a/src/caffe/layers/base_data_layer.cpp
+++ b/src/caffe/layers/base_data_layer.cpp
@@ -1,7 +1,9 @@
+#include <boost/thread.hpp>
#include <string>
#include <vector>
#include "caffe/data_layers.hpp"
+#include "caffe/net.hpp"
#include "caffe/util/io.hpp"
namespace caffe {
@@ -28,55 +30,91 @@ void BaseDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
+BasePrefetchingDataLayer<Dtype>::BasePrefetchingDataLayer(
+ const LayerParameter& param)
+ : BaseDataLayer<Dtype>(param),
+ prefetch_free_(), prefetch_full_() {
+ for (int i = 0; i < PREFETCH_COUNT; ++i) {
+ prefetch_free_.push(&prefetch_[i]);
+ }
+}
+
+template <typename Dtype>
void BasePrefetchingDataLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
BaseDataLayer<Dtype>::LayerSetUp(bottom, top);
- // Now, start the prefetch thread. Before calling prefetch, we make two
- // cpu_data calls so that the prefetch thread does not accidentally make
- // simultaneous cudaMalloc calls when the main thread is running. In some
- // GPUs this seems to cause failures if we do not so.
- this->prefetch_data_.mutable_cpu_data();
- if (this->output_labels_) {
- this->prefetch_label_.mutable_cpu_data();
+ // Before starting the prefetch thread, we make cpu_data and gpu_data
+ // calls so that the prefetch thread does not accidentally make simultaneous
+ // cudaMalloc calls when the main thread is running. In some GPUs this
+ // seems to cause failures if we do not so.
+ for (int i = 0; i < PREFETCH_COUNT; ++i) {
+ prefetch_[i].data_.mutable_cpu_data();
+ if (this->output_labels_) {
+ prefetch_[i].label_.mutable_cpu_data();
+ }
+ }
+#ifndef CPU_ONLY
+ if (Caffe::mode() == Caffe::GPU) {
+ for (int i = 0; i < PREFETCH_COUNT; ++i) {
+ prefetch_[i].data_.mutable_gpu_data();
+ if (this->output_labels_) {
+ prefetch_[i].label_.mutable_gpu_data();
+ }
+ }
}
+#endif
DLOG(INFO) << "Initializing prefetch";
- this->CreatePrefetchThread();
+ this->data_transformer_->InitRand();
+ StartInternalThread();
DLOG(INFO) << "Prefetch initialized.";
}
template <typename Dtype>
-void BasePrefetchingDataLayer<Dtype>::CreatePrefetchThread() {
- this->data_transformer_->InitRand();
- CHECK(StartInternalThread()) << "Thread execution failed";
-}
+void BasePrefetchingDataLayer<Dtype>::InternalThreadEntry() {
+#ifndef CPU_ONLY
+ cudaStream_t stream;
+ if (Caffe::mode() == Caffe::GPU) {
+ cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking);
+ }
+#endif
-template <typename Dtype>
-void BasePrefetchingDataLayer<Dtype>::JoinPrefetchThread() {
- CHECK(WaitForInternalThreadToExit()) << "Thread joining failed";
+ try {
+ while (!must_stop()) {
+ Batch<Dtype>* batch = prefetch_free_.pop();
+ load_batch(batch);
+#ifndef CPU_ONLY
+ if (Caffe::mode() == Caffe::GPU) {
+ batch->data_.data().get()->async_gpu_push(stream);
+ cudaStreamSynchronize(stream);
+ }
+#endif
+ prefetch_full_.push(batch);
+ }
+ } catch (boost::thread_interrupted&) {
+ // Interrupted exception is expected on shutdown
+ }
}
template <typename Dtype>
void BasePrefetchingDataLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
- // First, join the thread
- JoinPrefetchThread();
- DLOG(INFO) << "Thread joined";
+ Batch<Dtype>* batch = prefetch_full_.pop("Data layer prefetch queue empty");
// Reshape to loaded data.
- top[0]->ReshapeLike(prefetch_data_);
+ top[0]->Reshape(batch->data_.num(), batch->data_.channels(),
+ batch->data_.height(), batch->data_.width());
// Copy the data
- caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),
+ caffe_copy(batch->data_.count(), batch->data_.cpu_data(),
top[0]->mutable_cpu_data());
DLOG(INFO) << "Prefetch copied";
if (this->output_labels_) {
// Reshape to loaded labels.
- top[1]->ReshapeLike(prefetch_label_);
+ top[1]->ReshapeLike(batch->label_);
// Copy the labels.
- caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),
- top[1]->mutable_cpu_data());
+ caffe_copy(batch->label_.count(), batch->label_.cpu_data(),
+ top[1]->mutable_cpu_data());
}
- // Start a new prefetch thread
- DLOG(INFO) << "CreatePrefetchThread";
- CreatePrefetchThread();
+
+ prefetch_free_.push(batch);
}
#ifdef CPU_ONLY
diff --git a/src/caffe/layers/base_data_layer.cu b/src/caffe/layers/base_data_layer.cu
index 9335a5bc..56439bc5 100644
--- a/src/caffe/layers/base_data_layer.cu
+++ b/src/caffe/layers/base_data_layer.cu
@@ -7,22 +7,21 @@ namespace caffe {
template <typename Dtype>
void BasePrefetchingDataLayer<Dtype>::Forward_gpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
- // First, join the thread
- JoinPrefetchThread();
+ Batch<Dtype>* batch = prefetch_full_.pop("Data layer prefetch queue empty");
// Reshape to loaded data.
- top[0]->ReshapeLike(this->prefetch_data_);
+ top[0]->ReshapeLike(batch->data_);
// Copy the data
- caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),
+ caffe_copy(batch->data_.count(), batch->data_.gpu_data(),
top[0]->mutable_gpu_data());
if (this->output_labels_) {
// Reshape to loaded labels.
- top[1]->ReshapeLike(prefetch_label_);
+ top[1]->ReshapeLike(batch->label_);
// Copy the labels.
- caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),
+ caffe_copy(batch->label_.count(), batch->label_.gpu_data(),
top[1]->mutable_gpu_data());
}
- // Start a new prefetch thread
- CreatePrefetchThread();
+
+ prefetch_free_.push(batch);
}
INSTANTIATE_LAYER_GPU_FORWARD(BasePrefetchingDataLayer);
diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp
index 161a75e0..0932d9fe 100644
--- a/src/caffe/layers/data_layer.cpp
+++ b/src/caffe/layers/data_layer.cpp
@@ -11,93 +11,85 @@
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/benchmark.hpp"
#include "caffe/util/io.hpp"
-#include "caffe/util/math_functions.hpp"
-#include "caffe/util/rng.hpp"
namespace caffe {
template <typename Dtype>
-DataLayer<Dtype>::~DataLayer<Dtype>() {
- this->JoinPrefetchThread();
+DataLayer<Dtype>::DataLayer(const LayerParameter& param)
+ : BasePrefetchingDataLayer<Dtype>(param),
+ reader_(param) {
+}
+
+template <typename Dtype>
+DataLayer<Dtype>::~DataLayer() {
+ this->StopInternalThread();
}
template <typename Dtype>
void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
- // Initialize DB
- db_.reset(db::GetDB(this->layer_param_.data_param().backend()));
- db_->Open(this->layer_param_.data_param().source(), db::READ);
- cursor_.reset(db_->NewCursor());
+ const int batch_size = this->layer_param_.data_param().batch_size();
+ // Read a data point, and use it to initialize the top blob.
+ Datum& datum = *(reader_.full().peek());
- // Check if we should randomly skip a few data points
- if (this->layer_param_.data_param().rand_skip()) {
- unsigned int skip = caffe_rng_rand() %
- this->layer_param_.data_param().rand_skip();
- LOG(INFO) << "Skipping first " << skip << " data points.";
- while (skip-- > 0) {
- cursor_->Next();
- }
- }
- // Read a data point, to initialize the prefetch and top blobs.
- Datum datum;
- datum.ParseFromString(cursor_->value());
// Use data_transformer to infer the expected blob shape from datum.
vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
this->transformed_data_.Reshape(top_shape);
// Reshape top[0] and prefetch_data according to the batch_size.
- top_shape[0] = this->layer_param_.data_param().batch_size();
- this->prefetch_data_.Reshape(top_shape);
- top[0]->ReshapeLike(this->prefetch_data_);
-
+ top_shape[0] = batch_size;
+ top[0]->Reshape(top_shape);
+ for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
+ this->prefetch_[i].data_.Reshape(top_shape);
+ }
LOG(INFO) << "output data size: " << top[0]->num() << ","
<< top[0]->channels() << "," << top[0]->height() << ","
<< top[0]->width();
// label
if (this->output_labels_) {
- vector<int> label_shape(1, this->layer_param_.data_param().batch_size());
+ vector<int> label_shape(1, batch_size);
top[1]->Reshape(label_shape);
- this->prefetch_label_.Reshape(label_shape);
+ for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
+ this->prefetch_[i].label_.Reshape(label_shape);
+ }
}
}
-// This function is used to create a thread that prefetches the data.
-template <typename Dtype>
-void DataLayer<Dtype>::InternalThreadEntry() {
+// This function is called on prefetch thread
+template<typename Dtype>
+void DataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
CPUTimer batch_timer;
batch_timer.Start();
double read_time = 0;
double trans_time = 0;
CPUTimer timer;
- CHECK(this->prefetch_data_.count());
+ CHECK(batch->data_.count());
CHECK(this->transformed_data_.count());
// Reshape according to the first datum of each batch
// on single input batches allows for inputs of varying dimension.
const int batch_size = this->layer_param_.data_param().batch_size();
- Datum datum;
- datum.ParseFromString(cursor_->value());
+ Datum& datum = *(reader_.full().peek());
// Use data_transformer to infer the expected blob shape from datum.
vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
this->transformed_data_.Reshape(top_shape);
- // Reshape prefetch_data according to the batch_size.
+ // Reshape batch according to the batch_size.
top_shape[0] = batch_size;
- this->prefetch_data_.Reshape(top_shape);
+ batch->data_.Reshape(top_shape);
- Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
+ Dtype* top_data = batch->data_.mutable_cpu_data();
Dtype* top_label = NULL; // suppress warnings about uninitialized variables
if (this->output_labels_) {
- top_label = this->prefetch_label_.mutable_cpu_data();
+ top_label = batch->label_.mutable_cpu_data();
}
- timer.Start();
for (int item_id = 0; item_id < batch_size; ++item_id) {
+ timer.Start();
// get a datum
- Datum datum;
- datum.ParseFromString(cursor_->value());
+ Datum& datum = *(reader_.full().pop("Waiting for data"));
read_time += timer.MicroSeconds();
timer.Start();
// Apply data transformations (mirror, scale, crop...)
- int offset = this->prefetch_data_.offset(item_id);
+ int offset = batch->data_.offset(item_id);
this->transformed_data_.set_cpu_data(top_data + offset);
this->data_transformer_->Transform(datum, &(this->transformed_data_));
// Copy label.
@@ -105,13 +97,8 @@ void DataLayer<Dtype>::InternalThreadEntry() {
top_label[item_id] = datum.label();
}
trans_time += timer.MicroSeconds();
- timer.Start();
- // go to the next item.
- cursor_->Next();
- if (!cursor_->valid()) {
- DLOG(INFO) << "Restarting data prefetching from start.";
- cursor_->SeekToFirst();
- }
+
+ reader_.free().push(const_cast<Datum*>(&datum));
}
timer.Stop();
batch_timer.Stop();
diff --git a/src/caffe/layers/image_data_layer.cpp b/src/caffe/layers/image_data_layer.cpp
index dcc53348..223ba3a7 100644
--- a/src/caffe/layers/image_data_layer.cpp
+++ b/src/caffe/layers/image_data_layer.cpp
@@ -17,7 +17,7 @@ namespace caffe {
template <typename Dtype>
ImageDataLayer<Dtype>::~ImageDataLayer<Dtype>() {
- this->JoinPrefetchThread();
+ this->StopInternalThread();
}
template <typename Dtype>
@@ -70,8 +70,10 @@ void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
const int batch_size = this->layer_param_.image_data_param().batch_size();
CHECK_GT(batch_size, 0) << "Positive batch size required";
top_shape[0] = batch_size;
- this->prefetch_data_.Reshape(top_shape);
- top[0]->ReshapeLike(this->prefetch_data_);
+ for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
+ this->prefetch_[i].data_.Reshape(top_shape);
+ }
+ top[0]->Reshape(top_shape);
LOG(INFO) << "output data size: " << top[0]->num() << ","
<< top[0]->channels() << "," << top[0]->height() << ","
@@ -79,7 +81,9 @@ void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
// label
vector<int> label_shape(1, batch_size);
top[1]->Reshape(label_shape);
- this->prefetch_label_.Reshape(label_shape);
+ for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
+ this->prefetch_[i].label_.Reshape(label_shape);
+ }
}
template <typename Dtype>
@@ -89,15 +93,15 @@ void ImageDataLayer<Dtype>::ShuffleImages() {
shuffle(lines_.begin(), lines_.end(), prefetch_rng);
}
-// This function is used to create a thread that prefetches the data.
+// This function is called on prefetch thread
template <typename Dtype>
-void ImageDataLayer<Dtype>::InternalThreadEntry() {
+void ImageDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
CPUTimer batch_timer;
batch_timer.Start();
double read_time = 0;
double trans_time = 0;
CPUTimer timer;
- CHECK(this->prefetch_data_.count());
+ CHECK(batch->data_.count());
CHECK(this->transformed_data_.count());
ImageDataParameter image_data_param = this->layer_param_.image_data_param();
const int batch_size = image_data_param.batch_size();
@@ -114,12 +118,12 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() {
// Use data_transformer to infer the expected blob shape from a cv_img.
vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);
this->transformed_data_.Reshape(top_shape);
- // Reshape prefetch_data according to the batch_size.
+ // Reshape batch according to the batch_size.
top_shape[0] = batch_size;
- this->prefetch_data_.Reshape(top_shape);
+ batch->data_.Reshape(top_shape);
- Dtype* prefetch_data = this->prefetch_data_.mutable_cpu_data();
- Dtype* prefetch_label = this->prefetch_label_.mutable_cpu_data();
+ Dtype* prefetch_data = batch->data_.mutable_cpu_data();
+ Dtype* prefetch_label = batch->label_.mutable_cpu_data();
// datum scales
const int lines_size = lines_.size();
@@ -133,7 +137,7 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() {
read_time += timer.MicroSeconds();
timer.Start();
// Apply transformations (mirror, crop...) to the image
- int offset = this->prefetch_data_.offset(item_id);
+ int offset = batch->data_.offset(item_id);
this->transformed_data_.set_cpu_data(prefetch_data + offset);
this->data_transformer_->Transform(cv_img, &(this->transformed_data_));
trans_time += timer.MicroSeconds();
diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp
index c127d56b..f637f2ec 100644
--- a/src/caffe/layers/window_data_layer.cpp
+++ b/src/caffe/layers/window_data_layer.cpp
@@ -27,7 +27,7 @@ namespace caffe {
template <typename Dtype>
WindowDataLayer<Dtype>::~WindowDataLayer<Dtype>() {
- this->JoinPrefetchThread();
+ this->StopInternalThread();
}
template <typename Dtype>
@@ -171,7 +171,9 @@ void WindowDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
CHECK_GT(crop_size, 0);
const int batch_size = this->layer_param_.window_data_param().batch_size();
top[0]->Reshape(batch_size, channels, crop_size, crop_size);
- this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size);
+ for (int i = 0; i < this->PREFETCH_COUNT; ++i)
+ this->prefetch_[i].data_.Reshape(
+ batch_size, channels, crop_size, crop_size);
LOG(INFO) << "output data size: " << top[0]->num() << ","
<< top[0]->channels() << "," << top[0]->height() << ","
@@ -179,7 +181,9 @@ void WindowDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
// label
vector<int> label_shape(1, batch_size);
top[1]->Reshape(label_shape);
- this->prefetch_label_.Reshape(label_shape);
+ for (int i = 0; i < this->PREFETCH_COUNT; ++i) {
+ this->prefetch_[i].label_.Reshape(label_shape);
+ }
// data mean
has_mean_file_ = this->transform_param_.has_mean_file();
@@ -217,9 +221,9 @@ unsigned int WindowDataLayer<Dtype>::PrefetchRand() {
return (*prefetch_rng)();
}
-// Thread fetching the data
+// This function is called on prefetch thread
template <typename Dtype>
-void WindowDataLayer<Dtype>::InternalThreadEntry() {
+void WindowDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
// At each iteration, sample N windows where N*p are foreground (object)
// windows and N*(1-p) are background (non-object) windows
CPUTimer batch_timer;
@@ -227,8 +231,8 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() {
double read_time = 0;
double trans_time = 0;
CPUTimer timer;
- Dtype* top_data = this->prefetch_data_.mutable_cpu_data();
- Dtype* top_label = this->prefetch_label_.mutable_cpu_data();
+ Dtype* top_data = batch->data_.mutable_cpu_data();
+ Dtype* top_label = batch->label_.mutable_cpu_data();
const Dtype scale = this->layer_param_.window_data_param().scale();
const int batch_size = this->layer_param_.window_data_param().batch_size();
const int context_pad = this->layer_param_.window_data_param().context_pad();
@@ -252,7 +256,7 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() {
bool use_square = (crop_mode == "square") ? true : false;
// zero out batch
- caffe_set(this->prefetch_data_.count(), Dtype(0), top_data);
+ caffe_set(batch->data_.count(), Dtype(0), top_data);
const int num_fg = static_cast<int>(static_cast<float>(batch_size)
* fg_fraction);
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
index 91883a10..7875285f 100644
--- a/src/caffe/net.cpp
+++ b/src/caffe/net.cpp
@@ -10,6 +10,7 @@
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/net.hpp"
+#include "caffe/parallel.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/hdf5.hpp"
#include "caffe/util/insert_splits.hpp"
@@ -21,12 +22,14 @@
namespace caffe {
template <typename Dtype>
-Net<Dtype>::Net(const NetParameter& param) {
+Net<Dtype>::Net(const NetParameter& param, const Net* root_net)
+ : root_net_(root_net) {
Init(param);
}
template <typename Dtype>
-Net<Dtype>::Net(const string& param_file, Phase phase) {
+Net<Dtype>::Net(const string& param_file, Phase phase, const Net* root_net)
+ : root_net_(root_net) {
NetParameter param;
ReadNetParamsFromTextFileOrDie(param_file, &param);
param.mutable_state()->set_phase(phase);
@@ -35,14 +38,18 @@ Net<Dtype>::Net(const string& param_file, Phase phase) {
template <typename Dtype>
void Net<Dtype>::Init(const NetParameter& in_param) {
+ CHECK(Caffe::root_solver() || root_net_)
+ << "root_net_ needs to be set for all non-root solvers";
// Set phase from the state.
phase_ = in_param.state().phase();
// Filter layers based on their include/exclude rules and
// the current NetState.
NetParameter filtered_param;
FilterNet(in_param, &filtered_param);
- LOG(INFO) << "Initializing net from parameters: " << std::endl
- << filtered_param.DebugString();
+ if (Caffe::root_solver()) {
+ LOG(INFO) << "Initializing net from parameters: " << std::endl
+ << filtered_param.DebugString();
+ }
// Create a copy of filtered_param with splits added where necessary.
NetParameter param;
InsertSplits(filtered_param, &param);
@@ -66,7 +73,8 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
const int layer_id = -1; // inputs have fake layer ID -1
AppendTop(param, layer_id, input_id, &available_blobs, &blob_name_to_idx);
}
- DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype);
+ DLOG_IF(INFO, Caffe::root_solver())
+ << "Memory required for data: " << memory_used_ * sizeof(Dtype);
// For each layer, set up its input and output
bottom_vecs_.resize(param.layer_size());
top_vecs_.resize(param.layer_size());
@@ -75,6 +83,9 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
top_id_vecs_.resize(param.layer_size());
bottom_need_backward_.resize(param.layer_size());
for (int layer_id = 0; layer_id < param.layer_size(); ++layer_id) {
+ // For non-root solvers, whether this layer is shared from root_net_.
+ bool share_from_root = !Caffe::root_solver()
+ && root_net_->layers_[layer_id]->ShareInParallel();
// Inherit phase from net if unset.
if (!param.layer(layer_id).has_phase()) {
param.mutable_layer(layer_id)->set_phase(phase_);
@@ -87,9 +98,17 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
<< "propagate_down param must be specified "
<< "either 0 or bottom_size times ";
}
- layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param));
+ if (share_from_root) {
+ LOG(INFO) << "Sharing layer " << layer_param.name() << " from root net";
+ layers_.push_back(root_net_->layers_[layer_id]);
+ layers_[layer_id]->SetShared(true);
+ } else {
+ layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param));
+ }
layer_names_.push_back(layer_param.name());
- LOG(INFO) << "Creating Layer " << layer_param.name();
+ if (Caffe::root_solver()) {
+ LOG(INFO) << "Creating Layer " << layer_param.name();
+ }
bool need_backward = false;
// Figure out this layer's input and output
@@ -119,20 +138,42 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
}
}
// After this layer is connected, set it up.
- LOG(INFO) << "Setting up " << layer_names_[layer_id];
- layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]);
+ if (share_from_root) {
+ // Set up size of top blobs using root_net_
+ const vector<Blob<Dtype>*>& base_top = root_net_->top_vecs_[layer_id];
+ const vector<Blob<Dtype>*>& this_top = this->top_vecs_[layer_id];
+ for (int top_id = 0; top_id < base_top.size(); ++top_id) {
+ this_top[top_id]->ReshapeLike(*base_top[top_id]);
+ LOG(INFO) << "Created top blob " << top_id << " (shape: "
+ << this_top[top_id]->shape_string() << ") for shared layer "
+ << layer_param.name();
+ }
+ } else {
+ layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]);
+ }
+ if (Caffe::root_solver()) {
+ LOG(INFO) << "Setting up " << layer_names_[layer_id];
+ }
for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) {
if (blob_loss_weights_.size() <= top_id_vecs_[layer_id][top_id]) {
blob_loss_weights_.resize(top_id_vecs_[layer_id][top_id] + 1, Dtype(0));
}
blob_loss_weights_[top_id_vecs_[layer_id][top_id]] = layer->loss(top_id);
- LOG(INFO) << "Top shape: " << top_vecs_[layer_id][top_id]->shape_string();
+ if (Caffe::root_solver()) {
+ LOG(INFO) << "Top shape: "
+ << top_vecs_[layer_id][top_id]->shape_string();
+ }
if (layer->loss(top_id)) {
- LOG(INFO) << " with loss weight " << layer->loss(top_id);
+ if (Caffe::root_solver()) {
+ LOG(INFO) << " with loss weight " << layer->loss(top_id);
+ }
}
memory_used_ += top_vecs_[layer_id][top_id]->count();
}
- DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype);
+ if (Caffe::root_solver()) {
+ DLOG(INFO) << "Memory required for data: "
+ << memory_used_ * sizeof(Dtype);
+ }
const int param_size = layer_param.param_size();
const int num_param_blobs = layers_[layer_id]->blobs().size();
CHECK_LE(param_size, num_param_blobs)
@@ -191,10 +232,14 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
}
if (!layer_contributes_loss) { layer_need_backward_[layer_id] = false; }
if (layer_need_backward_[layer_id]) {
- LOG(INFO) << layer_names_[layer_id] << " needs backward computation.";
+ if (Caffe::root_solver()) {
+ LOG(INFO) << layer_names_[layer_id] << " needs backward computation.";
+ }
} else {
- LOG(INFO) << layer_names_[layer_id]
- << " does not need backward computation.";
+ if (Caffe::root_solver()) {
+ LOG(INFO) << layer_names_[layer_id]
+ << " does not need backward computation.";
+ }
}
for (int bottom_id = 0; bottom_id < bottom_vecs_[layer_id].size();
++bottom_id) {
@@ -234,7 +279,9 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
// In the end, all remaining blobs are considered output blobs.
for (set<string>::iterator it = available_blobs.begin();
it != available_blobs.end(); ++it) {
- LOG(INFO) << "This network produces output " << *it;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << "This network produces output " << *it;
+ }
net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get());
net_output_blob_indices_.push_back(blob_name_to_idx[*it]);
}
@@ -246,8 +293,10 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
}
ShareWeights();
debug_info_ = param.debug_info();
- LOG(INFO) << "Network initialization done.";
- LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype);
+ if (Caffe::root_solver()) {
+ LOG(INFO) << "Network initialization done.";
+ LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype);
+ }
}
template <typename Dtype>
@@ -286,27 +335,33 @@ bool Net<Dtype>::StateMeetsRule(const NetState& state,
// Check whether the rule is broken due to phase.
if (rule.has_phase()) {
if (rule.phase() != state.phase()) {
- LOG(INFO) << "The NetState phase (" << state.phase()
- << ") differed from the phase (" << rule.phase()
- << ") specified by a rule in layer " << layer_name;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << "The NetState phase (" << state.phase()
+ << ") differed from the phase (" << rule.phase()
+ << ") specified by a rule in layer " << layer_name;
+ }
return false;
}
}
// Check whether the rule is broken due to min level.
if (rule.has_min_level()) {
if (state.level() < rule.min_level()) {
- LOG(INFO) << "The NetState level (" << state.level()
- << ") is above the min_level (" << rule.min_level()
- << ") specified by a rule in layer " << layer_name;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << "The NetState level (" << state.level()
+ << ") is above the min_level (" << rule.min_level()
+ << ") specified by a rule in layer " << layer_name;
+ }
return false;
}
}
// Check whether the rule is broken due to max level.
if (rule.has_max_level()) {
if (state.level() > rule.max_level()) {
- LOG(INFO) << "The NetState level (" << state.level()
- << ") is above the max_level (" << rule.max_level()
- << ") specified by a rule in layer " << layer_name;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << "The NetState level (" << state.level()
+ << ") is above the max_level (" << rule.max_level()
+ << ") specified by a rule in layer " << layer_name;
+ }
return false;
}
}
@@ -319,8 +374,10 @@ bool Net<Dtype>::StateMeetsRule(const NetState& state,
if (rule.stage(i) == state.stage(j)) { has_stage = true; }
}
if (!has_stage) {
- LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i)
- << "' specified by a rule in layer " << layer_name;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i)
+ << "' specified by a rule in layer " << layer_name;
+ }
return false;
}
}
@@ -333,8 +390,10 @@ bool Net<Dtype>::StateMeetsRule(const NetState& state,
if (rule.not_stage(i) == state.stage(j)) { has_stage = true; }
}
if (has_stage) {
- LOG(INFO) << "The NetState contained a not_stage '" << rule.not_stage(i)
- << "' specified by a rule in layer " << layer_name;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << "The NetState contained a not_stage '" << rule.not_stage(i)
+ << "' specified by a rule in layer " << layer_name;
+ }
return false;
}
}
@@ -356,7 +415,9 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
if (blob_name_to_idx && layer_param && layer_param->bottom_size() > top_id &&
blob_name == layer_param->bottom(top_id)) {
// In-place computation
- LOG(INFO) << layer_param->name() << " -> " << blob_name << " (in-place)";
+ if (Caffe::root_solver()) {
+ LOG(INFO) << layer_param->name() << " -> " << blob_name << " (in-place)";
+ }
top_vecs_[layer_id].push_back(blobs_[(*blob_name_to_idx)[blob_name]].get());
top_id_vecs_[layer_id].push_back((*blob_name_to_idx)[blob_name]);
} else if (blob_name_to_idx &&
@@ -366,10 +427,12 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
LOG(FATAL) << "Duplicate blobs produced by multiple sources.";
} else {
// Normal output.
- if (layer_param) {
- LOG(INFO) << layer_param->name() << " -> " << blob_name;
- } else {
- LOG(INFO) << "Input " << top_id << " -> " << blob_name;
+ if (Caffe::root_solver()) {
+ if (layer_param) {
+ LOG(INFO) << layer_param->name() << " -> " << blob_name;
+ } else {
+ LOG(INFO) << "Input " << top_id << " -> " << blob_name;
+ }
}
shared_ptr<Blob<Dtype> > blob_pointer(new Blob<Dtype>());
const int blob_id = blobs_.size();
@@ -409,7 +472,9 @@ int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
<< " (at index " << bottom_id << ") to layer " << layer_id;
}
const int blob_id = (*blob_name_to_idx)[blob_name];
- LOG(INFO) << layer_names_[layer_id] << " <- " << blob_name;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << layer_names_[layer_id] << " <- " << blob_name;
+ }
bottom_vecs_[layer_id].push_back(blobs_[blob_id].get());
bottom_id_vecs_[layer_id].push_back(blob_id);
available_blobs->erase(blob_name);
@@ -468,9 +533,10 @@ void Net<Dtype>::AppendParam(const NetParameter& param, const int layer_id,
param_layer_indices_[owner_net_param_id];
const int owner_layer_id = owner_index.first;
const int owner_param_id = owner_index.second;
- LOG(INFO) << "Sharing parameters '" << param_name << "' owned by "
- << "layer '" << layer_names_[owner_layer_id] << "', param "
- << "index " << owner_param_id;
+ LOG_IF(INFO, Caffe::root_solver()) << "Sharing parameters '" << param_name
+ << "' owned by "
+ << "layer '" << layer_names_[owner_layer_id] << "', param "
+ << "index " << owner_param_id;
Blob<Dtype>* this_blob = layers_[layer_id]->blobs()[param_id].get();
Blob<Dtype>* owner_blob =
layers_[owner_layer_id]->blobs()[owner_param_id].get();
@@ -596,8 +662,10 @@ void Net<Dtype>::InputDebugInfo(const int input_id) {
const Blob<Dtype>& blob = *net_input_blobs_[input_id];
const string& blob_name = blob_names_[net_input_blob_indices_[input_id]];
const Dtype data_abs_val_mean = blob.asum_data() / blob.count();
- LOG(INFO) << " [Forward] "
- << "Input " << blob_name << " data: " << data_abs_val_mean;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << " [Forward] "
+ << "Input " << blob_name << " data: " << data_abs_val_mean;
+ }
}
template <typename Dtype>
@@ -606,9 +674,12 @@ void Net<Dtype>::ForwardDebugInfo(const int layer_id) {
const Blob<Dtype>& blob = *top_vecs_[layer_id][top_id];
const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]];
const Dtype data_abs_val_mean = blob.asum_data() / blob.count();
- LOG(INFO) << " [Forward] "
- << "Layer " << layer_names_[layer_id] << ", top blob " << blob_name
- << " data: " << data_abs_val_mean;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << " [Forward] "
+ << "Layer " << layer_names_[layer_id]
+ << ", top blob " << blob_name
+ << " data: " << data_abs_val_mean;
+ }
}
for (int param_id = 0; param_id < layers_[layer_id]->blobs().size();
++param_id) {
@@ -616,9 +687,12 @@ void Net<Dtype>::ForwardDebugInfo(const int layer_id) {
const int net_param_id = param_id_vecs_[layer_id][param_id];
const string& blob_name = param_display_names_[net_param_id];
const Dtype data_abs_val_mean = blob.asum_data() / blob.count();
- LOG(INFO) << " [Forward] "
- << "Layer " << layer_names_[layer_id] << ", param blob " << blob_name
- << " data: " << data_abs_val_mean;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << " [Forward] "
+ << "Layer " << layer_names_[layer_id]
+ << ", param blob " << blob_name
+ << " data: " << data_abs_val_mean;
+ }
}
}
@@ -630,18 +704,24 @@ void Net<Dtype>::BackwardDebugInfo(const int layer_id) {
const Blob<Dtype>& blob = *bottom_vec[bottom_id];
const string& blob_name = blob_names_[bottom_id_vecs_[layer_id][bottom_id]];
const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count();
- LOG(INFO) << " [Backward] "
- << "Layer " << layer_names_[layer_id] << ", bottom blob " << blob_name
- << " diff: " << diff_abs_val_mean;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << " [Backward] "
+ << "Layer " << layer_names_[layer_id]
+ << ", bottom blob " << blob_name
+ << " diff: " << diff_abs_val_mean;
+ }
}
for (int param_id = 0; param_id < layers_[layer_id]->blobs().size();
++param_id) {
if (!layers_[layer_id]->param_propagate_down(param_id)) { continue; }
const Blob<Dtype>& blob = *layers_[layer_id]->blobs()[param_id];
const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count();
- LOG(INFO) << " [Backward] "
- << "Layer " << layer_names_[layer_id] << ", param blob " << param_id
- << " diff: " << diff_abs_val_mean;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << " [Backward] "
+ << "Layer " << layer_names_[layer_id]
+ << ", param blob " << param_id
+ << " diff: " << diff_abs_val_mean;
+ }
}
}
@@ -654,17 +734,22 @@ void Net<Dtype>::UpdateDebugInfo(const int param_id) {
const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count();
if (param_owner < 0) {
const Dtype data_abs_val_mean = blob.asum_data() / blob.count();
- LOG(INFO) << " [Update] Layer " << layer_name
- << ", param " << param_display_name
- << " data: " << data_abs_val_mean << "; diff: " << diff_abs_val_mean;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << " [Update] Layer " << layer_name
+ << ", param " << param_display_name
+ << " data: " << data_abs_val_mean
+ << "; diff: " << diff_abs_val_mean;
+ }
} else {
const string& owner_layer_name =
layer_names_[param_layer_indices_[param_owner].first];
- LOG(INFO) << " [Update] Layer " << layer_name
- << ", param blob " << param_display_name
- << " (owned by layer " << owner_layer_name << ", "
- << "param " << param_display_names_[param_owners_[param_id]] << ")"
- << " diff: " << diff_abs_val_mean;
+ if (Caffe::root_solver()) {
+ LOG(INFO) << " [Update] Layer " << layer_name
+ << ", param blob " << param_display_name
+ << " (owned by layer " << owner_layer_name << ", " << "param "
+ << param_display_names_[param_owners_[param_id]] << ")"
+ << " diff: " << diff_abs_val_mean;
+ }
}
}
@@ -721,8 +806,8 @@ void Net<Dtype>::Backward() {
const Dtype l2norm_data = std::sqrt(sumsq_data);
const Dtype l2norm_diff = std::sqrt(sumsq_diff);
LOG(ERROR) << " [Backward] All net params (data, diff): "
- << "L1 norm = (" << asum_data << ", " << asum_diff << "); "
- << "L2 norm = (" << l2norm_data << ", " << l2norm_diff << ")";
+ << "L1 norm = (" << asum_data << ", " << asum_diff << "); "
+ << "L2 norm = (" << l2norm_data << ", " << l2norm_diff << ")";
}
}
diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp
new file mode 100644
index 00000000..6e7d802b
--- /dev/null
+++ b/src/caffe/parallel.cpp
@@ -0,0 +1,438 @@
+#ifndef CPU_ONLY
+#include <cuda_runtime.h>
+#endif
+#include <glog/logging.h>
+#include <stdio.h>
+#include <sys/ioctl.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+
+#include <cstdlib>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "boost/thread.hpp"
+#include "caffe/caffe.hpp"
+#include "caffe/parallel.hpp"
+
+namespace caffe {
+
+enum Op {
+ copy,
+ replace_cpu,
+ replace_gpu,
+ replace_cpu_diff,
+ replace_gpu_diff
+};
+
+template<typename Dtype>
+static void apply_buffers(const vector<Blob<Dtype>*>& blobs,
+ Dtype* buffer, size_t total_size, Op op) {
+ Dtype* ptr = buffer;
+ for (int i = 0; i < blobs.size(); ++i) {
+ int size = blobs[i]->count();
+ switch (op) {
+ case copy: {
+ // Init buffer to current values of blobs
+ caffe_copy(size,
+ reinterpret_cast<const Dtype*>(blobs[i]->data()->cpu_data()),
+ ptr);
+ break;
+ }
+ case replace_cpu:
+ blobs[i]->data()->set_cpu_data(ptr);
+ break;
+ case replace_gpu:
+ blobs[i]->data()->set_gpu_data(ptr);
+ break;
+ case replace_cpu_diff:
+ blobs[i]->diff()->set_cpu_data(ptr);
+ break;
+ case replace_gpu_diff:
+ blobs[i]->diff()->set_gpu_data(ptr);
+ break;
+ }
+ ptr += size;
+ }
+ CHECK_EQ(total_size, ptr - buffer);
+}
+
+// Buffer size necessary to store given blobs
+template<typename Dtype>
+static size_t total_size(const vector<Blob<Dtype>*>& params) {
+ size_t size = 0;
+ for (int i = 0; i < params.size(); ++i)
+ size += params[i]->count();
+ return size;
+}
+
+template<typename Dtype>
+Params<Dtype>::Params(shared_ptr<Solver<Dtype> > root_solver)
+ : size_(total_size<Dtype>(root_solver->net()->learnable_params())),
+ data_(),
+ diff_() {
+}
+
+template<typename Dtype>
+GPUParams<Dtype>::GPUParams(shared_ptr<Solver<Dtype> > root_solver, int device)
+ : Params<Dtype>(root_solver) {
+#ifndef CPU_ONLY
+ int initial_device;
+ CUDA_CHECK(cudaGetDevice(&initial_device));
+
+ // Allocate device buffers
+ CUDA_CHECK(cudaSetDevice(device));
+ CUDA_CHECK(cudaMalloc(&data_, size_ * sizeof(Dtype)));
+
+ // Copy blob values
+ const vector<Blob<Dtype>*>& net =
+ root_solver->net()->learnable_params();
+ apply_buffers(net, data_, size_, copy);
+
+ CUDA_CHECK(cudaMalloc(&diff_, size_ * sizeof(Dtype)));
+ caffe_gpu_set(size_, Dtype(0), diff_);
+
+ CUDA_CHECK(cudaSetDevice(initial_device));
+#else
+ NO_GPU;
+#endif
+}
+
+template<typename Dtype>
+GPUParams<Dtype>::~GPUParams() {
+#ifndef CPU_ONLY
+ CUDA_CHECK(cudaFree(data_));
+ CUDA_CHECK(cudaFree(diff_));
+#endif
+}
+
+template<typename Dtype>
+void GPUParams<Dtype>::configure(Solver<Dtype>* solver) const {
+ const vector<Blob<Dtype>*>& net =
+ solver->net()->learnable_params();
+ apply_buffers(net, data_, size_, replace_gpu);
+ apply_buffers(net, diff_, size_, replace_gpu_diff);
+}
+
+void DevicePair::compute(const vector<int> devices, vector<DevicePair>* pairs) {
+#ifndef CPU_ONLY
+ vector<int> remaining(devices);
+
+ // Depth for reduction tree
+ int remaining_depth = static_cast<int>(ceil(log2(remaining.size())));
+
+ // Group GPUs by board
+ for (int d = 0; d < remaining_depth; ++d) {
+ for (int i = 0; i < remaining.size(); ++i) {
+ for (int j = i + 1; j < remaining.size(); ++j) {
+ cudaDeviceProp a, b;
+ CUDA_CHECK(cudaGetDeviceProperties(&a, remaining[i]));
+ CUDA_CHECK(cudaGetDeviceProperties(&b, remaining[j]));
+ if (a.isMultiGpuBoard && b.isMultiGpuBoard) {
+ if (a.multiGpuBoardGroupID == b.multiGpuBoardGroupID) {
+ pairs->push_back(DevicePair(remaining[i], remaining[j]));
+ DLOG(INFO) << "GPU board: " << remaining[i] << ":" << remaining[j];
+ remaining.erase(remaining.begin() + j);
+ break;
+ }
+ }
+ }
+ }
+ }
+ ostringstream s;
+ for (int i = 0; i < remaining.size(); ++i) {
+ s << (i ? ", " : "") << remaining[i];
+ }
+ DLOG(INFO) << "GPUs paired by boards, remaining: " << s.str();
+
+ // Group by P2P accessibility
+ remaining_depth = ceil(log2(remaining.size()));
+ for (int d = 0; d < remaining_depth; ++d) {
+ for (int i = 0; i < remaining.size(); ++i) {
+ for (int j = i + 1; j < remaining.size(); ++j) {
+ int access;
+ CUDA_CHECK(
+ cudaDeviceCanAccessPeer(&access, remaining[i], remaining[j]));
+ if (access) {
+ pairs->push_back(DevicePair(remaining[i], remaining[j]));
+ DLOG(INFO) << "P2P pair: " << remaining[i] << ":" << remaining[j];
+ remaining.erase(remaining.begin() + j);
+ break;
+ }
+ }
+ }
+ }
+ s.str("");
+ for (int i = 0; i < remaining.size(); ++i) {
+ s << (i ? ", " : "") << remaining[i];
+ }
+ DLOG(INFO) << "GPUs paired by P2P access, remaining: " << s.str();
+
+ // Group remaining
+ remaining_depth = ceil(log2(remaining.size()));
+ for (int d = 0; d < remaining_depth; ++d) {
+ for (int i = 0; i < remaining.size(); ++i) {
+ pairs->push_back(DevicePair(remaining[i], remaining[i + 1]));
+ DLOG(INFO) << "Remaining pair: " << remaining[i] << ":"
+ << remaining[i + 1];
+ remaining.erase(remaining.begin() + i + 1);
+ }
+ }
+
+ // Should only be the parent node remaining
+ CHECK_EQ(remaining.size(), 1);
+
+ pairs->insert(pairs->begin(), DevicePair(-1, remaining[0]));
+
+ CHECK(pairs->size() == devices.size());
+ for (int i = 0; i < pairs->size(); ++i) {
+ CHECK((*pairs)[i].parent() != (*pairs)[i].device());
+ for (int j = i + 1; j < pairs->size(); ++j) {
+ CHECK((*pairs)[i].device() != (*pairs)[j].device());
+ }
+ }
+#else
+ NO_GPU;
+#endif
+}
+
+//
+
+template<typename Dtype>
+P2PSync<Dtype>::P2PSync(shared_ptr<Solver<Dtype> > root_solver,
+ P2PSync<Dtype>* parent, const SolverParameter& param)
+ : GPUParams<Dtype>(root_solver, param.device_id()),
+ parent_(parent),
+ children_(),
+ queue_(),
+ initial_iter_(root_solver->iter()),
+ solver_() {
+#ifndef CPU_ONLY
+ int initial_device;
+ CUDA_CHECK(cudaGetDevice(&initial_device));
+ const int self = param.device_id();
+ CUDA_CHECK(cudaSetDevice(self));
+
+ if (parent == NULL) {
+ solver_ = root_solver;
+ } else {
+ Caffe::set_root_solver(false);
+ solver_.reset(new WorkerSolver<Dtype>(param, root_solver.get()));
+ Caffe::set_root_solver(true);
+ }
+ this->configure(solver_.get());
+ solver_->add_callback(this);
+
+ if (parent) {
+ // Enable p2p access between devices
+ const int peer = parent->solver_->param().device_id();
+ int access;
+ CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer));
+ if (access) {
+ CUDA_CHECK(cudaDeviceEnablePeerAccess(peer, 0));
+ } else {
+ LOG(INFO)<< "GPU " << self << " does not have p2p access to GPU " << peer;
+ }
+ // Allocate receiving buffer on parent
+ CUDA_CHECK(cudaSetDevice(peer));
+ CUDA_CHECK(cudaMalloc(&parent_grads_, size_ * sizeof(Dtype)));
+ CUDA_CHECK(cudaSetDevice(self));
+ }
+
+ CUDA_CHECK(cudaSetDevice(initial_device));
+#else
+ NO_GPU;
+#endif
+}
+
+template<typename Dtype>
+P2PSync<Dtype>::~P2PSync() {
+#ifndef CPU_ONLY
+ int initial_device;
+ CUDA_CHECK(cudaGetDevice(&initial_device));
+ const int self = solver_->param().device_id();
+ CUDA_CHECK(cudaSetDevice(self));
+
+ if (parent_) {
+ CUDA_CHECK(cudaFree(parent_grads_));
+ const int peer = parent_->solver_->param().device_id();
+ int access;
+ CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer));
+ if (access) {
+ CUDA_CHECK(cudaDeviceDisablePeerAccess(peer));
+ }
+ }
+
+ CUDA_CHECK(cudaSetDevice(initial_device));
+#endif
+}
+
+template<typename Dtype>
+void P2PSync<Dtype>::InternalThreadEntry() {
+ Caffe::SetDevice(solver_->param().device_id());
+ CHECK(Caffe::root_solver());
+ Caffe::set_root_solver(false);
+ // See if there is a defined seed and reset random state if so
+ if (solver_->param().random_seed() >= 0) {
+ // Fetch random seed and modulate by device ID to make sure
+ // everyone doesn't have the same seed. We seem to have some
+ // solver instability if we have everyone with the same seed
+ Caffe::set_random_seed(
+ solver_->param().random_seed() + solver_->param().device_id());
+ }
+ solver_->Step(solver_->param().max_iter() - initial_iter_);
+}
+
+template<typename Dtype>
+void P2PSync<Dtype>::on_start() {
+#ifndef CPU_ONLY
+#ifdef DEBUG
+ int device;
+ CUDA_CHECK(cudaGetDevice(&device));
+ CHECK(device == solver_->param().device_id());
+#else
+// CHECK(false);
+#endif
+
+ // Wait for update from parent
+ if (parent_) {
+ P2PSync<Dtype> *parent = queue_.pop();
+ CHECK(parent == parent_);
+ }
+
+ // Update children
+ for (int i = children_.size() - 1; i >= 0; i--) {
+ Dtype* src = data_;
+ Dtype* dst = children_[i]->data_;
+
+#ifdef DEBUG
+ cudaPointerAttributes attributes;
+ CUDA_CHECK(cudaPointerGetAttributes(&attributes, src));
+ CHECK(attributes.device == device);
+ CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst));
+ CHECK(attributes.device == children_[i]->solver_->param().device_id());
+#endif
+
+ CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype),
+ cudaMemcpyDeviceToDevice, cudaStreamDefault));
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault));
+ children_[i]->queue_.push(this);
+ }
+#endif
+}
+
+template<typename Dtype>
+void P2PSync<Dtype>::on_gradients_ready() {
+#ifndef CPU_ONLY
+#ifdef DEBUG
+ int device;
+ CUDA_CHECK(cudaGetDevice(&device));
+ CHECK(device == solver_->param().device_id());
+#endif
+
+ // Sum children gradients as they appear in the queue
+ for (int i = 0; i < children_.size(); ++i) {
+ P2PSync<Dtype> *child = queue_.pop();
+ Dtype* src = child->parent_grads_;
+ Dtype* dst = diff_;
+
+#ifdef DEBUG
+ bool ok = false;
+ for (int j = 0; j < children_.size(); ++j) {
+ if (child == children_[j]) {
+ ok = true;
+ }
+ }
+ CHECK(ok);
+ cudaPointerAttributes attributes;
+ CUDA_CHECK(cudaPointerGetAttributes(&attributes, src));
+ CHECK(attributes.device == device);
+ CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst));
+ CHECK(attributes.device == device);
+#endif
+
+ caffe_gpu_add(size_, src, dst, dst);
+ }
+
+ // Send gradients to parent
+ if (parent_) {
+ Dtype* src = diff_;
+ Dtype* dst = parent_grads_;
+
+#ifdef DEBUG
+ cudaPointerAttributes attributes;
+ CUDA_CHECK(cudaPointerGetAttributes(&attributes, src));
+ CHECK(attributes.device == device);
+ CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst));
+ CHECK(attributes.device == parent_->solver_->param().device_id());
+#endif
+
+ CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), //
+ cudaMemcpyDeviceToDevice, cudaStreamDefault));
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault));
+ parent_->queue_.push(this);
+ } else {
+ // Loss functions divide gradients by the batch size, so to compensate
+ // for split batch, the root solver divides by number of solvers.
+ caffe_gpu_scal(size_, Dtype(1.0 / Caffe::solver_count()), diff_);
+ }
+#endif
+}
+
+template<typename Dtype>
+void P2PSync<Dtype>::run(const vector<int>& gpus) {
+ // Pair devices for map-reduce synchronization
+ vector<DevicePair> pairs;
+ DevicePair::compute(gpus, &pairs);
+ ostringstream s;
+ for (int i = 1; i < pairs.size(); ++i) {
+ s << (i == 1 ? "" : ", ") << pairs[i].parent() << ":" << pairs[i].device();
+ }
+ LOG(INFO)<< "GPUs pairs " << s.str();
+
+ SolverParameter param(solver_->param());
+ vector<shared_ptr<P2PSync<Dtype> > > syncs(gpus.size());
+
+ // Build the GPU tree by finding the parent for each solver
+ for (int attempts = 0; attempts < pairs.size(); ++attempts) {
+ for (int i = 1; i < pairs.size(); ++i) {
+ if (!syncs[i].get()) {
+ P2PSync<Dtype>* parent = NULL;
+ for (int j = 0; j < syncs.size(); ++j) {
+ P2PSync<Dtype>* sync = j == 0 ? this : syncs[j].get();
+ if (sync) {
+ const SolverParameter& p = sync->solver()->param();
+ if (p.device_id() == pairs[i].parent()) {
+ parent = sync;
+ }
+ }
+ }
+ if (parent) {
+ param.set_device_id(pairs[i].device());
+ syncs[i].reset(new P2PSync<Dtype>(solver_, parent, param));
+ parent->children_.push_back((P2PSync<Dtype>*) syncs[i].get());
+ }
+ }
+ }
+ }
+
+ LOG(INFO)<< "Starting Optimization";
+
+ for (int i = 1; i < syncs.size(); ++i) {
+ syncs[i]->StartInternalThread();
+ }
+
+ // Run root solver on current thread
+ solver_->Solve();
+
+ for (int i = 1; i < syncs.size(); ++i) {
+ syncs[i]->StopInternalThread();
+ }
+}
+
+INSTANTIATE_CLASS(Params);
+INSTANTIATE_CLASS(GPUParams);
+INSTANTIATE_CLASS(P2PSync);
+
+} // namespace caffe
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index 7cfcaa8b..fc0d961a 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -501,6 +501,7 @@ message DataParameter {
// to avoid all asynchronous sgd clients to start at the same point. The skip
// point would be set as rand_skip * rand(0,1). Note that rand_skip should not
// be larger than the number of keys in the database.
+ // DEPRECATED. Each solver accesses a different subset of the database.
optional uint32 rand_skip = 7 [default = 0];
optional DB backend = 8 [default = LEVELDB];
// DEPRECATED. See TransformationParameter. For data pre-processing, we can do
@@ -516,6 +517,9 @@ message DataParameter {
optional bool mirror = 6 [default = false];
// Force the encoded image to have 3 color channels
optional bool force_encoded_color = 9 [default = false];
+ // Prefetch queue (Number of batches to prefetch to host memory, increase if
+ // data access bandwidth varies).
+ optional uint32 prefetch = 10 [default = 4];
}
message DropoutParameter {
@@ -737,6 +741,10 @@ message PythonParameter {
// string, dictionary in Python dict format, JSON, etc. You may parse this
// string in `setup` method and use it in `forward` and `backward`.
optional string param_str = 3 [default = ''];
+ // Whether this PythonLayer is shared among worker solvers during data parallelism.
+ // If true, each worker solver sequentially run forward from this layer.
+ // This value should be set true if you are using it as a data layer.
+ optional bool share_in_parallel = 4 [default = false];
}
// Message that stores parameters used by ReductionLayer
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 78902ca0..248f238e 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -18,14 +18,14 @@
namespace caffe {
template <typename Dtype>
-Solver<Dtype>::Solver(const SolverParameter& param)
- : net_() {
+Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)
+ : net_(), callbacks_(), root_solver_(root_solver) {
Init(param);
}
template <typename Dtype>
-Solver<Dtype>::Solver(const string& param_file)
- : net_() {
+Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)
+ : net_(), callbacks_(), root_solver_(root_solver) {
SolverParameter param;
ReadProtoFromTextFileOrDie(param_file, &param);
Init(param);
@@ -33,17 +33,21 @@ Solver<Dtype>::Solver(const string& param_file)
template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
- LOG(INFO) << "Initializing solver from parameters: " << std::endl
- << param.DebugString();
+ CHECK(Caffe::root_solver() || root_solver_)
+ << "root_solver_ needs to be set for all non-root solvers";
+ LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
+ << std::endl << param.DebugString();
param_ = param;
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
- if (param_.random_seed() >= 0) {
+ if (Caffe::root_solver() && param_.random_seed() >= 0) {
Caffe::set_random_seed(param_.random_seed());
}
// Scaffolding code
InitTrainNet();
- InitTestNets();
- LOG(INFO) << "Solver scaffolding done.";
+ if (Caffe::root_solver()) {
+ InitTestNets();
+ LOG(INFO) << "Solver scaffolding done.";
+ }
iter_ = 0;
current_step_ = 0;
}
@@ -59,19 +63,22 @@ void Solver<Dtype>::InitTrainNet() {
<< "one of these fields specifying a train_net: " << field_names;
NetParameter net_param;
if (param_.has_train_net_param()) {
- LOG(INFO) << "Creating training net specified in train_net_param.";
+ LOG_IF(INFO, Caffe::root_solver())
+ << "Creating training net specified in train_net_param.";
net_param.CopyFrom(param_.train_net_param());
} else if (param_.has_train_net()) {
- LOG(INFO) << "Creating training net from train_net file: "
- << param_.train_net();
+ LOG_IF(INFO, Caffe::root_solver())
+ << "Creating training net from train_net file: " << param_.train_net();
ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);
}
if (param_.has_net_param()) {
- LOG(INFO) << "Creating training net specified in net_param.";
+ LOG_IF(INFO, Caffe::root_solver())
+ << "Creating training net specified in net_param.";
net_param.CopyFrom(param_.net_param());
}
if (param_.has_net()) {
- LOG(INFO) << "Creating training net from net file: " << param_.net();
+ LOG_IF(INFO, Caffe::root_solver())
+ << "Creating training net from net file: " << param_.net();
ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
}
// Set the correct NetState. We start with the solver defaults (lowest
@@ -83,11 +90,16 @@ void Solver<Dtype>::InitTrainNet() {
net_state.MergeFrom(net_param.state());
net_state.MergeFrom(param_.train_state());
net_param.mutable_state()->CopyFrom(net_state);
- net_.reset(new Net<Dtype>(net_param));
+ if (Caffe::root_solver()) {
+ net_.reset(new Net<Dtype>(net_param));
+ } else {
+ net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get()));
+ }
}
template <typename Dtype>
void Solver<Dtype>::InitTestNets() {
+ CHECK(Caffe::root_solver());
const bool has_net_param = param_.has_net_param();
const bool has_net_file = param_.has_net();
const int num_generic_nets = has_net_param + has_net_file;
@@ -157,7 +169,12 @@ void Solver<Dtype>::InitTestNets() {
net_params[i].mutable_state()->CopyFrom(net_state);
LOG(INFO)
<< "Creating test net (#" << i << ") specified by " << sources[i];
- test_nets_[i].reset(new Net<Dtype>(net_params[i]));
+ if (Caffe::root_solver()) {
+ test_nets_[i].reset(new Net<Dtype>(net_params[i]));
+ } else {
+ test_nets_[i].reset(new Net<Dtype>(net_params[i],
+ root_solver_->test_nets_[i].get()));
+ }
test_nets_[i]->set_debug_info(param_.debug_info());
}
}
@@ -175,10 +192,14 @@ void Solver<Dtype>::Step(int iters) {
// zero-init the params
net_->ClearParamDiffs();
if (param_.test_interval() && iter_ % param_.test_interval() == 0
- && (iter_ > 0 || param_.test_initialization())) {
+ && (iter_ > 0 || param_.test_initialization())
+ && Caffe::root_solver()) {
TestAll();
}
+ for (int i = 0; i < callbacks_.size(); ++i) {
+ callbacks_[i]->on_start();
+ }
const bool display = param_.display() && iter_ % param_.display() == 0;
net_->set_debug_info(display && param_.debug_info());
// accumulate the loss and gradient
@@ -198,7 +219,8 @@ void Solver<Dtype>::Step(int iters) {
losses[idx] = loss;
}
if (display) {
- LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss;
+ LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
+ << ", loss = " << smoothed_loss;
const vector<Blob<Dtype>*>& result = net_->output_blobs();
int score_index = 0;
for (int j = 0; j < result.size(); ++j) {
@@ -213,12 +235,15 @@ void Solver<Dtype>::Step(int iters) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * result_vec[k] << " loss)";
}
- LOG(INFO) << " Train net output #"
+ LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
<< score_index++ << ": " << output_name << " = "
<< result_vec[k] << loss_msg_stream.str();
}
}
}
+ for (int i = 0; i < callbacks_.size(); ++i) {
+ callbacks_[i]->on_gradients_ready();
+ }
ApplyUpdate();
// Increment the internal iter_ counter -- its value should always indicate
@@ -226,7 +251,9 @@ void Solver<Dtype>::Step(int iters) {
++iter_;
// Save a snapshot if needed.
- if (param_.snapshot() && iter_ % param_.snapshot() == 0) {
+ if (param_.snapshot()
+ && iter_ % param_.snapshot() == 0
+ && Caffe::root_solver()) {
Snapshot();
}
}
@@ -234,6 +261,7 @@ void Solver<Dtype>::Step(int iters) {
template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
+ CHECK(Caffe::root_solver());
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
@@ -278,6 +306,7 @@ void Solver<Dtype>::TestAll() {
template <typename Dtype>
void Solver<Dtype>::Test(const int test_net_id) {
+ CHECK(Caffe::root_solver());
LOG(INFO) << "Iteration " << iter_
<< ", Testing net (#" << test_net_id << ")";
CHECK_NOTNULL(test_nets_[test_net_id].get())->
@@ -328,13 +357,14 @@ void Solver<Dtype>::Test(const int test_net_id) {
<< " = " << loss_weight * mean_score << " loss)";
}
LOG(INFO) << " Test net output #" << i << ": " << output_name << " = "
- << mean_score << loss_msg_stream.str();
+ << mean_score << loss_msg_stream.str();
}
}
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
+ CHECK(Caffe::root_solver());
string model_filename;
switch (param_.snapshot_format()) {
case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
@@ -379,6 +409,7 @@ string Solver<Dtype>::SnapshotToHDF5() {
template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
+ CHECK(Caffe::root_solver());
string state_filename(state_file);
if (state_filename.size() >= 3 &&
state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
@@ -480,6 +511,7 @@ void SGDSolver<Dtype>::ClipGradients() {
template <typename Dtype>
void SGDSolver<Dtype>::ApplyUpdate() {
+ CHECK(Caffe::root_solver());
Dtype rate = GetLearningRate();
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
@@ -723,6 +755,7 @@ void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
template <typename Dtype>
void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+ CHECK(Caffe::root_solver());
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype momentum = this->param_.momentum();
@@ -783,6 +816,7 @@ void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
template <typename Dtype>
void AdaGradSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
+ CHECK(Caffe::root_solver());
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
Dtype delta = this->param_.delta();
diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp
index 7617ccfb..a667a867 100644
--- a/src/caffe/syncedmem.cpp
+++ b/src/caffe/syncedmem.cpp
@@ -12,8 +12,14 @@ SyncedMemory::~SyncedMemory() {
}
#ifndef CPU_ONLY
- if (gpu_ptr_) {
+ if (gpu_ptr_ && own_gpu_data_) {
+ int initial_device;
+ cudaGetDevice(&initial_device);
+ if (gpu_device_ != -1) {
+ CUDA_CHECK(cudaSetDevice(gpu_device_));
+ }
CUDA_CHECK(cudaFree(gpu_ptr_));
+ cudaSetDevice(initial_device);
}
#endif // CPU_ONLY
}
@@ -48,13 +54,17 @@ inline void SyncedMemory::to_gpu() {
#ifndef CPU_ONLY
switch (head_) {
case UNINITIALIZED:
+ CUDA_CHECK(cudaGetDevice(&gpu_device_));
CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
caffe_gpu_memset(size_, 0, gpu_ptr_);
head_ = HEAD_AT_GPU;
+ own_gpu_data_ = true;
break;
case HEAD_AT_CPU:
if (gpu_ptr_ == NULL) {
+ CUDA_CHECK(cudaGetDevice(&gpu_device_));
CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
+ own_gpu_data_ = true;
}
caffe_gpu_memcpy(size_, cpu_ptr_, gpu_ptr_);
head_ = SYNCED;
@@ -92,6 +102,26 @@ const void* SyncedMemory::gpu_data() {
#endif
}
+void SyncedMemory::set_gpu_data(void* data) {
+#ifndef CPU_ONLY
+ CHECK(data);
+ if (own_gpu_data_) {
+ int initial_device;
+ cudaGetDevice(&initial_device);
+ if (gpu_device_ != -1) {
+ CUDA_CHECK(cudaSetDevice(gpu_device_));
+ }
+ CUDA_CHECK(cudaFree(gpu_ptr_));
+ cudaSetDevice(initial_device);
+ }
+ gpu_ptr_ = data;
+ head_ = HEAD_AT_GPU;
+ own_gpu_data_ = false;
+#else
+ NO_GPU;
+#endif
+}
+
void* SyncedMemory::mutable_cpu_data() {
to_cpu();
head_ = HEAD_AT_CPU;
@@ -108,6 +138,20 @@ void* SyncedMemory::mutable_gpu_data() {
#endif
}
+#ifndef CPU_ONLY
+void SyncedMemory::async_gpu_push(const cudaStream_t& stream) {
+ CHECK(head_ == HEAD_AT_CPU);
+ if (gpu_ptr_ == NULL) {
+ CUDA_CHECK(cudaGetDevice(&gpu_device_));
+ CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
+ own_gpu_data_ = true;
+ }
+ const cudaMemcpyKind put = cudaMemcpyHostToDevice;
+ CUDA_CHECK(cudaMemcpyAsync(gpu_ptr_, cpu_ptr_, size_, put, stream));
+ // Assume caller will synchronize on the stream before use
+ head_ = SYNCED;
+}
+#endif
} // namespace caffe
diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp
index c97d4ede..1d255a86 100644
--- a/src/caffe/test/test_gradient_based_solver.cpp
+++ b/src/caffe/test/test_gradient_based_solver.cpp
@@ -8,6 +8,7 @@
#include "gtest/gtest.h"
#include "caffe/common.hpp"
+#include "caffe/parallel.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/solver.hpp"
#include "caffe/util/io.hpp"
@@ -35,6 +36,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
string snapshot_prefix_;
shared_ptr<SGDSolver<Dtype> > solver_;
+ shared_ptr<P2PSync<Dtype> > sync_;
int seed_;
// Dimensions are determined by generate_sample_data.py
// TODO this is brittle and the hdf5 file should be checked instead.
@@ -71,8 +73,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
string RunLeastSquaresSolver(const Dtype learning_rate,
const Dtype weight_decay, const Dtype momentum, const int num_iters,
- const int iter_size = 1, const bool snapshot = false,
- const char* from_snapshot = NULL) {
+ const int iter_size = 1, const int devices = 1,
+ const bool snapshot = false, const char* from_snapshot = NULL) {
ostringstream proto;
proto <<
"snapshot_after_train: " << snapshot << " "
@@ -185,7 +187,20 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
this->solver_->net()->Forward(empty_bottom_vec);
}
}
- this->solver_->Solve();
+ if (devices == 1) {
+ this->solver_->Solve();
+ } else {
+ LOG(INFO) << "Multi-GPU test on " << devices << " devices";
+ vector<int> gpus;
+ for (int i = 0; i < devices; ++i) {
+ gpus.push_back(i);
+ }
+ Caffe::set_solver_count(gpus.size());
+ this->sync_.reset(new P2PSync<Dtype>(
+ this->solver_, NULL, this->solver_->param()));
+ this->sync_->run(gpus);
+ Caffe::set_solver_count(1);
+ }
if (snapshot) {
ostringstream resume_file;
resume_file << snapshot_prefix_ << "/_iter_" << num_iters
@@ -428,20 +443,38 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0,
const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
const int iter_to_check = 0) {
- // Initialize the solver and run K (= iter_to_check) solver iterations.
- RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check);
-
- // Compute the (K+1)th update using the analytic least squares gradient.
- vector<shared_ptr<Blob<Dtype> > > updated_params;
- ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,
- &updated_params);
-
- // Reinitialize the solver and run K+1 solver iterations.
- RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
- iter_to_check + 1);
-
- // Check that the solver's solution matches ours.
- CheckLeastSquaresUpdate(updated_params);
+ const int kNum = num_;
+ const int kIterSize = 1;
+ // Test over all numbers of devices.
+ int available_devices = 1;
+#ifndef CPU_ONLY
+ if (Caffe::mode() == Caffe::GPU) {
+ CUDA_CHECK(cudaGetDeviceCount(&available_devices));
+ }
+#endif
+ for (int devices = 1; devices <= available_devices; ++devices) {
+ // Configure batch size for single / multi device equivalence.
+ // Constant data is needed for multi device as for accumulation.
+ num_ = kNum * devices;
+
+ // Initialize the solver and run K (= iter_to_check) solver iterations
+ // (on single device).
+ RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
+ iter_to_check, kIterSize, 1);
+
+ // Compute the (K+1)th update using the analytic least squares gradient.
+ vector<shared_ptr<Blob<Dtype> > > updated_params;
+ ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,
+ &updated_params);
+
+ // Reinitialize the solver and run K+1 solver iterations.
+ num_ = kNum;
+ RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
+ iter_to_check + 1, kIterSize, devices);
+
+ // Check that the solver's solution matches ours.
+ CheckLeastSquaresUpdate(updated_params);
+ }
}
void TestSnapshot(const Dtype learning_rate = 1.0,
@@ -451,8 +484,9 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
const int total_num_iters = num_iters * 2;
bool snapshot = false;
const int kIterSize = 1;
+ const int kDevices = 1;
RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
- total_num_iters, kIterSize, snapshot);
+ total_num_iters, kIterSize, kDevices, snapshot);
// Save the resulting param values.
vector<shared_ptr<Blob<Dtype> > > param_copies;
@@ -482,12 +516,13 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
// Run the solver for num_iters iterations and snapshot.
snapshot = true;
string snapshot_name = RunLeastSquaresSolver(learning_rate, weight_decay,
- momentum, num_iters, kIterSize, snapshot);
+ momentum, num_iters, kIterSize, kDevices, snapshot);
// Reinitialize the solver and run for num_iters more iterations.
snapshot = false;
RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
- total_num_iters, kIterSize, snapshot, snapshot_name.c_str());
+ total_num_iters, kIterSize, kDevices,
+ snapshot, snapshot_name.c_str());
// Check that params now match.
const vector<Blob<Dtype>*>& params = solver_->net()->learnable_params();
diff --git a/src/caffe/test/test_internal_thread.cpp b/src/caffe/test/test_internal_thread.cpp
index 31882b6d..93f1cc54 100644
--- a/src/caffe/test/test_internal_thread.cpp
+++ b/src/caffe/test/test_internal_thread.cpp
@@ -2,6 +2,7 @@
#include "gtest/gtest.h"
#include "caffe/internal_thread.hpp"
+#include "caffe/util/math_functions.hpp"
#include "caffe/test/test_caffe_main.hpp"
@@ -13,11 +14,40 @@ class InternalThreadTest : public ::testing::Test {};
TEST_F(InternalThreadTest, TestStartAndExit) {
InternalThread thread;
EXPECT_FALSE(thread.is_started());
- EXPECT_TRUE(thread.StartInternalThread());
+ thread.StartInternalThread();
EXPECT_TRUE(thread.is_started());
- EXPECT_TRUE(thread.WaitForInternalThreadToExit());
+ thread.StopInternalThread();
EXPECT_FALSE(thread.is_started());
}
+class TestThreadA : public InternalThread {
+ void InternalThreadEntry() {
+ EXPECT_EQ(4244559767, caffe_rng_rand());
+ }
+};
+
+class TestThreadB : public InternalThread {
+ void InternalThreadEntry() {
+ EXPECT_EQ(1726478280, caffe_rng_rand());
+ }
+};
+
+TEST_F(InternalThreadTest, TestRandomSeed) {
+ TestThreadA t1;
+ Caffe::set_random_seed(9658361);
+ t1.StartInternalThread();
+ t1.StopInternalThread();
+
+ TestThreadA t2;
+ Caffe::set_random_seed(9658361);
+ t2.StartInternalThread();
+ t2.StopInternalThread();
+
+ TestThreadB t3;
+ Caffe::set_random_seed(3435563);
+ t3.StartInternalThread();
+ t3.StopInternalThread();
+}
+
} // namespace caffe
diff --git a/src/caffe/test/test_layer_factory.cpp b/src/caffe/test/test_layer_factory.cpp
index efb1b37a..c86fafd0 100644
--- a/src/caffe/test/test_layer_factory.cpp
+++ b/src/caffe/test/test_layer_factory.cpp
@@ -1,11 +1,14 @@
#include <map>
#include <string>
+#include "boost/scoped_ptr.hpp"
#include "gtest/gtest.h"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/layer_factory.hpp"
+#include "caffe/util/db.hpp"
+#include "caffe/util/io.hpp"
#include "caffe/test/test_caffe_main.hpp"
@@ -21,11 +24,20 @@ TYPED_TEST(LayerFactoryTest, TestCreateLayer) {
typename LayerRegistry<Dtype>::CreatorRegistry& registry =
LayerRegistry<Dtype>::Registry();
shared_ptr<Layer<Dtype> > layer;
- LayerParameter layer_param;
for (typename LayerRegistry<Dtype>::CreatorRegistry::iterator iter =
registry.begin(); iter != registry.end(); ++iter) {
// Special case: PythonLayer is checked by pytest
if (iter->first == "Python") { continue; }
+ LayerParameter layer_param;
+ // Data layers expect a DB
+ if (iter->first == "Data") {
+ string tmp;
+ MakeTempDir(&tmp);
+ boost::scoped_ptr<db::DB> db(db::GetDB(DataParameter_DB_LEVELDB));
+ db->Open(tmp, db::NEW);
+ db->Close();
+ layer_param.mutable_data_param()->set_source(tmp);
+ }
layer_param.set_type(iter->first);
layer = LayerRegistry<Dtype>::CreateLayer(layer_param);
EXPECT_EQ(iter->first, layer->type());
diff --git a/src/caffe/test/test_upgrade_proto.cpp b/src/caffe/test/test_upgrade_proto.cpp
index eec62765..00672023 100644
--- a/src/caffe/test/test_upgrade_proto.cpp
+++ b/src/caffe/test/test_upgrade_proto.cpp
@@ -2,12 +2,15 @@
#include <string>
#include <vector>
+#include "boost/scoped_ptr.hpp"
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
+#include "caffe/util/db.hpp"
+#include "caffe/util/io.hpp"
#include "caffe/util/upgrade_proto.hpp"
#include "caffe/test/test_caffe_main.hpp"
@@ -2901,6 +2904,15 @@ TEST_F(NetUpgradeTest, TestUpgradeV1LayerType) {
continue; // Empty string isn't actually a valid layer type.
}
layer_param.set_type(v2_layer_type);
+ // Data layers expect a DB
+ if (v2_layer_type == "Data") {
+ string tmp;
+ MakeTempDir(&tmp);
+ boost::scoped_ptr<db::DB> db(db::GetDB(DataParameter_DB_LEVELDB));
+ db->Open(tmp, db::NEW);
+ db->Close();
+ layer_param.mutable_data_param()->set_source(tmp);
+ }
layer = LayerRegistry<float>::CreateLayer(layer_param);
EXPECT_EQ(v2_layer_type, layer->type());
}
diff --git a/src/caffe/util/blocking_queue.cpp b/src/caffe/util/blocking_queue.cpp
new file mode 100644
index 00000000..d1d1fa86
--- /dev/null
+++ b/src/caffe/util/blocking_queue.cpp
@@ -0,0 +1,96 @@
+#include <boost/thread.hpp>
+#include <string>
+
+#include "caffe/data_layers.hpp"
+#include "caffe/data_reader.hpp"
+#include "caffe/parallel.hpp"
+#include "caffe/util/blocking_queue.hpp"
+
+namespace caffe {
+
+template<typename T>
+class BlockingQueue<T>::sync {
+ public:
+ mutable boost::mutex mutex_;
+ boost::condition_variable condition_;
+};
+
+template<typename T>
+BlockingQueue<T>::BlockingQueue()
+ : sync_(new sync()) {
+}
+
+template<typename T>
+void BlockingQueue<T>::push(const T& t) {
+ boost::mutex::scoped_lock lock(sync_->mutex_);
+ queue_.push(t);
+ lock.unlock();
+ sync_->condition_.notify_one();
+}
+
+template<typename T>
+bool BlockingQueue<T>::try_pop(T* t) {
+ boost::mutex::scoped_lock lock(sync_->mutex_);
+
+ if (queue_.empty()) {
+ return false;
+ }
+
+ *t = queue_.front();
+ queue_.pop();
+ return true;
+}
+
+template<typename T>
+T BlockingQueue<T>::pop(const string& log_on_wait) {
+ boost::mutex::scoped_lock lock(sync_->mutex_);
+
+ while (queue_.empty()) {
+ if (!log_on_wait.empty()) {
+ LOG_EVERY_N(INFO, 1000)<< log_on_wait;
+ }
+ sync_->condition_.wait(lock);
+ }
+
+ T t = queue_.front();
+ queue_.pop();
+ return t;
+}
+
+template<typename T>
+bool BlockingQueue<T>::try_peek(T* t) {
+ boost::mutex::scoped_lock lock(sync_->mutex_);
+
+ if (queue_.empty()) {
+ return false;
+ }
+
+ *t = queue_.front();
+ return true;
+}
+
+template<typename T>
+T BlockingQueue<T>::peek() {
+ boost::mutex::scoped_lock lock(sync_->mutex_);
+
+ while (queue_.empty()) {
+ sync_->condition_.wait(lock);
+ }
+
+ return queue_.front();
+}
+
+template<typename T>
+size_t BlockingQueue<T>::size() const {
+ boost::mutex::scoped_lock lock(sync_->mutex_);
+ return queue_.size();
+}
+
+template class BlockingQueue<Batch<float>*>;
+template class BlockingQueue<Batch<double>*>;
+template class BlockingQueue<Datum*>;
+template class BlockingQueue<shared_ptr<DataReader::QueuePair> >;
+template class BlockingQueue<P2PSync<float>*>;
+template class BlockingQueue<P2PSync<double>*>;
+
+} // namespace caffe