summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_cuda.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/test/test_cuda.py b/test/test_cuda.py
index ad52d9b950..88ae5b790d 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -1497,6 +1497,21 @@ class TestCuda(TestCase):
def test_cuda_synchronize(self):
torch.cuda.synchronize()
+ torch.cuda.synchronize('cuda')
+ torch.cuda.synchronize('cuda:0')
+ torch.cuda.synchronize(0)
+ torch.cuda.synchronize(torch.device('cuda:0'))
+
+ if TEST_MULTIGPU:
+ torch.cuda.synchronize('cuda:1')
+ torch.cuda.synchronize(1)
+ torch.cuda.synchronize(torch.device('cuda:1'))
+
+ with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"):
+ torch.cuda.synchronize(torch.device("cpu"))
+
+ with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"):
+ torch.cuda.synchronize("cpu")
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
def test_current_stream(self):