summaryrefslogtreecommitdiff
path: root/test/test_cuda_primary_ctx.py
diff options
context:
space:
mode:
authorZachary DeVito <zdevito@fb.com>2019-02-22 13:37:26 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-22 13:51:45 -0800
commit356a94b64e69aa2e87b33e7bbd0b302feca327a4 (patch)
tree5e864de50ec91b836f8dc05f8f91c9691c0b4cc2 /test/test_cuda_primary_ctx.py
parent81b43202ae1eaa44ada84b193ae0fbfe0d11bc91 (diff)
downloadpytorch-356a94b64e69aa2e87b33e7bbd0b302feca327a4.tar.gz
pytorch-356a94b64e69aa2e87b33e7bbd0b302feca327a4.tar.bz2
pytorch-356a94b64e69aa2e87b33e7bbd0b302feca327a4.zip
Lazily load libcuda libnvrtc from c++ (#17317)
Summary: Fixes https://github.com/pytorch/pytorch/issues/16860 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17317 Differential Revision: D14157877 Pulled By: zdevito fbshipit-source-id: c37aec2d77c2e637d4fc6ceffe2bd32901c70317
Diffstat (limited to 'test/test_cuda_primary_ctx.py')
-rw-r--r--test/test_cuda_primary_ctx.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/test/test_cuda_primary_ctx.py b/test/test_cuda_primary_ctx.py
index 68aab85bef..2211a13487 100644
--- a/test/test_cuda_primary_ctx.py
+++ b/test/test_cuda_primary_ctx.py
@@ -2,6 +2,8 @@ import ctypes
import torch
from common_utils import TestCase, run_tests, skipIfRocm
import unittest
+import glob
+import os
# NOTE: this needs to be run in a brand new process
@@ -17,10 +19,17 @@ if not TEST_CUDA:
TestCase = object # noqa: F811
+_thnvrtc = None
+
+
def get_is_primary_context_created(device):
flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint))
active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
- result = torch.cuda.cudart().cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
+ global _thnvrtc
+ if _thnvrtc is None:
+ path = glob.glob('{}/lib/libthnvrtc.*'.format(os.path.dirname(torch.__file__)))[0]
+ _thnvrtc = ctypes.cdll.LoadLibrary(path)
+ result = _thnvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
assert result == 0, 'cuDevicePrimaryCtxGetState failed'
return bool(active[0])