diff options
Diffstat (limited to 'test/test_nn.py')
-rw-r--r-- | test/test_nn.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/test/test_nn.py b/test/test_nn.py index 649c3ffb3a..ca113ecf21 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -13,6 +13,7 @@ from operator import mul from collections import OrderedDict import hashlib import os +import threading import torch from torch._six import inf, nan @@ -3837,6 +3838,55 @@ class TestNN(NNTestCase): @unittest.skipIf(not TEST_CUDA, 'CUDA not available') @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') + @skipIfRocm + def test_cudnn_multiple_threads_same_device(self): + # This function is intended to test the lazy creation and reuse of per-thread + # cudnn handles on each device in aten/src/ATen/cudnn/Handles.cpp. + # Failure here likely indicates something wrong with that logic. + weight = torch.ones((1, 1, 2, 2), device='cuda') + + results = {} + + num_threads = 2 + trials = 2 + test_iters = 100 + + with torch.backends.cudnn.flags(enabled=True): + def _worker(t, input): + my_stream = torch.cuda.Stream() + results[t] = input + with torch.cuda.stream(my_stream): + for i in range(test_iters): + # If all threads are sharing the same cudnn handle, + # the following sequence may occur: + # thread 0 calls setCuDNNStreamToCurrent() + # thread 1 calls setCuDNNStreamToCurrent() + # thread 0 launches its raw convolution, which it thinks is in + # its own stream, but is actually in thread 1's stream. + # thread 0 enqueues its div_, which IS is its own stream, + # but now races with its convolution. + results[t] = torch.nn.functional.conv2d(results[t], weight, padding=0) + results[t].div_(4.0) + torch.cuda.current_stream().wait_stream(my_stream) + + for trial in range(trials): + for t in range(num_threads): + results[t] = torch.ones((1, 1, 2048, 2048), device='cuda') + + threads = [threading.Thread(target=_worker, + args=(t, results[t])) for t in range(num_threads)] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + for t in range(num_threads): + self.assertEqual(results[t].sum().item(), + (2048 - test_iters) * (2048 - test_iters)) + + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') @repeat_test_for_types(ALL_TENSORTYPES) @skipIfRocm def test_Conv2d_deterministic_cudnn(self, dtype=torch.float): |