summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorpeterjc123 <peter_jiachen@163.com>2018-05-01 22:57:16 +0800
committerEdward Z. Yang <ezyang@mit.edu>2018-05-01 10:57:16 -0400
commit15b12e6f8aa9ba97881ff189d904fadea65d72ca (patch)
treeacef820f4870fe3d79aed239f34fec3bc37e65c8 /tools
parent7968ee0f5921cdff0e534ab87c41b5f4e0e5f220 (diff)
downloadpytorch-15b12e6f8aa9ba97881ff189d904fadea65d72ca.tar.gz
pytorch-15b12e6f8aa9ba97881ff189d904fadea65d72ca.tar.bz2
pytorch-15b12e6f8aa9ba97881ff189d904fadea65d72ca.zip
Add support for MKLDNN on Windows (#7130)
Diffstat (limited to 'tools')
-rw-r--r--tools/setup_helpers/mkldnn.py24
1 files changed, 15 insertions, 9 deletions
diff --git a/tools/setup_helpers/mkldnn.py b/tools/setup_helpers/mkldnn.py
index 4723c9ea54..8e6af45dc7 100644
--- a/tools/setup_helpers/mkldnn.py
+++ b/tools/setup_helpers/mkldnn.py
@@ -4,15 +4,11 @@ import os
import sys
from itertools import chain
-from .env import check_env_flag
+from .env import check_env_flag, IS_LINUX, IS_WINDOWS, IS_CONDA, CONDA_DIR
def gather_paths(env_vars):
- return list(chain(*(os.getenv(v, '').split(':') for v in env_vars)))
-
-IS_LINUX = platform.system() == 'Linux'
-IS_CONDA = 'conda' in sys.version or 'Continuum' in sys.version
-CONDA_DIR = os.path.join(os.path.dirname(sys.executable), '..')
+ return list(chain(*(os.getenv(v, '').split(os.pathsep) for v in env_vars)))
MKLDNN_HOME = os.getenv('MKLDNN_HOME', '/usr/local/mkl-dnn')
@@ -20,17 +16,20 @@ WITH_MKLDNN = False
MKLDNN_LIB_DIR = None
MKLDNN_INCLUDE_DIR = None
MKLDNN_LIBRARY = None
-if IS_LINUX and not check_env_flag('NO_MKLDNN'):
+if (IS_LINUX or IS_WINDOWS) and not check_env_flag('NO_MKLDNN'):
lib_paths = list(filter(bool, [
os.getenv('MKLDNN_LIB_DIR'),
os.path.join(MKLDNN_HOME, 'lib'),
os.path.join(MKLDNN_HOME, 'lib64'),
+ os.path.join(MKLDNN_HOME, 'lib/x64'),
'/usr/lib/',
'/usr/lib64/',
] + gather_paths([
'LIBRARY_PATH',
]) + gather_paths([
'LD_LIBRARY_PATH',
+ ]) + gather_paths([
+ 'LIB'
])))
include_paths = list(filter(bool, [
os.getenv('MKLDNN_INCLUDE_DIR'),
@@ -40,7 +39,14 @@ if IS_LINUX and not check_env_flag('NO_MKLDNN'):
'CPATH',
'C_INCLUDE_PATH',
'CPLUS_INCLUDE_PATH',
+ 'INCLUDE',
])))
+ if IS_WINDOWS:
+ mkldnn_regex = 'mkldnn*.lib'
+ mklml_regex = 'mklml*.lib'
+ else:
+ mkldnn_regex = 'libmkldnn*'
+ mklml_regex = 'libmklml_intel*'
if IS_CONDA:
lib_paths.append(os.path.join(CONDA_DIR, 'lib'))
include_paths.append(os.path.join(CONDA_DIR, 'include'))
@@ -48,9 +54,9 @@ if IS_LINUX and not check_env_flag('NO_MKLDNN'):
if path is None or not os.path.exists(path):
continue
else:
- libraries = sorted(glob.glob(os.path.join(path, 'libmkldnn*')))
+ libraries = sorted(glob.glob(os.path.join(path, mkldnn_regex)))
if libraries:
- if not glob.glob(os.path.join(path, 'libmklml_intel*')):
+ if not glob.glob(os.path.join(path, mklml_regex)):
print("WARNING: MKL-DNN is not compiled with Intel MKL small library")
print("Convolution performance might be suboptimal")
print("Refer https://github.com/01org/mkl-dnn for detail info")