diff options
author | Zachary DeVito <zdevito@fb.com> | 2019-02-22 13:37:26 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-22 13:51:45 -0800 |
commit | 356a94b64e69aa2e87b33e7bbd0b302feca327a4 (patch) | |
tree | 5e864de50ec91b836f8dc05f8f91c9691c0b4cc2 /test/test_cuda_primary_ctx.py | |
parent | 81b43202ae1eaa44ada84b193ae0fbfe0d11bc91 (diff) | |
download | pytorch-356a94b64e69aa2e87b33e7bbd0b302feca327a4.tar.gz pytorch-356a94b64e69aa2e87b33e7bbd0b302feca327a4.tar.bz2 pytorch-356a94b64e69aa2e87b33e7bbd0b302feca327a4.zip |
Lazily load libcuda libnvrtc from c++ (#17317)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/16860
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17317
Differential Revision: D14157877
Pulled By: zdevito
fbshipit-source-id: c37aec2d77c2e637d4fc6ceffe2bd32901c70317
Diffstat (limited to 'test/test_cuda_primary_ctx.py')
-rw-r--r-- | test/test_cuda_primary_ctx.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/test/test_cuda_primary_ctx.py b/test/test_cuda_primary_ctx.py index 68aab85bef..2211a13487 100644 --- a/test/test_cuda_primary_ctx.py +++ b/test/test_cuda_primary_ctx.py @@ -2,6 +2,8 @@ import ctypes import torch from common_utils import TestCase, run_tests, skipIfRocm import unittest +import glob +import os # NOTE: this needs to be run in a brand new process @@ -17,10 +19,17 @@ if not TEST_CUDA: TestCase = object # noqa: F811 +_thnvrtc = None + + def get_is_primary_context_created(device): flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint)) active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) - result = torch.cuda.cudart().cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active) + global _thnvrtc + if _thnvrtc is None: + path = glob.glob('{}/lib/libthnvrtc.*'.format(os.path.dirname(torch.__file__)))[0] + _thnvrtc = ctypes.cdll.LoadLibrary(path) + result = _thnvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active) assert result == 0, 'cuDevicePrimaryCtxGetState failed' return bool(active[0]) |