summaryrefslogtreecommitdiff
path: root/torch/lib
diff options
context:
space:
mode:
authorEdward Yang <ezyang@fb.com>2018-10-19 09:47:02 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-10-19 09:48:41 -0700
commit68f4a4b3ba8770594f4aed16b7306cda28bf4a57 (patch)
treeff331837ff5187a4bfd70f91d44326e7f4b31125 /torch/lib
parent6ec2f091885807d37deea51e9d8116eb8d7ea8f8 (diff)
downloadpytorch-68f4a4b3ba8770594f4aed16b7306cda28bf4a57.tar.gz
pytorch-68f4a4b3ba8770594f4aed16b7306cda28bf4a57.tar.bz2
pytorch-68f4a4b3ba8770594f4aed16b7306cda28bf4a57.zip
Delete THCStreamGuard in favor of CUDAGuard, also c10d code cleanup (#12849)
Summary: I got annoyed at waiting for OSS to tell me my c10d builds were busted, so I also added support for building the test scripts in fbcode and fixed the warnings this uncovered. Pull Request resolved: https://github.com/pytorch/pytorch/pull/12849 Reviewed By: pietern Differential Revision: D10457671 fbshipit-source-id: 5b0e36c606e397323f313f09dfce64d2df88faed
Diffstat (limited to 'torch/lib')
-rw-r--r--torch/lib/c10d/CUDAUtils.cpp19
-rw-r--r--torch/lib/c10d/CUDAUtils.hpp43
-rw-r--r--torch/lib/c10d/ProcessGroupGloo.cpp26
-rw-r--r--torch/lib/c10d/ProcessGroupGloo.hpp3
-rw-r--r--torch/lib/c10d/ProcessGroupNCCL.cpp52
-rw-r--r--torch/lib/c10d/ProcessGroupNCCL.hpp5
-rw-r--r--torch/lib/c10d/private/CUDAUtils.hpp40
-rw-r--r--torch/lib/c10d/test/CUDATest.cu4
-rw-r--r--torch/lib/c10d/test/CUDATest.hpp3
-rw-r--r--torch/lib/c10d/test/FileStoreTest.cpp4
-rw-r--r--torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp20
-rw-r--r--torch/lib/c10d/test/ProcessGroupGlooTest.cpp19
-rw-r--r--torch/lib/c10d/test/ProcessGroupNCCLTest.cpp18
-rw-r--r--torch/lib/c10d/test/TCPStoreTest.cpp4
14 files changed, 80 insertions, 180 deletions
diff --git a/torch/lib/c10d/CUDAUtils.cpp b/torch/lib/c10d/CUDAUtils.cpp
index 2a1a7d1769..db5e0cbaf2 100644
--- a/torch/lib/c10d/CUDAUtils.cpp
+++ b/torch/lib/c10d/CUDAUtils.cpp
@@ -13,7 +13,7 @@ CUDAEvent CUDAEvent::create(unsigned int flags) {
return event;
}
-CUDAEvent::~CUDAEvent() noexcept (false) {
+CUDAEvent::~CUDAEvent() noexcept(false) {
if (event_ != nullptr) {
// cudaEventDestroy must run on the same device of the event,
// otherwise it creates a context on default device as well.
@@ -23,21 +23,4 @@ CUDAEvent::~CUDAEvent() noexcept (false) {
}
}
-CUDAStream CUDAStream::create() {
- CUDAStream stream;
- stream.stream_ = THCStream_new();
- return stream;
-}
-
-CUDAStream::~CUDAStream() {
- if (stream_ != nullptr) {
- THCStream_free(stream_);
- stream_ = nullptr;
- }
-}
-
-cudaStream_t CUDAStream::getStream() const {
- return THCStream_stream(stream_);
-}
-
} // namespace c10d
diff --git a/torch/lib/c10d/CUDAUtils.hpp b/torch/lib/c10d/CUDAUtils.hpp
index c0db85e7ed..b8e99e32ca 100644
--- a/torch/lib/c10d/CUDAUtils.hpp
+++ b/torch/lib/c10d/CUDAUtils.hpp
@@ -1,7 +1,5 @@
#pragma once
-typedef struct CUDAStreamInternals THCStream;
-
#include <algorithm>
#include <cuda.h>
@@ -50,45 +48,4 @@ class CUDAEvent {
cudaEvent_t event_;
};
-// RAII wrapper for CUDA streams.
-//
-// This wrapper uses THCStream instead of cudaStream_t because we need
-// to interact with the THC API for selecting the current stream.
-// Doing this without having a THCStream pointer is cumbersome.
-//
-class CUDAStream {
- public:
- CUDAStream(THCStream* stream) : stream_(stream) {}
-
- CUDAStream() : CUDAStream(nullptr) {}
-
- ~CUDAStream();
-
- static CUDAStream create();
-
- // Must not be copyable.
- CUDAStream& operator=(const CUDAStream&) = delete;
- CUDAStream(const CUDAStream&) = delete;
-
- // Must be move constructable.
- CUDAStream(CUDAStream&& other) {
- std::swap(stream_, other.stream_);
- }
-
- // Must be move assignable.
- CUDAStream& operator=(CUDAStream&& other) {
- std::swap(stream_, other.stream_);
- return *this;
- }
-
- cudaStream_t getStream() const;
-
- THCStream* getTHCStream() {
- return stream_;
- }
-
- protected:
- THCStream* stream_;
-};
-
} // namespace c10d
diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp
index a521c36eac..9a01757105 100644
--- a/torch/lib/c10d/ProcessGroupGloo.cpp
+++ b/torch/lib/c10d/ProcessGroupGloo.cpp
@@ -6,6 +6,8 @@
#include <gloo/broadcast_one_to_all.h>
#ifdef USE_CUDA
+#include <ATen/cuda/CUDAGuard.h>
+
#include <gloo/cuda_allreduce_halving_doubling.h>
#include <gloo/cuda_allreduce_ring_chunked.h>
#include <gloo/cuda_broadcast_one_to_all.h>
@@ -104,7 +106,7 @@ const ::gloo::ReductionFunction<T>* reductionFunction(const ReduceOp& r) {
std::vector<cudaStream_t> getStreamVector(AlgorithmEntry& entry) {
std::vector<cudaStream_t> streams(entry.streams.size());
for (size_t i = 0; i < entry.streams.size(); i++) {
- streams[i] = entry.streams[i].getStream();
+ streams[i] = entry.streams[i].stream();
}
return streams;
}
@@ -117,7 +119,7 @@ void synchronizeStreams(THCState* thcState, AlgorithmEntry* entry) {
for (size_t i = 0; i < key.devices.size(); i++) {
const auto& device = key.devices[i];
auto publicStream = THCState_getCurrentStreamOnDevice(thcState, device);
- auto privateStream = entry->streams[i].getStream();
+ auto privateStream = entry->streams[i].stream();
auto event = entry->events[i].getEvent();
// Synchronize private stream with public stream.
@@ -201,7 +203,7 @@ void ProcessGroupGloo::WorkGloo::finish(const AlgorithmEntry& entry) {
deviceGuard.set_index(devices_[i]);
events_[i] = CUDAEvent::create();
const auto& event = events_[i].getEvent();
- const auto& stream = entry.streams[i].getStream();
+ const auto& stream = entry.streams[i].stream();
C10D_CUDA_CHECK(cudaEventRecord(event, stream));
}
}
@@ -528,7 +530,7 @@ EntryType ProcessGroupGloo::construct(const AlgorithmKey& key) {
entry->events.resize(key.devices.size());
for (size_t i = 0; i < key.devices.size(); i++) {
deviceGuard.set_index(key.devices[i]);
- entry->streams[i] = CUDAStream::create();
+ entry->streams[i] = at::cuda::createCUDAStream();
entry->events[i] = CUDAEvent::create();
}
}
@@ -606,10 +608,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::broadcast(
entry->run = [=]() mutable {
entry->algorithm->run();
for (size_t i = 0; i < tensors.size(); i++) {
- // The THCStreamGuard is a RAII wrapper for temporarily
- // overriding the current THCStream. This also sets the
- // current device to the stream's device.
- THCStreamGuard guard(thcState, entry->streams[i]);
+ at::cuda::CUDAGuard guard(entry->streams[i]);
tensors[i].copy_(entry->src[i]);
}
};
@@ -657,10 +656,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::allreduce(
entry->run = [=]() mutable {
entry->algorithm->run();
for (size_t i = 0; i < tensors.size(); i++) {
- // The THCStreamGuard is a RAII wrapper for temporarily
- // overriding the current THCStream. This also sets the
- // current device to the stream's device.
- THCStreamGuard guard(thcState, entry->streams[i]);
+ at::cuda::CUDAGuard guard(entry->streams[i]);
tensors[i].copy_(entry->src[i]);
}
};
@@ -722,7 +718,7 @@ uint32_t checkTag(int32_t tag) {
if (tag < 0) {
throw std::runtime_error("Tag must be >= 0");
}
- return (uint32_t) tag;
+ return (uint32_t)tag;
}
std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::send(
@@ -797,9 +793,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::barrier() {
key.collectiveType = CollectiveType::BARRIER;
auto entry = checkout(key);
- entry->run = [=]() mutable {
- entry->algorithm->run();
- };
+ entry->run = [=]() mutable { entry->algorithm->run(); };
return enqueue(entry);
}
diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp
index 7b4bfbccbe..03f707561c 100644
--- a/torch/lib/c10d/ProcessGroupGloo.hpp
+++ b/torch/lib/c10d/ProcessGroupGloo.hpp
@@ -16,6 +16,7 @@
#include <torch/csrc/utils/hash.h>
#ifdef USE_CUDA
+#include <ATen/cuda/CUDAStream.h>
#include <c10d/CUDAUtils.hpp>
#endif
@@ -119,7 +120,7 @@ struct AlgorithmEntry {
// true, the caller can launch new CUDA kernels and they will be
// correctly sequenced.
//
- std::vector<CUDAStream> streams;
+ std::vector<at::cuda::CUDAStream> streams;
std::vector<CUDAEvent> events;
#endif
diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp
index db3fcf5e21..9568769d9f 100644
--- a/torch/lib/c10d/ProcessGroupNCCL.cpp
+++ b/torch/lib/c10d/ProcessGroupNCCL.cpp
@@ -6,6 +6,8 @@
#include <THC.h>
#include <THC/THCGeneral.hpp>
+#include <ATen/cuda/CUDAGuard.h>
+
#include <c10d/Utils.hpp>
#include <c10d/private/CUDAUtils.hpp>
@@ -69,18 +71,18 @@ void syncStreams(
THCState* thcState,
const std::vector<at::Device>& devices,
std::vector<CUDAEvent>& ncclEvents,
- std::vector<CUDAStream>& ncclStreams) {
+ std::vector<at::cuda::CUDAStream>& ncclStreams) {
at::DeviceGuard gpuGuard;
for (size_t i = 0; i < devices.size(); ++i) {
gpuGuard.set_index(devices[i].index());
auto currentThcStream =
THCState_getCurrentStreamOnDevice(thcState, devices[i].index());
- CUDAStream& ncclStream = ncclStreams[i];
+ at::cuda::CUDAStream& ncclStream = ncclStreams[i];
CUDAEvent& ncclEvent = ncclEvents[i];
C10D_CUDA_CHECK(cudaEventRecord(ncclEvent.getEvent(), currentThcStream));
C10D_CUDA_CHECK(
- cudaStreamWaitEvent(ncclStream.getStream(), ncclEvent.getEvent(), 0));
+ cudaStreamWaitEvent(ncclStream.stream(), ncclEvent.getEvent(), 0));
}
}
@@ -241,7 +243,7 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
at::DeviceGuard gpuGuard;
std::vector<CUDAEvent> eventVal;
- std::vector<CUDAStream> streamVal;
+ std::vector<at::cuda::CUDAStream> streamVal;
eventVal.resize(devices.size());
streamVal.resize(devices.size());
@@ -258,7 +260,7 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);
// Also create the NCCL streams and events
- streamVal[i] = CUDAStream::create();
+ streamVal[i] = at::cuda::createCUDAStream();
// Event created using cudaEventDisableTiming flag and not
// cudaEventBlockingSync flag will provide the best performance when used
// with cudaStreamWaitEvent() and cudaEventQuery(). Since we here don't
@@ -377,7 +379,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allreduce(
for (size_t i = 0; i < tensors.size(); ++i) {
gpuGuard.set_index(devices[i].index());
- CUDAStream& ncclStream = ncclStreams_[key][i];
+ at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
C10D_NCCL_CHECK(ncclAllReduce(
tensors[i].data_ptr(),
@@ -386,18 +388,17 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allreduce(
getNcclDataType(tensors[i].type().scalarType()),
ncclOp[opts.reduceOp],
ncclComms[i]->getNcclComm(),
- ncclStream.getStream()));
+ ncclStream.stream()));
}
C10D_NCCL_CHECK(ncclGroupEnd());
// Event should only be recorded after the ncclGroupEnd()
for (size_t i = 0; i < tensors.size(); ++i) {
- CUDAStream& ncclStream = ncclStreams_[key][i];
+ at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
CUDAEvent& cudaEvent = work->cudaEvents_[i];
- C10D_CUDA_CHECK(
- cudaEventRecord(cudaEvent.getEvent(), ncclStream.getStream()));
+ C10D_CUDA_CHECK(cudaEventRecord(cudaEvent.getEvent(), ncclStream.stream()));
}
return work;
@@ -427,7 +428,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::broadcast(
for (size_t i = 0; i < tensors.size(); ++i) {
gpuGuard.set_index(devices[i].index());
- CUDAStream& ncclStream = ncclStreams_[key][i];
+ at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
// root rank of the the GPU
int root = opts.rootRank * tensors.size() + opts.rootTensor;
@@ -437,18 +438,17 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::broadcast(
getNcclDataType(tensors[i].type().scalarType()),
root,
ncclComms[i]->getNcclComm(),
- ncclStream.getStream()));
+ ncclStream.stream()));
}
C10D_NCCL_CHECK(ncclGroupEnd());
// Event should only be recorded after the ncclGroupEnd()
for (size_t i = 0; i < tensors.size(); ++i) {
- CUDAStream& ncclStream = ncclStreams_[key][i];
+ at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
CUDAEvent& cudaEvent = work->cudaEvents_[i];
- C10D_CUDA_CHECK(
- cudaEventRecord(cudaEvent.getEvent(), ncclStream.getStream()));
+ C10D_CUDA_CHECK(cudaEventRecord(cudaEvent.getEvent(), ncclStream.stream()));
}
return work;
@@ -478,7 +478,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce(
for (size_t i = 0; i < tensors.size(); ++i) {
gpuGuard.set_index(devices[i].index());
- CUDAStream& ncclStream = ncclStreams_[key][i];
+ at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
// root rank of the the GPU
int root = opts.rootRank * tensors.size() + opts.rootTensor;
@@ -490,18 +490,17 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce(
ncclOp[opts.reduceOp],
root,
ncclComms[i]->getNcclComm(),
- ncclStream.getStream()));
+ ncclStream.stream()));
}
C10D_NCCL_CHECK(ncclGroupEnd());
// Event should only be recorded after the ncclGroupEnd()
for (size_t i = 0; i < tensors.size(); ++i) {
- CUDAStream& ncclStream = ncclStreams_[key][i];
+ at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
CUDAEvent& cudaEvent = work->cudaEvents_[i];
- C10D_CUDA_CHECK(
- cudaEventRecord(cudaEvent.getEvent(), ncclStream.getStream()));
+ C10D_CUDA_CHECK(cudaEventRecord(cudaEvent.getEvent(), ncclStream.stream()));
}
return work;
@@ -550,7 +549,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather(
for (size_t i = 0; i < inputTensors.size(); ++i) {
gpuGuard.set_index(devices[i].index());
- CUDAStream& ncclStream = ncclStreams_[key][i];
+ at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
C10D_NCCL_CHECK(ncclAllGather(
inputTensors[i].data_ptr(),
@@ -558,15 +557,15 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather(
inputTensors[i].numel(),
getNcclDataType(inputTensors[i].type().scalarType()),
ncclComms[i]->getNcclComm(),
- ncclStream.getStream()));
+ ncclStream.stream()));
}
C10D_NCCL_CHECK(ncclGroupEnd());
// Copy the flattened output tensors to the outputs
for (size_t i = 0; i < outputTensors.size(); ++i) {
- CUDAStream& ncclStream = ncclStreams_[key][i];
- THCStreamGuard guard(thcState_, ncclStream);
+ at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
+ at::cuda::CUDAGuard guard(ncclStream);
for (size_t j = 0; j < outputTensors[0].size(); ++j) {
outputTensors[i][j].copy_(flattenOutputTensors[i][j], true);
}
@@ -574,11 +573,10 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather(
// Event should only be recorded after the ncclGroupEnd()
for (size_t i = 0; i < inputTensors.size(); ++i) {
- CUDAStream& ncclStream = ncclStreams_[key][i];
+ at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i];
CUDAEvent& cudaEvent = work->cudaEvents_[i];
- C10D_CUDA_CHECK(
- cudaEventRecord(cudaEvent.getEvent(), ncclStream.getStream()));
+ C10D_CUDA_CHECK(cudaEventRecord(cudaEvent.getEvent(), ncclStream.stream()));
}
return work;
}
diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp
index 3eca7c4d95..b060ac0b70 100644
--- a/torch/lib/c10d/ProcessGroupNCCL.hpp
+++ b/torch/lib/c10d/ProcessGroupNCCL.hpp
@@ -8,6 +8,8 @@
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
+#include <ATen/cuda/CUDAContext.h>
+
// forward declaration
struct THCState;
@@ -189,7 +191,8 @@ class ProcessGroupNCCL : public ProcessGroup {
devNCCLCommMap_;
// The CUDA steams used by NCCL kernels
- std::unordered_map<std::string, std::vector<CUDAStream>> ncclStreams_;
+ std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>>
+ ncclStreams_;
// The CUDA events used to sync NCCL streams
std::unordered_map<std::string, std::vector<CUDAEvent>> ncclEvents_;
diff --git a/torch/lib/c10d/private/CUDAUtils.hpp b/torch/lib/c10d/private/CUDAUtils.hpp
index 2375e02d15..d3bb8a33d5 100644
--- a/torch/lib/c10d/private/CUDAUtils.hpp
+++ b/torch/lib/c10d/private/CUDAUtils.hpp
@@ -11,6 +11,7 @@
#include <c10d/CUDAUtils.hpp>
+// TODO: Use AT_CHECK or similar here
#define C10D_CUDA_CHECK(condition) \
do { \
cudaError_t error = (condition); \
@@ -25,42 +26,3 @@
throw std::runtime_error(ss.str()); \
} \
} while (0)
-
-namespace c10d {
-
-// THCStreamGuard is a RAII guard for selecting a THCStream.
-//
-// It sets both the current device to the stream's device and the
-// current stream in the THC state.
-//
-class THCStreamGuard {
- public:
- explicit THCStreamGuard(THCState* state, CUDAStream& stream)
- : device_(THCStream_device(stream.getTHCStream())), state_(state) {
- at::DeviceGuard deviceGuard(device_);
- original_ = THCState_getStream(state_);
- THCStream_retain(original_);
- THCState_setStream(state_, stream.getTHCStream());
- }
-
- THCStreamGuard(THCStreamGuard&& other)
- : device_(other.device_), state_(nullptr), original_(nullptr) {
- std::swap(state_, other.state_);
- std::swap(original_, other.original_);
- }
-
- ~THCStreamGuard() {
- if (original_ != nullptr) {
- at::DeviceGuard deviceGuard(device_);
- THCState_setStream(state_, original_);
- THCStream_free(original_);
- }
- }
-
- private:
- const int device_;
- THCState* state_;
- THCStream* original_;
-};
-
-} // namespace c10d
diff --git a/torch/lib/c10d/test/CUDATest.cu b/torch/lib/c10d/test/CUDATest.cu
index 6b91229da6..ef8a2780c6 100644
--- a/torch/lib/c10d/test/CUDATest.cu
+++ b/torch/lib/c10d/test/CUDATest.cu
@@ -16,8 +16,8 @@ __global__ void waitClocks(const uint64_t count) {
} // namespace
-void cudaSleep(CUDAStream& stream, uint64_t clocks) {
- waitClocks<<<1, 1, 0, stream.getStream()>>>(clocks);
+void cudaSleep(at::cuda::CUDAStream& stream, uint64_t clocks) {
+ waitClocks<<<1, 1, 0, stream.stream()>>>(clocks);
}
int cudaNumDevices() {
diff --git a/torch/lib/c10d/test/CUDATest.hpp b/torch/lib/c10d/test/CUDATest.hpp
index 352a74e5af..2a46d7fbb6 100644
--- a/torch/lib/c10d/test/CUDATest.hpp
+++ b/torch/lib/c10d/test/CUDATest.hpp
@@ -3,12 +3,13 @@
#include <cuda.h>
#include <cuda_runtime.h>
+#include <ATen/cuda/CUDAStream.h>
#include <c10d/CUDAUtils.hpp>
namespace c10d {
namespace test {
-void cudaSleep(CUDAStream& stream, uint64_t clocks);
+void cudaSleep(at::cuda::CUDAStream& stream, uint64_t clocks);
int cudaNumDevices();
diff --git a/torch/lib/c10d/test/FileStoreTest.cpp b/torch/lib/c10d/test/FileStoreTest.cpp
index b7a5d7e208..c34ab7a094 100644
--- a/torch/lib/c10d/test/FileStoreTest.cpp
+++ b/torch/lib/c10d/test/FileStoreTest.cpp
@@ -57,7 +57,7 @@ void testHelper(const std::string prefix = "") {
const auto numIterations = 100;
c10d::test::Semaphore sem1, sem2;
for (auto i = 0; i < numThreads; i++) {
- threads.push_back(std::move(std::thread([&] {
+ threads.push_back(std::thread([&] {
c10d::FileStore fileStore(path);
c10d::PrefixStore store(prefix, fileStore);
sem1.post();
@@ -65,7 +65,7 @@ void testHelper(const std::string prefix = "") {
for (auto j = 0; j < numIterations; j++) {
store.add("counter", 1);
}
- })));
+ }));
}
sem1.wait(numThreads);
sem2.post(numThreads);
diff --git a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp
index 74e96acc15..1a57500426 100644
--- a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp
+++ b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp
@@ -1,5 +1,7 @@
#include <gloo/transport/tcp/device.h>
+#include <ATen/cuda/CUDAGuard.h>
+
#include <c10d/CUDAUtils.hpp>
#include <c10d/FileStore.hpp>
#include <c10d/ProcessGroupGloo.hpp>
@@ -9,9 +11,8 @@
using namespace c10d::test;
-using c10d::CUDAStream;
+using at::cuda::CUDAStream;
using c10d::ProcessGroup;
-using c10d::THCStreamGuard;
template <typename T, typename... Args>
std::vector<T> initialize(const std::string& path, int N, Args&&... args) {
@@ -22,8 +23,7 @@ std::vector<T> initialize(const std::string& path, int N, Args&&... args) {
std::vector<std::thread> threads;
for (auto i = 0; i < N; i++) {
- threads.push_back(
- std::move(std::thread([i, N, &tests] { tests[i].start(i, N); })));
+ threads.push_back(std::thread([i, N, &tests] { tests[i].start(i, N); }));
}
for (auto& thread : threads) {
@@ -71,11 +71,11 @@ class AsyncInputIsOutputTest : public AsyncTest {
numTensors_(numTensors),
numDevices_(cudaNumDevices()),
state_(::at::globalContext().lazyInitCUDA()) {
-
// Allocate inputs on available devices in a round robin fashion.
inputs_.resize(numTensors_);
for (auto i = 0; i < numTensors_; i++) {
- inputs_[i] = at::empty({16, 16}, at::device({at::kCUDA, i % numDevices_}));
+ inputs_[i] =
+ at::empty({16, 16}, at::device({at::kCUDA, i % numDevices_}));
}
// Allocate a stream per device.
@@ -89,14 +89,14 @@ class AsyncInputIsOutputTest : public AsyncTest {
streams_.resize(numDevices_);
for (auto i = 0; i < numDevices_; i++) {
deviceGuard.set_index(i);
- streams_[i] = CUDAStream::create();
+ streams_[i] = at::cuda::createCUDAStream();
}
}
- std::vector<THCStreamGuard> createStreamGuard() {
- std::vector<THCStreamGuard> guards;
+ std::vector<at::cuda::CUDAGuard> createStreamGuard() {
+ std::vector<at::cuda::CUDAGuard> guards;
for (auto& stream : streams_) {
- guards.push_back(std::move(THCStreamGuard(state_, stream)));
+ guards.push_back(at::cuda::CUDAGuard(stream));
}
return guards;
}
diff --git a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp
index 41484b1007..1a03ded0e7 100644
--- a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp
+++ b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp
@@ -34,10 +34,10 @@ class SignalTest {
// Arms test to send signal to PID when the semaphore unlocks. This
// happens as soon as the first collective completes successfully.
void arm(int pid, int signal) {
- arm_ = std::move(std::thread([=] {
+ arm_ = std::thread([=] {
sem_.wait();
kill(pid, signal);
- }));
+ });
}
std::shared_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) {
@@ -66,7 +66,7 @@ class SignalTest {
sem_.post();
}
- return std::move(work);
+ return work;
}
protected:
@@ -97,19 +97,19 @@ class CollectiveTest {
int num) {
std::vector<CollectiveTest> tests;
for (auto i = 0; i < num; i++) {
- tests.push_back(std::move(CollectiveTest(path)));
+ tests.push_back(CollectiveTest(path));
}
std::vector<std::thread> threads;
for (auto i = 0; i < num; i++) {
- threads.push_back(std::move(
- std::thread([i, &tests] { tests[i].start(i, tests.size()); })));
+ threads.push_back(
+ std::thread([i, &tests] { tests[i].start(i, tests.size()); }));
}
for (auto& thread : threads) {
thread.join();
}
- return std::move(tests);
+ return tests;
}
CollectiveTest(const std::string& path) : path_(path) {}
@@ -151,7 +151,7 @@ std::vector<std::vector<at::Tensor>> copyTensors(
for (size_t j = 0; j < input.size(); j++) {
output[j] = input[j].cpu();
}
- outputs[i] = std::move(output);
+ outputs[i] = output;
}
return outputs;
}
@@ -163,8 +163,7 @@ void testAllreduce(const std::string& path, const at::Backend b) {
// Generate inputs
std::vector<std::vector<at::Tensor>> inputs(size);
for (auto i = 0; i < size; i++) {
- auto tensor =
- at::ones({16, 16}, b) * i;
+ auto tensor = at::ones({16, 16}, b) * i;
inputs[i] = std::vector<at::Tensor>({tensor});
}
diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp
index 23ffebc4a7..99f00af374 100644
--- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp
+++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp
@@ -7,11 +7,13 @@
#include <c10d/test/CUDATest.hpp>
#include <c10d/test/TestUtils.hpp>
+#include <ATen/cuda/CUDAGuard.h>
+#include <ATen/cuda/CUDAStream.h>
+
using namespace c10d::test;
-using c10d::CUDAStream;
+using at::cuda::CUDAStream;
using c10d::ProcessGroup;
-using c10d::THCStreamGuard;
class NCCLTestBase {
public:
@@ -68,14 +70,14 @@ class NCCLTest : public NCCLTestBase {
streams_.resize(numDevices_);
for (auto i = 0; i < numDevices_; i++) {
deviceGuard.set_index(i);
- streams_[i] = CUDAStream::create();
+ streams_[i] = at::cuda::createCUDAStream();
}
}
- std::vector<THCStreamGuard> createStreamGuard() {
- std::vector<THCStreamGuard> guards;
+ std::vector<at::cuda::CUDAGuard> createStreamGuard() {
+ std::vector<at::cuda::CUDAGuard> guards;
for (auto& stream : streams_) {
- guards.push_back(std::move(THCStreamGuard(state_, stream)));
+ guards.push_back(at::cuda::CUDAGuard(stream));
}
return guards;
}
@@ -93,7 +95,7 @@ class NCCLTest : public NCCLTestBase {
// Copy inputs to outputs
for (auto i = 0; i < numDevices_; i++) {
- cudaStreamSynchronize(streams_[i].getStream());
+ cudaStreamSynchronize(streams_[i].stream());
outputs[i] = inputs_[i].cpu();
}
@@ -111,7 +113,7 @@ class NCCLTest : public NCCLTestBase {
// Copy inputs to outputs
for (auto i = 0; i < numDevices_; ++i) {
- cudaStreamSynchronize(streams_[i].getStream());
+ cudaStreamSynchronize(streams_[i].stream());
for (auto j = 0; j < worldSize_ * numDevices_; ++j) {
outputs[i][j] = outputs_[i][j].cpu();
}
diff --git a/torch/lib/c10d/test/TCPStoreTest.cpp b/torch/lib/c10d/test/TCPStoreTest.cpp
index 77f3bda3f7..8b3f89ccbb 100644
--- a/torch/lib/c10d/test/TCPStoreTest.cpp
+++ b/torch/lib/c10d/test/TCPStoreTest.cpp
@@ -39,7 +39,7 @@ void testHelper(const std::string& prefix = "") {
std::string expectedCounterRes = std::to_string(numThreads * numIterations);
for (auto i = 0; i < numThreads; i++) {
- threads.push_back(std::move(
+ threads.push_back(
std::thread([&sem1, &sem2, &clientStores, i, &expectedCounterRes] {
for (auto j = 0; j < numIterations; j++) {
clientStores[i]->add("counter", 1);
@@ -65,7 +65,7 @@ void testHelper(const std::string& prefix = "") {
std::string val = "thread_val_" + std::to_string(numIterations - 1);
c10d::test::check(*clientStores[i], key, val);
}
- })));
+ }));
}
sem1.wait(numThreads);