diff options
author | Jon Crall <erotemic@gmail.com> | 2018-01-11 07:24:17 -0500 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2018-01-11 07:24:17 -0500 |
commit | 94f439c07c4f3e13fbd87016212b3534d31a7165 (patch) | |
tree | 303e4ac743b1bfa48d4f6f5e18e7857f3b661e29 /tools/setup_helpers | |
parent | 0988e328c983b436a25cdcdd1fbbfcf4c054de43 (diff) | |
download | pytorch-94f439c07c4f3e13fbd87016212b3534d31a7165.tar.gz pytorch-94f439c07c4f3e13fbd87016212b3534d31a7165.tar.bz2 pytorch-94f439c07c4f3e13fbd87016212b3534d31a7165.zip |
Fixed setup.py to handle CUDNN_LIBRARY envvar with aten (#4597)
* Fixed setup.py to handle CUDNN_LIBRARY envvar with aten
* undo changes
* Added CUDNN_LIBRARY to bat file
Diffstat (limited to 'tools/setup_helpers')
-rw-r--r-- | tools/setup_helpers/cudnn.py | 24 |
1 files changed, 20 insertions, 4 deletions
diff --git a/tools/setup_helpers/cudnn.py b/tools/setup_helpers/cudnn.py index 8bff3e5ed8..87efcfda3b 100644 --- a/tools/setup_helpers/cudnn.py +++ b/tools/setup_helpers/cudnn.py @@ -19,6 +19,7 @@ CONDA_DIR = os.path.join(os.path.dirname(sys.executable), '..') WITH_CUDNN = False CUDNN_LIB_DIR = None CUDNN_INCLUDE_DIR = None +CUDNN_LIBRARY = None if WITH_CUDA and not check_env_flag('NO_CUDNN'): lib_paths = list(filter(bool, [ os.getenv('CUDNN_LIB_DIR'), @@ -49,11 +50,15 @@ if WITH_CUDA and not check_env_flag('NO_CUDNN'): if path is None or not os.path.exists(path): continue if IS_WINDOWS: - if os.path.exists(os.path.join(path, 'cudnn.lib')): + library = os.path.join(path, 'cudnn.lib') + if os.path.exists(library): + CUDNN_LIBRARY = library CUDNN_LIB_DIR = path break else: - if glob.glob(os.path.join(path, 'libcudnn*')): + libraries = sorted(glob.glob(os.path.join(path, 'libcudnn*'))) + if libraries: + CUDNN_LIBRARY = libraries[0] CUDNN_LIB_DIR = path break for path in include_paths: @@ -62,7 +67,18 @@ if WITH_CUDA and not check_env_flag('NO_CUDNN'): if os.path.exists((os.path.join(path, 'cudnn.h'))): CUDNN_INCLUDE_DIR = path break - if not CUDNN_LIB_DIR or not CUDNN_INCLUDE_DIR: - CUDNN_LIB_DIR = CUDNN_INCLUDE_DIR = None + + # Specifying the library directly will overwrite the lib directory + library = os.getenv('CUDNN_LIBRARY') + if library is not None and os.path.exists(library): + CUDNN_LIBRARY = library + CUDNN_LIB_DIR = os.path.dirname(CUDNN_LIBRARY) + + if not all([CUDNN_LIBRARY, CUDNN_LIB_DIR, CUDNN_INCLUDE_DIR]): + CUDNN_LIBRARY = CUDNN_LIB_DIR = CUDNN_INCLUDE_DIR = None else: + real_cudnn_library = os.path.realpath(CUDNN_LIBRARY) + real_cudnn_lib_dir = os.path.realpath(CUDNN_LIB_DIR) + assert os.path.dirname(real_cudnn_library) == real_cudnn_lib_dir, ( + 'cudnn library and lib_dir must agree') WITH_CUDNN = True |