summaryrefslogtreecommitdiff
path: root/test/test_cuda_primary_ctx.py
blob: 68aab85bef0cbbf2fd1b63b8c87db0c25bcd44b3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import ctypes
import torch
from common_utils import TestCase, run_tests, skipIfRocm
import unittest

# NOTE: this needs to be run in a brand new process

# We cannot import TEST_CUDA and TEST_MULTIGPU from common_cuda here,
# because if we do that, the TEST_CUDNN line from common_cuda will be executed
# multiple times as well during the execution of this test suite, and it will
# cause CUDA OOM error on Windows.
TEST_CUDA = torch.cuda.is_available()
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2

if not TEST_CUDA:
    print('CUDA not available, skipping tests')
    TestCase = object  # noqa: F811


def get_is_primary_context_created(device):
    flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint))
    active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    result = torch.cuda.cudart().cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
    assert result == 0, 'cuDevicePrimaryCtxGetState failed'
    return bool(active[0])


class TestCudaPrimaryCtx(TestCase):
    @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
    @skipIfRocm
    def test_cuda_primary_ctx(self):
        # Ensure context has not been created beforehand
        self.assertFalse(get_is_primary_context_created(0))
        self.assertFalse(get_is_primary_context_created(1))

        x = torch.randn(1, device='cuda:1')

        # We should have only created context on 'cuda:1'
        self.assertFalse(get_is_primary_context_created(0))
        self.assertTrue(get_is_primary_context_created(1))

        print(x)

        # We should still have only created context on 'cuda:1'
        self.assertFalse(get_is_primary_context_created(0))
        self.assertTrue(get_is_primary_context_created(1))

        y = torch.randn(1, device='cpu')
        y.copy_(x)

        # We should still have only created context on 'cuda:1'
        self.assertFalse(get_is_primary_context_created(0))
        self.assertTrue(get_is_primary_context_created(1))

    # DO NOT ADD ANY OTHER TESTS HERE!  ABOVE TEST REQUIRES FRESH PROCESS

if __name__ == '__main__':
    run_tests()