diff options
author | soumith <soumith@fb.com> | 2016-10-02 11:37:43 -0700 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-10-02 11:45:46 -0700 |
commit | 833bedb46b605185c452a8f09c4e5f8d3d1117e8 (patch) | |
tree | a0a34ef67ac64a89576f4fa429776df7fa9318b6 /torch/backends | |
parent | 3d8eba7b42ade3bd677447b1c6dd3f6a0cec14de (diff) | |
download | pytorch-833bedb46b605185c452a8f09c4e5f8d3d1117e8.tar.gz pytorch-833bedb46b605185c452a8f09c4e5f8d3d1117e8.tar.bz2 pytorch-833bedb46b605185c452a8f09c4e5f8d3d1117e8.zip |
cudnn relative check in binary builds
Diffstat (limited to 'torch/backends')
-rw-r--r-- | torch/backends/cudnn/__init__.py | 24 |
1 files changed, 20 insertions, 4 deletions
diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index 938d1b4fb4..1ccda30fb3 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -1,16 +1,32 @@ import ctypes import warnings import torch.cuda +import os.path as path lib = None +# TODO: fix libname for OSX / Windows +# TODO: just load 5.1, not 5.1.3 +# TODO: dynamic version checks via cudnnGetVersion libname = 'libcudnn.so.5.1.3' - +thisdir = path.dirname(__file__) +libpaths = ['', path.join(thisdir, '../../lib')] def _loadlib(): global lib - lib = ctypes.cdll.LoadLibrary(libname) - lib.cudnnGetErrorString.restype = ctypes.c_char_p - + loaded = False + for libpath in libpaths: + try: + lib = ctypes.cdll.LoadLibrary(path.join(libpath, libname)) + loaded = True + break + except OSError: + continue + + if loaded: + lib.cudnnGetErrorString.restype = ctypes.c_char_p + else: + lib = None + raise OSError("Could not load cuDNN") def is_acceptable(tensor): if not (isinstance(tensor, torch.cuda.HalfTensor) or |