diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test_cuda.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/test/test_cuda.py b/test/test_cuda.py index ad52d9b950..88ae5b790d 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1497,6 +1497,21 @@ class TestCuda(TestCase): def test_cuda_synchronize(self): torch.cuda.synchronize() + torch.cuda.synchronize('cuda') + torch.cuda.synchronize('cuda:0') + torch.cuda.synchronize(0) + torch.cuda.synchronize(torch.device('cuda:0')) + + if TEST_MULTIGPU: + torch.cuda.synchronize('cuda:1') + torch.cuda.synchronize(1) + torch.cuda.synchronize(torch.device('cuda:1')) + + with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"): + torch.cuda.synchronize(torch.device("cpu")) + + with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"): + torch.cuda.synchronize("cpu") @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") def test_current_stream(self): |