summaryrefslogtreecommitdiff
path: root/test/test_nn.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_nn.py')
-rw-r--r--test/test_nn.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/test/test_nn.py b/test/test_nn.py
index 649c3ffb3a..ca113ecf21 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -13,6 +13,7 @@ from operator import mul
from collections import OrderedDict
import hashlib
import os
+import threading
import torch
from torch._six import inf, nan
@@ -3837,6 +3838,55 @@ class TestNN(NNTestCase):
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
+ @skipIfRocm
+ def test_cudnn_multiple_threads_same_device(self):
+ # This function is intended to test the lazy creation and reuse of per-thread
+ # cudnn handles on each device in aten/src/ATen/cudnn/Handles.cpp.
+ # Failure here likely indicates something wrong with that logic.
+ weight = torch.ones((1, 1, 2, 2), device='cuda')
+
+ results = {}
+
+ num_threads = 2
+ trials = 2
+ test_iters = 100
+
+ with torch.backends.cudnn.flags(enabled=True):
+ def _worker(t, input):
+ my_stream = torch.cuda.Stream()
+ results[t] = input
+ with torch.cuda.stream(my_stream):
+ for i in range(test_iters):
+ # If all threads are sharing the same cudnn handle,
+ # the following sequence may occur:
+ # thread 0 calls setCuDNNStreamToCurrent()
+ # thread 1 calls setCuDNNStreamToCurrent()
+ # thread 0 launches its raw convolution, which it thinks is in
+ # its own stream, but is actually in thread 1's stream.
+ # thread 0 enqueues its div_, which IS is its own stream,
+ # but now races with its convolution.
+ results[t] = torch.nn.functional.conv2d(results[t], weight, padding=0)
+ results[t].div_(4.0)
+ torch.cuda.current_stream().wait_stream(my_stream)
+
+ for trial in range(trials):
+ for t in range(num_threads):
+ results[t] = torch.ones((1, 1, 2048, 2048), device='cuda')
+
+ threads = [threading.Thread(target=_worker,
+ args=(t, results[t])) for t in range(num_threads)]
+
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ for t in range(num_threads):
+ self.assertEqual(results[t].sum().item(),
+ (2048 - test_iters) * (2048 - test_iters))
+
+ @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
+ @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
@repeat_test_for_types(ALL_TENSORTYPES)
@skipIfRocm
def test_Conv2d_deterministic_cudnn(self, dtype=torch.float):