summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPeter Goldsborough <peter@goldsborough.me>2018-02-23 07:15:30 -0800
committerSoumith Chintala <soumith@gmail.com>2018-02-23 10:15:30 -0500
commit008ba18c5bc2a8fb5f061e812c18c531887d946e (patch)
tree5b1bea6274f7f4b85a36de2281f278fc656e1913
parente2519e7dd15de727e2da0d1e9220dd662be39df7 (diff)
downloadpytorch-008ba18c5bc2a8fb5f061e812c18c531887d946e.tar.gz
pytorch-008ba18c5bc2a8fb5f061e812c18c531887d946e.tar.bz2
pytorch-008ba18c5bc2a8fb5f061e812c18c531887d946e.zip
Improve CUDA extension support (#5324)
* Also pass torch includes to nvcc build * Export ATen/cuda headers with install * Refactor flags common to C++ and CUDA * Improve tests for C++/CUDA extensions * Export .cuh files under THC * Refactor and clean cpp_extension.py slightly * Include ATen in cuda extension test * Clarifying comment in cuda_extension.cu * Replace cuda_extension.cu with cuda_extension_kernel.cu in setup.py * Copy compile args in C++ extension and add second kernel * Conditionally add -std=c++11 to cuda_flags * Also export cuDNN headers * Add comment about deepcopy
-rw-r--r--setup.py4
-rw-r--r--test/cpp_extensions/cuda_extension.cpp2
-rw-r--r--test/cpp_extensions/cuda_extension.cu29
-rw-r--r--test/cpp_extensions/cuda_extension_kernel.cu2
-rw-r--r--test/cpp_extensions/cuda_extension_kernel2.cu23
-rw-r--r--test/cpp_extensions/setup.py7
-rw-r--r--test/test_cpp_extensions.py2
-rw-r--r--torch/utils/cpp_extension.py73
8 files changed, 115 insertions, 27 deletions
diff --git a/setup.py b/setup.py
index a0066083a3..9b3d45d09d 100644
--- a/setup.py
+++ b/setup.py
@@ -864,11 +864,15 @@ if __name__ == '__main__':
'lib/torch_shm_manager',
'lib/*.h',
'lib/include/ATen/*.h',
+ 'lib/include/ATen/cuda/*.cuh',
+ 'lib/include/ATen/cudnn/*.h',
+ 'lib/include/ATen/cuda/detail/*.cuh',
'lib/include/pybind11/*.h',
'lib/include/pybind11/detail/*.h',
'lib/include/TH/*.h',
'lib/include/TH/generic/*.h',
'lib/include/THC/*.h',
+ 'lib/include/THC/*.cuh',
'lib/include/THC/generic/*.h',
'lib/include/torch/csrc/*.h',
'lib/include/torch/csrc/autograd/*.h',
diff --git a/test/cpp_extensions/cuda_extension.cpp b/test/cpp_extensions/cuda_extension.cpp
index f772843d51..4c1703a873 100644
--- a/test/cpp_extensions/cuda_extension.cpp
+++ b/test/cpp_extensions/cuda_extension.cpp
@@ -1,6 +1,6 @@
#include <torch/torch.h>
-// Declare the function from cuda_extension_kernel.cu. It will be compiled
+// Declare the function from cuda_extension.cu. It will be compiled
// separately with nvcc and linked with the object file of cuda_extension.cpp
// into one shared library.
void sigmoid_add_cuda(const float* x, const float* y, float* output, int size);
diff --git a/test/cpp_extensions/cuda_extension.cu b/test/cpp_extensions/cuda_extension.cu
new file mode 100644
index 0000000000..29511af8a0
--- /dev/null
+++ b/test/cpp_extensions/cuda_extension.cu
@@ -0,0 +1,29 @@
+// NOTE: This is a copy of cuda_extension_kernel.cu. It's kept here to test
+// collision handling when a C++ file and CUDA file share the same filename.
+// Setuptools can't deal with this at all, so the setup.py-based test uses
+// cuda_extension_kernel.cu and the JIT test uses this file. Symlinks don't
+// work well on Windows, so this is the most thorough solution right now.
+
+#include <cuda.h>
+#include <cuda_runtime.h>
+
+#include <ATen/ATen.h>
+
+__global__ void sigmoid_add_kernel(
+ const float* __restrict__ x,
+ const float* __restrict__ y,
+ float* __restrict__ output,
+ const int size) {
+ const int index = blockIdx.x * blockDim.x + threadIdx.x;
+ if (index < size) {
+ const float sigmoid_x = 1.0f / (1.0f + __expf(-x[index]));
+ const float sigmoid_y = 1.0f / (1.0f + __expf(-y[index]));
+ output[index] = sigmoid_x + sigmoid_y;
+ }
+}
+
+void sigmoid_add_cuda(const float* x, const float* y, float* output, int size) {
+ const int threads = 1024;
+ const int blocks = (size + threads - 1) / threads;
+ sigmoid_add_kernel<<<blocks, threads>>>(x, y, output, size);
+}
diff --git a/test/cpp_extensions/cuda_extension_kernel.cu b/test/cpp_extensions/cuda_extension_kernel.cu
index 686cfaa494..6602199898 100644
--- a/test/cpp_extensions/cuda_extension_kernel.cu
+++ b/test/cpp_extensions/cuda_extension_kernel.cu
@@ -1,6 +1,8 @@
#include <cuda.h>
#include <cuda_runtime.h>
+#include <ATen/ATen.h>
+
__global__ void sigmoid_add_kernel(
const float* __restrict__ x,
const float* __restrict__ y,
diff --git a/test/cpp_extensions/cuda_extension_kernel2.cu b/test/cpp_extensions/cuda_extension_kernel2.cu
new file mode 100644
index 0000000000..817bdf64ac
--- /dev/null
+++ b/test/cpp_extensions/cuda_extension_kernel2.cu
@@ -0,0 +1,23 @@
+#include <cuda.h>
+#include <cuda_runtime.h>
+
+#include <ATen/ATen.h>
+
+__global__ void tanh_add_kernel(
+ const float* __restrict__ x,
+ const float* __restrict__ y,
+ float* __restrict__ output,
+ const int size) {
+ const int index = blockIdx.x * blockDim.x + threadIdx.x;
+ if (index < size) {
+ const float tanh_x = 2.0f / (1.0f + __expf(-2.0f * x[index])) - 1;
+ const float tanh_y = 2.0f / (1.0f + __expf(-2.0f * y[index])) - 1;
+ output[index] = tanh_x + tanh_y;
+ }
+}
+
+void tanh_add_cuda(const float* x, const float* y, float* output, int size) {
+ const int threads = 1024;
+ const int blocks = (size + threads - 1) / threads;
+ tanh_add_kernel<<<blocks, threads>>>(x, y, output, size);
+}
diff --git a/test/cpp_extensions/setup.py b/test/cpp_extensions/setup.py
index 014f074a48..d85dbd1cbc 100644
--- a/test/cpp_extensions/setup.py
+++ b/test/cpp_extensions/setup.py
@@ -10,8 +10,11 @@ ext_modules = [
if torch.cuda.is_available():
extension = CUDAExtension(
- 'torch_test_cuda_extension',
- ['cuda_extension.cpp', 'cuda_extension_kernel.cu'],
+ 'torch_test_cuda_extension', [
+ 'cuda_extension.cpp',
+ 'cuda_extension_kernel.cu',
+ 'cuda_extension_kernel2.cu',
+ ],
extra_compile_args={'cxx': ['-g'],
'nvcc': ['-O2']})
ext_modules.append(extension)
diff --git a/test/test_cpp_extensions.py b/test/test_cpp_extensions.py
index 3ff66a5382..bbf82d881e 100644
--- a/test/test_cpp_extensions.py
+++ b/test/test_cpp_extensions.py
@@ -67,7 +67,7 @@ class TestCppExtension(common.TestCase):
name='torch_test_cuda_extension',
sources=[
'cpp_extensions/cuda_extension.cpp',
- 'cpp_extensions/cuda_extension_kernel.cu'
+ 'cpp_extensions/cuda_extension.cu'
],
extra_cuda_cflags=['-O2'],
verbose=True)
diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py
index 66f3aeb7ad..058e404072 100644
--- a/torch/utils/cpp_extension.py
+++ b/torch/utils/cpp_extension.py
@@ -1,3 +1,4 @@
+import copy
import glob
import imp
import os
@@ -82,23 +83,18 @@ class BuildExtension(build_ext):
'''A custom build extension for adding compiler-specific options.'''
def build_extensions(self):
- # On some platforms, like Windows, compiler_cxx is not available.
- if hasattr(self.compiler, 'compiler_cxx'):
- compiler = self.compiler.compiler_cxx[0]
- else:
- compiler = os.environ.get('CXX', 'c++')
- check_compiler_abi_compatibility(compiler)
-
+ self._check_abi()
for extension in self.extensions:
- define = '-DTORCH_EXTENSION_NAME={}'.format(extension.name)
- extension.extra_compile_args = [define]
+ self._define_torch_extension_name(extension)
# Register .cu and .cuh as valid source extensions.
self.compiler.src_extensions += ['.cu', '.cuh']
# Save the original _compile method for later.
original_compile = self.compiler._compile
- def wrap_compile(obj, src, ext, cc_args, cflags, pp_opts):
+ def wrap_compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
+ # Copy before we make any modifications.
+ cflags = copy.deepcopy(extra_postargs)
try:
original_compiler = self.compiler.compiler_so
if _is_cuda_file(src):
@@ -106,10 +102,12 @@ class BuildExtension(build_ext):
self.compiler.set_executable('compiler_so', nvcc)
if isinstance(cflags, dict):
cflags = cflags['nvcc']
- cflags += ['-c', '--compiler-options', "'-fPIC'"]
- else:
- if isinstance(cflags, dict):
+ cflags += ['--compiler-options', "'-fPIC'"]
+ elif isinstance(cflags, dict):
cflags = cflags['cxx']
+ # NVCC does not allow multiple -std to be passed, so we avoid
+ # overriding the option if the user explicitly passed it.
+ if not any(flag.startswith('-std=') for flag in cflags):
cflags.append('-std=c++11')
original_compile(obj, src, ext, cc_args, cflags, pp_opts)
@@ -122,6 +120,22 @@ class BuildExtension(build_ext):
build_ext.build_extensions(self)
+ def _check_abi(self):
+ # On some platforms, like Windows, compiler_cxx is not available.
+ if hasattr(self.compiler, 'compiler_cxx'):
+ compiler = self.compiler.compiler_cxx[0]
+ else:
+ compiler = os.environ.get('CXX', 'c++')
+ check_compiler_abi_compatibility(compiler)
+
+ def _define_torch_extension_name(self, extension):
+ define = '-DTORCH_EXTENSION_NAME={}'.format(extension.name)
+ if isinstance(extension.extra_compile_args, dict):
+ for args in extension.extra_compile_args.values():
+ args.append(define)
+ else:
+ extension.extra_compile_args.append(define)
+
def CppExtension(name, sources, *args, **kwargs):
'''
@@ -178,7 +192,10 @@ def include_paths(cuda=False):
'''
here = os.path.abspath(__file__)
torch_path = os.path.dirname(os.path.dirname(here))
- paths = [os.path.join(torch_path, 'lib', 'include')]
+ lib_include = os.path.join(torch_path, 'lib', 'include')
+ # Some internal (old) Torch headers don't properly prefix their includes,
+ # so we need to pass -Itorch/lib/include/TH as well.
+ paths = [lib_include, os.path.join(lib_include, 'TH')]
if cuda:
paths.append(_join_cuda_home('include'))
return paths
@@ -356,18 +373,22 @@ def _write_ninja_file(path,
# sysconfig.get_paths()['include'] gives us the location of Python.h
includes.append(sysconfig.get_paths()['include'])
- cflags = ['-fPIC', '-std=c++11', '-DTORCH_EXTENSION_NAME={}'.format(name)]
- cflags += ['-I{}'.format(include) for include in includes]
- cflags += extra_cflags
+ common_cflags = ['-DTORCH_EXTENSION_NAME={}'.format(name)]
+ common_cflags += ['-I{}'.format(include) for include in includes]
+
+ cflags = common_cflags + ['-fPIC', '-std=c++11'] + extra_cflags
flags = ['cflags = {}'.format(' '.join(cflags))]
if with_cuda:
- cuda_flags = "--compiler-options '-fPIC'"
- extra_flags = ' '.join(extra_cuda_cflags)
- flags.append('cuda_flags = {} {}'.format(cuda_flags, extra_flags))
+ cuda_flags = common_cflags
+ cuda_flags += ['--compiler-options', "'-fPIC'"]
+ cuda_flags += extra_cuda_cflags
+ if not any(flag.startswith('-std=') for flag in cuda_flags):
+ cuda_flags.append('-std=c++11')
+ flags.append('cuda_flags = {}'.format(' '.join(cuda_flags)))
ldflags = ['-shared'] + extra_ldflags
- # The darwin linker needs explicit consent to ignore unresolved symbols
+ # The darwin linker needs explicit consent to ignore unresolved symbols.
if sys.platform == 'darwin':
ldflags.append('-undefined dynamic_lookup')
flags.append('ldflags = {}'.format(' '.join(ldflags)))
@@ -393,9 +414,15 @@ def _write_ninja_file(path,
for source_file in sources:
# '/path/to/file.cpp' -> 'file'
file_name = os.path.splitext(os.path.basename(source_file))[0]
- target = '{}.o'.format(file_name)
+ if _is_cuda_file(source_file):
+ rule = 'cuda_compile'
+ # Use a different object filename in case a C++ and CUDA file have
+ # the same filename but different extension (.cpp vs. .cu).
+ target = '{}.cuda.o'.format(file_name)
+ else:
+ rule = 'compile'
+ target = '{}.o'.format(file_name)
object_files.append(target)
- rule = 'cuda_compile' if _is_cuda_file(source_file) else 'compile'
build.append('build {}: {} {}'.format(target, rule, source_file))
library_target = '{}.so'.format(name)