diff options
author | peterjc123 <peter_jiachen@163.com> | 2018-05-01 22:57:16 +0800 |
---|---|---|
committer | Edward Z. Yang <ezyang@mit.edu> | 2018-05-01 10:57:16 -0400 |
commit | 15b12e6f8aa9ba97881ff189d904fadea65d72ca (patch) | |
tree | acef820f4870fe3d79aed239f34fec3bc37e65c8 /tools | |
parent | 7968ee0f5921cdff0e534ab87c41b5f4e0e5f220 (diff) | |
download | pytorch-15b12e6f8aa9ba97881ff189d904fadea65d72ca.tar.gz pytorch-15b12e6f8aa9ba97881ff189d904fadea65d72ca.tar.bz2 pytorch-15b12e6f8aa9ba97881ff189d904fadea65d72ca.zip |
Add support for MKLDNN on Windows (#7130)
Diffstat (limited to 'tools')
-rw-r--r-- | tools/setup_helpers/mkldnn.py | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/tools/setup_helpers/mkldnn.py b/tools/setup_helpers/mkldnn.py index 4723c9ea54..8e6af45dc7 100644 --- a/tools/setup_helpers/mkldnn.py +++ b/tools/setup_helpers/mkldnn.py @@ -4,15 +4,11 @@ import os import sys from itertools import chain -from .env import check_env_flag +from .env import check_env_flag, IS_LINUX, IS_WINDOWS, IS_CONDA, CONDA_DIR def gather_paths(env_vars): - return list(chain(*(os.getenv(v, '').split(':') for v in env_vars))) - -IS_LINUX = platform.system() == 'Linux' -IS_CONDA = 'conda' in sys.version or 'Continuum' in sys.version -CONDA_DIR = os.path.join(os.path.dirname(sys.executable), '..') + return list(chain(*(os.getenv(v, '').split(os.pathsep) for v in env_vars))) MKLDNN_HOME = os.getenv('MKLDNN_HOME', '/usr/local/mkl-dnn') @@ -20,17 +16,20 @@ WITH_MKLDNN = False MKLDNN_LIB_DIR = None MKLDNN_INCLUDE_DIR = None MKLDNN_LIBRARY = None -if IS_LINUX and not check_env_flag('NO_MKLDNN'): +if (IS_LINUX or IS_WINDOWS) and not check_env_flag('NO_MKLDNN'): lib_paths = list(filter(bool, [ os.getenv('MKLDNN_LIB_DIR'), os.path.join(MKLDNN_HOME, 'lib'), os.path.join(MKLDNN_HOME, 'lib64'), + os.path.join(MKLDNN_HOME, 'lib/x64'), '/usr/lib/', '/usr/lib64/', ] + gather_paths([ 'LIBRARY_PATH', ]) + gather_paths([ 'LD_LIBRARY_PATH', + ]) + gather_paths([ + 'LIB' ]))) include_paths = list(filter(bool, [ os.getenv('MKLDNN_INCLUDE_DIR'), @@ -40,7 +39,14 @@ if IS_LINUX and not check_env_flag('NO_MKLDNN'): 'CPATH', 'C_INCLUDE_PATH', 'CPLUS_INCLUDE_PATH', + 'INCLUDE', ]))) + if IS_WINDOWS: + mkldnn_regex = 'mkldnn*.lib' + mklml_regex = 'mklml*.lib' + else: + mkldnn_regex = 'libmkldnn*' + mklml_regex = 'libmklml_intel*' if IS_CONDA: lib_paths.append(os.path.join(CONDA_DIR, 'lib')) include_paths.append(os.path.join(CONDA_DIR, 'include')) @@ -48,9 +54,9 @@ if IS_LINUX and not check_env_flag('NO_MKLDNN'): if path is None or not os.path.exists(path): continue else: - libraries = sorted(glob.glob(os.path.join(path, 'libmkldnn*'))) + libraries = sorted(glob.glob(os.path.join(path, mkldnn_regex))) if libraries: - if not glob.glob(os.path.join(path, 'libmklml_intel*')): + if not glob.glob(os.path.join(path, mklml_regex)): print("WARNING: MKL-DNN is not compiled with Intel MKL small library") print("Convolution performance might be suboptimal") print("Refer https://github.com/01org/mkl-dnn for detail info") |