summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorWill Feng <yf225@cornell.edu>2018-06-07 15:12:05 -0400
committerGitHub <noreply@github.com>2018-06-07 15:12:05 -0400
commitf2c86532f33f913ca756e492a4072588e3976331 (patch)
treebe6c3aae404f3aa5bdb3d05b7e105fe707f51a25 /test
parent14f5484e0d01f4ec7f5e454112dfc9e21e15856b (diff)
downloadpytorch-f2c86532f33f913ca756e492a4072588e3976331.tar.gz
pytorch-f2c86532f33f913ca756e492a4072588e3976331.tar.bz2
pytorch-f2c86532f33f913ca756e492a4072588e3976331.zip
Fix TEST_CUDA import in test_cuda (#8246)
Diffstat (limited to 'test')
-rw-r--r--test/test_cuda.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 2a4cd7c553..d9bf6e66a4 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -13,8 +13,13 @@ from torch import multiprocessing as mp
from test_torch import TestTorch
from common import TestCase, get_gpu_type, to_gpu, freeze_rng_state, run_tests, PY3
-from common_cuda import TEST_CUDA, TEST_MULTIGPU
+# We cannot import TEST_CUDA and TEST_MULTIGPU from common_cuda here,
+# because if we do that, the TEST_CUDNN line from common_cuda will be executed
+# multiple times as well during the execution of this test suite, and it will
+# cause CUDA OOM error on Windows.
+TEST_CUDA = torch.cuda.is_available()
+TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
if not TEST_CUDA:
print('CUDA not available, skipping tests')