summaryrefslogtreecommitdiff
path: root/tools/setup_helpers
diff options
context:
space:
mode:
authorJon Crall <erotemic@gmail.com>2018-01-11 07:24:17 -0500
committerSoumith Chintala <soumith@gmail.com>2018-01-11 07:24:17 -0500
commit94f439c07c4f3e13fbd87016212b3534d31a7165 (patch)
tree303e4ac743b1bfa48d4f6f5e18e7857f3b661e29 /tools/setup_helpers
parent0988e328c983b436a25cdcdd1fbbfcf4c054de43 (diff)
downloadpytorch-94f439c07c4f3e13fbd87016212b3534d31a7165.tar.gz
pytorch-94f439c07c4f3e13fbd87016212b3534d31a7165.tar.bz2
pytorch-94f439c07c4f3e13fbd87016212b3534d31a7165.zip
Fixed setup.py to handle CUDNN_LIBRARY envvar with aten (#4597)
* Fixed setup.py to handle CUDNN_LIBRARY envvar with aten * undo changes * Added CUDNN_LIBRARY to bat file
Diffstat (limited to 'tools/setup_helpers')
-rw-r--r--tools/setup_helpers/cudnn.py24
1 files changed, 20 insertions, 4 deletions
diff --git a/tools/setup_helpers/cudnn.py b/tools/setup_helpers/cudnn.py
index 8bff3e5ed8..87efcfda3b 100644
--- a/tools/setup_helpers/cudnn.py
+++ b/tools/setup_helpers/cudnn.py
@@ -19,6 +19,7 @@ CONDA_DIR = os.path.join(os.path.dirname(sys.executable), '..')
WITH_CUDNN = False
CUDNN_LIB_DIR = None
CUDNN_INCLUDE_DIR = None
+CUDNN_LIBRARY = None
if WITH_CUDA and not check_env_flag('NO_CUDNN'):
lib_paths = list(filter(bool, [
os.getenv('CUDNN_LIB_DIR'),
@@ -49,11 +50,15 @@ if WITH_CUDA and not check_env_flag('NO_CUDNN'):
if path is None or not os.path.exists(path):
continue
if IS_WINDOWS:
- if os.path.exists(os.path.join(path, 'cudnn.lib')):
+ library = os.path.join(path, 'cudnn.lib')
+ if os.path.exists(library):
+ CUDNN_LIBRARY = library
CUDNN_LIB_DIR = path
break
else:
- if glob.glob(os.path.join(path, 'libcudnn*')):
+ libraries = sorted(glob.glob(os.path.join(path, 'libcudnn*')))
+ if libraries:
+ CUDNN_LIBRARY = libraries[0]
CUDNN_LIB_DIR = path
break
for path in include_paths:
@@ -62,7 +67,18 @@ if WITH_CUDA and not check_env_flag('NO_CUDNN'):
if os.path.exists((os.path.join(path, 'cudnn.h'))):
CUDNN_INCLUDE_DIR = path
break
- if not CUDNN_LIB_DIR or not CUDNN_INCLUDE_DIR:
- CUDNN_LIB_DIR = CUDNN_INCLUDE_DIR = None
+
+ # Specifying the library directly will overwrite the lib directory
+ library = os.getenv('CUDNN_LIBRARY')
+ if library is not None and os.path.exists(library):
+ CUDNN_LIBRARY = library
+ CUDNN_LIB_DIR = os.path.dirname(CUDNN_LIBRARY)
+
+ if not all([CUDNN_LIBRARY, CUDNN_LIB_DIR, CUDNN_INCLUDE_DIR]):
+ CUDNN_LIBRARY = CUDNN_LIB_DIR = CUDNN_INCLUDE_DIR = None
else:
+ real_cudnn_library = os.path.realpath(CUDNN_LIBRARY)
+ real_cudnn_lib_dir = os.path.realpath(CUDNN_LIB_DIR)
+ assert os.path.dirname(real_cudnn_library) == real_cudnn_lib_dir, (
+ 'cudnn library and lib_dir must agree')
WITH_CUDNN = True