diff options
author | mhouston <mhouston@nvidia.com> | 2015-07-10 16:05:48 -0700 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2015-08-09 15:16:02 -0700 |
commit | 335bee737cb2e715abe685e6029afc83ccd8f404 (patch) | |
tree | 98f3591655f2ea8783387befd1b3adf2468e7718 | |
parent | e5575cf17a43a56e4ba9bc5465548ac0512197d8 (diff) | |
download | caffeonacl-335bee737cb2e715abe685e6029afc83ccd8f404.tar.gz caffeonacl-335bee737cb2e715abe685e6029afc83ccd8f404.tar.bz2 caffeonacl-335bee737cb2e715abe685e6029afc83ccd8f404.zip |
Detect topology corner cases and improve broadcast order
- Start with distant nodes in broadcast
- Fix outside loop to loop for full tree depth
-rw-r--r-- | src/caffe/parallel.cpp | 73 |
1 files changed, 41 insertions, 32 deletions
diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp index 3fef8cfd..5a08df6c 100644 --- a/src/caffe/parallel.cpp +++ b/src/caffe/parallel.cpp @@ -119,18 +119,23 @@ void DevicePair::compute(const vector<int> devices, vector<DevicePair>* pairs) { #ifndef CPU_ONLY vector<int> remaining(devices); + // Depth for reduction tree + int remaining_depth = static_cast<int>(ceil(log2(remaining.size()))); + // Group GPUs by board - for (int i = 0; i < remaining.size(); ++i) { - for (int j = i + 1; j < remaining.size(); ++j) { - cudaDeviceProp a, b; - CUDA_CHECK(cudaGetDeviceProperties(&a, remaining[i])); - CUDA_CHECK(cudaGetDeviceProperties(&b, remaining[j])); - if (a.isMultiGpuBoard && b.isMultiGpuBoard) { - if (a.multiGpuBoardGroupID == b.multiGpuBoardGroupID) { - pairs->push_back(DevicePair(remaining[i], remaining[j])); - DLOG(INFO) << "GPU board: " << remaining[i] << ":" << remaining[j]; - remaining.erase(remaining.begin() + j); - break; + for (int d = 0; d < remaining_depth; ++d) { + for (int i = 0; i < remaining.size(); ++i) { + for (int j = i + 1; j < remaining.size(); ++j) { + cudaDeviceProp a, b; + CUDA_CHECK(cudaGetDeviceProperties(&a, remaining[i])); + CUDA_CHECK(cudaGetDeviceProperties(&b, remaining[j])); + if (a.isMultiGpuBoard && b.isMultiGpuBoard) { + if (a.multiGpuBoardGroupID == b.multiGpuBoardGroupID) { + pairs->push_back(DevicePair(remaining[i], remaining[j])); + DLOG(INFO) << "GPU board: " << remaining[i] << ":" << remaining[j]; + remaining.erase(remaining.begin() + j); + break; + } } } } @@ -142,15 +147,19 @@ void DevicePair::compute(const vector<int> devices, vector<DevicePair>* pairs) { DLOG(INFO) << "GPUs paired by boards, remaining: " << s.str(); // Group by P2P accessibility - for (int i = 0; i < remaining.size(); ++i) { - for (int j = i + 1; j < remaining.size(); ++j) { - int access; - CUDA_CHECK(cudaDeviceCanAccessPeer(&access, remaining[i], remaining[j])); - if (access) { - pairs->push_back(DevicePair(remaining[i], remaining[j])); - DLOG(INFO) << "P2P pair: " << remaining[i] << ":" << remaining[j]; - remaining.erase(remaining.begin() + j); - break; + remaining_depth = ceil(log2(remaining.size())); + for (int d = 0; d < remaining_depth; ++d) { + for (int i = 0; i < remaining.size(); ++i) { + for (int j = i + 1; j < remaining.size(); ++j) { + int access; + CUDA_CHECK( + cudaDeviceCanAccessPeer(&access, remaining[i], remaining[j])); + if (access) { + pairs->push_back(DevicePair(remaining[i], remaining[j])); + DLOG(INFO) << "P2P pair: " << remaining[i] << ":" << remaining[j]; + remaining.erase(remaining.begin() + j); + break; + } } } } @@ -161,15 +170,19 @@ void DevicePair::compute(const vector<int> devices, vector<DevicePair>* pairs) { DLOG(INFO) << "GPUs paired by P2P access, remaining: " << s.str(); // Group remaining - for (int i = 0; i < remaining.size(); ++i) { - for (int j = i + 1; j < remaining.size(); ++j) { - pairs->push_back(DevicePair(remaining[i], remaining[j])); - DLOG(INFO) << "Remaining pair: " << remaining[i] << ":" << remaining[j]; - remaining.erase(remaining.begin() + j); - break; + remaining_depth = ceil(log2(remaining.size())); + for (int d = 0; d < remaining_depth; ++d) { + for (int i = 0; i < remaining.size(); ++i) { + pairs->push_back(DevicePair(remaining[i], remaining[i + 1])); + DLOG(INFO) << "Remaining pair: " << remaining[i] << ":" + << remaining[i + 1]; + remaining.erase(remaining.begin() + i + 1); } } + + // Should only be the parent node remaining CHECK_EQ(remaining.size(), 1); + pairs->insert(pairs->begin(), DevicePair(-1, remaining[0])); CHECK(pairs->size() == devices.size()); @@ -289,7 +302,7 @@ void P2PSync<Dtype>::on_start() { } // Update children - for (int i = 0; i < children_.size(); ++i) { + for (int i = children_.size() - 1; i >= 0; i--) { Dtype* src = data_; Dtype* dst = children_[i]->data_; @@ -301,13 +314,9 @@ void P2PSync<Dtype>::on_start() { CHECK(attributes.device == children_[i]->solver_->param().device_id()); #endif - CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), // + CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), cudaMemcpyDeviceToDevice, cudaStreamDefault)); - } - if (children_.size()) { CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); - } - for (int i = 0; i < children_.size(); ++i) { children_[i]->queue_.push(this); } #endif |