summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJun Shi <junshi@yahoo-inc.com>2016-01-22 09:58:37 -0800
committerJun Shi <junshi@yahoo-inc.com>2016-03-05 07:07:21 -0800
commit01528918c707df82e5910bea0270d7987db5abd8 (patch)
tree6a9b136eda30dd7b15a7d2daa6e874c22e2b4467
parent6a0b98768d4745714e31949b87382ff562be6724 (diff)
downloadcaffeonacl-01528918c707df82e5910bea0270d7987db5abd8.tar.gz
caffeonacl-01528918c707df82e5910bea0270d7987db5abd8.tar.bz2
caffeonacl-01528918c707df82e5910bea0270d7987db5abd8.zip
split p2psync::run()
-rw-r--r--include/caffe/parallel.hpp5
-rw-r--r--src/caffe/parallel.cpp20
-rw-r--r--src/caffe/test/test_gradient_based_solver.cpp2
-rw-r--r--tools/caffe.cpp2
4 files changed, 19 insertions, 10 deletions
diff --git a/include/caffe/parallel.hpp b/include/caffe/parallel.hpp
index 85fc2b55..6c496c88 100644
--- a/include/caffe/parallel.hpp
+++ b/include/caffe/parallel.hpp
@@ -93,7 +93,10 @@ class P2PSync : public GPUParams<Dtype>, public Solver<Dtype>::Callback,
return solver_;
}
- void run(const vector<int>& gpus);
+ void Run(const vector<int>& gpus);
+ void Prepare(const vector<int>& gpus,
+ vector<shared_ptr<P2PSync<Dtype> > >* syncs);
+ inline const int initial_iter() const { return initial_iter_; }
protected:
void on_start();
diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp
index 62f5d738..5bc41c6a 100644
--- a/src/caffe/parallel.cpp
+++ b/src/caffe/parallel.cpp
@@ -380,7 +380,8 @@ void P2PSync<Dtype>::on_gradients_ready() {
}
template<typename Dtype>
-void P2PSync<Dtype>::run(const vector<int>& gpus) {
+void P2PSync<Dtype>::Prepare(const vector<int>& gpus,
+ vector<shared_ptr<P2PSync<Dtype> > >* syncs) {
// Pair devices for map-reduce synchronization
vector<DevicePair> pairs;
DevicePair::compute(gpus, &pairs);
@@ -391,15 +392,14 @@ void P2PSync<Dtype>::run(const vector<int>& gpus) {
LOG(INFO)<< "GPUs pairs " << s.str();
SolverParameter param(solver_->param());
- vector<shared_ptr<P2PSync<Dtype> > > syncs(gpus.size());
// Build the GPU tree by finding the parent for each solver
for (int attempts = 0; attempts < pairs.size(); ++attempts) {
for (int i = 1; i < pairs.size(); ++i) {
- if (!syncs[i].get()) {
+ if (!syncs->at(i).get()) {
P2PSync<Dtype>* parent = NULL;
- for (int j = 0; j < syncs.size(); ++j) {
- P2PSync<Dtype>* sync = j == 0 ? this : syncs[j].get();
+ for (int j = 0; j < syncs->size(); ++j) {
+ P2PSync<Dtype>* sync = j == 0 ? this : syncs->at(j).get();
if (sync) {
const SolverParameter& p = sync->solver()->param();
if (p.device_id() == pairs[i].parent()) {
@@ -409,12 +409,18 @@ void P2PSync<Dtype>::run(const vector<int>& gpus) {
}
if (parent) {
param.set_device_id(pairs[i].device());
- syncs[i].reset(new P2PSync<Dtype>(solver_, parent, param));
- parent->children_.push_back((P2PSync<Dtype>*) syncs[i].get());
+ syncs->at(i).reset(new P2PSync<Dtype>(solver_, parent, param));
+ parent->children_.push_back((P2PSync<Dtype>*) syncs->at(i).get());
}
}
}
}
+}
+
+template<typename Dtype>
+void P2PSync<Dtype>::Run(const vector<int>& gpus) {
+ vector<shared_ptr<P2PSync<Dtype> > > syncs(gpus.size());
+ Prepare(gpus, &syncs);
LOG(INFO)<< "Starting Optimization";
diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp
index 09ec3a7e..975a8f0f 100644
--- a/src/caffe/test/test_gradient_based_solver.cpp
+++ b/src/caffe/test/test_gradient_based_solver.cpp
@@ -204,7 +204,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
Caffe::set_solver_count(gpus.size());
this->sync_.reset(new P2PSync<Dtype>(
this->solver_, NULL, this->solver_->param()));
- this->sync_->run(gpus);
+ this->sync_->Run(gpus);
Caffe::set_solver_count(1);
}
if (snapshot) {
diff --git a/tools/caffe.cpp b/tools/caffe.cpp
index 95b2f82c..5d9331f0 100644
--- a/tools/caffe.cpp
+++ b/tools/caffe.cpp
@@ -214,7 +214,7 @@ int train() {
if (gpus.size() > 1) {
caffe::P2PSync<float> sync(solver, NULL, solver->param());
- sync.run(gpus);
+ sync.Run(gpus);
} else {
LOG(INFO) << "Starting Optimization";
solver->Solve();