diff options
author | Jun Shi <junshi@yahoo-inc.com> | 2016-01-22 09:58:37 -0800 |
---|---|---|
committer | Jun Shi <junshi@yahoo-inc.com> | 2016-03-05 07:07:21 -0800 |
commit | 01528918c707df82e5910bea0270d7987db5abd8 (patch) | |
tree | 6a9b136eda30dd7b15a7d2daa6e874c22e2b4467 | |
parent | 6a0b98768d4745714e31949b87382ff562be6724 (diff) | |
download | caffeonacl-01528918c707df82e5910bea0270d7987db5abd8.tar.gz caffeonacl-01528918c707df82e5910bea0270d7987db5abd8.tar.bz2 caffeonacl-01528918c707df82e5910bea0270d7987db5abd8.zip |
split p2psync::run()
-rw-r--r-- | include/caffe/parallel.hpp | 5 | ||||
-rw-r--r-- | src/caffe/parallel.cpp | 20 | ||||
-rw-r--r-- | src/caffe/test/test_gradient_based_solver.cpp | 2 | ||||
-rw-r--r-- | tools/caffe.cpp | 2 |
4 files changed, 19 insertions, 10 deletions
diff --git a/include/caffe/parallel.hpp b/include/caffe/parallel.hpp index 85fc2b55..6c496c88 100644 --- a/include/caffe/parallel.hpp +++ b/include/caffe/parallel.hpp @@ -93,7 +93,10 @@ class P2PSync : public GPUParams<Dtype>, public Solver<Dtype>::Callback, return solver_; } - void run(const vector<int>& gpus); + void Run(const vector<int>& gpus); + void Prepare(const vector<int>& gpus, + vector<shared_ptr<P2PSync<Dtype> > >* syncs); + inline const int initial_iter() const { return initial_iter_; } protected: void on_start(); diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp index 62f5d738..5bc41c6a 100644 --- a/src/caffe/parallel.cpp +++ b/src/caffe/parallel.cpp @@ -380,7 +380,8 @@ void P2PSync<Dtype>::on_gradients_ready() { } template<typename Dtype> -void P2PSync<Dtype>::run(const vector<int>& gpus) { +void P2PSync<Dtype>::Prepare(const vector<int>& gpus, + vector<shared_ptr<P2PSync<Dtype> > >* syncs) { // Pair devices for map-reduce synchronization vector<DevicePair> pairs; DevicePair::compute(gpus, &pairs); @@ -391,15 +392,14 @@ void P2PSync<Dtype>::run(const vector<int>& gpus) { LOG(INFO)<< "GPUs pairs " << s.str(); SolverParameter param(solver_->param()); - vector<shared_ptr<P2PSync<Dtype> > > syncs(gpus.size()); // Build the GPU tree by finding the parent for each solver for (int attempts = 0; attempts < pairs.size(); ++attempts) { for (int i = 1; i < pairs.size(); ++i) { - if (!syncs[i].get()) { + if (!syncs->at(i).get()) { P2PSync<Dtype>* parent = NULL; - for (int j = 0; j < syncs.size(); ++j) { - P2PSync<Dtype>* sync = j == 0 ? this : syncs[j].get(); + for (int j = 0; j < syncs->size(); ++j) { + P2PSync<Dtype>* sync = j == 0 ? this : syncs->at(j).get(); if (sync) { const SolverParameter& p = sync->solver()->param(); if (p.device_id() == pairs[i].parent()) { @@ -409,12 +409,18 @@ void P2PSync<Dtype>::run(const vector<int>& gpus) { } if (parent) { param.set_device_id(pairs[i].device()); - syncs[i].reset(new P2PSync<Dtype>(solver_, parent, param)); - parent->children_.push_back((P2PSync<Dtype>*) syncs[i].get()); + syncs->at(i).reset(new P2PSync<Dtype>(solver_, parent, param)); + parent->children_.push_back((P2PSync<Dtype>*) syncs->at(i).get()); } } } } +} + +template<typename Dtype> +void P2PSync<Dtype>::Run(const vector<int>& gpus) { + vector<shared_ptr<P2PSync<Dtype> > > syncs(gpus.size()); + Prepare(gpus, &syncs); LOG(INFO)<< "Starting Optimization"; diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 09ec3a7e..975a8f0f 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -204,7 +204,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> { Caffe::set_solver_count(gpus.size()); this->sync_.reset(new P2PSync<Dtype>( this->solver_, NULL, this->solver_->param())); - this->sync_->run(gpus); + this->sync_->Run(gpus); Caffe::set_solver_count(1); } if (snapshot) { diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 95b2f82c..5d9331f0 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -214,7 +214,7 @@ int train() { if (gpus.size() > 1) { caffe::P2PSync<float> sync(solver, NULL, solver->param()); - sync.run(gpus); + sync.Run(gpus); } else { LOG(INFO) << "Starting Optimization"; solver->Solve(); |