summaryrefslogtreecommitdiff
path: root/torch/backends
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-10-02 11:37:43 -0700
committersoumith <soumith@fb.com>2016-10-02 11:45:46 -0700
commit833bedb46b605185c452a8f09c4e5f8d3d1117e8 (patch)
treea0a34ef67ac64a89576f4fa429776df7fa9318b6 /torch/backends
parent3d8eba7b42ade3bd677447b1c6dd3f6a0cec14de (diff)
downloadpytorch-833bedb46b605185c452a8f09c4e5f8d3d1117e8.tar.gz
pytorch-833bedb46b605185c452a8f09c4e5f8d3d1117e8.tar.bz2
pytorch-833bedb46b605185c452a8f09c4e5f8d3d1117e8.zip
cudnn relative check in binary builds
Diffstat (limited to 'torch/backends')
-rw-r--r--torch/backends/cudnn/__init__.py24
1 files changed, 20 insertions, 4 deletions
diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py
index 938d1b4fb4..1ccda30fb3 100644
--- a/torch/backends/cudnn/__init__.py
+++ b/torch/backends/cudnn/__init__.py
@@ -1,16 +1,32 @@
import ctypes
import warnings
import torch.cuda
+import os.path as path
lib = None
+# TODO: fix libname for OSX / Windows
+# TODO: just load 5.1, not 5.1.3
+# TODO: dynamic version checks via cudnnGetVersion
libname = 'libcudnn.so.5.1.3'
-
+thisdir = path.dirname(__file__)
+libpaths = ['', path.join(thisdir, '../../lib')]
def _loadlib():
global lib
- lib = ctypes.cdll.LoadLibrary(libname)
- lib.cudnnGetErrorString.restype = ctypes.c_char_p
-
+ loaded = False
+ for libpath in libpaths:
+ try:
+ lib = ctypes.cdll.LoadLibrary(path.join(libpath, libname))
+ loaded = True
+ break
+ except OSError:
+ continue
+
+ if loaded:
+ lib.cudnnGetErrorString.restype = ctypes.c_char_p
+ else:
+ lib = None
+ raise OSError("Could not load cuDNN")
def is_acceptable(tensor):
if not (isinstance(tensor, torch.cuda.HalfTensor) or