diff options
-rw-r--r-- | setup.py | 4 | ||||
-rw-r--r-- | test/cpp_extensions/cuda_extension.cpp | 2 | ||||
-rw-r--r-- | test/cpp_extensions/cuda_extension.cu | 29 | ||||
-rw-r--r-- | test/cpp_extensions/cuda_extension_kernel.cu | 2 | ||||
-rw-r--r-- | test/cpp_extensions/cuda_extension_kernel2.cu | 23 | ||||
-rw-r--r-- | test/cpp_extensions/setup.py | 7 | ||||
-rw-r--r-- | test/test_cpp_extensions.py | 2 | ||||
-rw-r--r-- | torch/utils/cpp_extension.py | 73 |
8 files changed, 115 insertions, 27 deletions
@@ -864,11 +864,15 @@ if __name__ == '__main__': 'lib/torch_shm_manager', 'lib/*.h', 'lib/include/ATen/*.h', + 'lib/include/ATen/cuda/*.cuh', + 'lib/include/ATen/cudnn/*.h', + 'lib/include/ATen/cuda/detail/*.cuh', 'lib/include/pybind11/*.h', 'lib/include/pybind11/detail/*.h', 'lib/include/TH/*.h', 'lib/include/TH/generic/*.h', 'lib/include/THC/*.h', + 'lib/include/THC/*.cuh', 'lib/include/THC/generic/*.h', 'lib/include/torch/csrc/*.h', 'lib/include/torch/csrc/autograd/*.h', diff --git a/test/cpp_extensions/cuda_extension.cpp b/test/cpp_extensions/cuda_extension.cpp index f772843d51..4c1703a873 100644 --- a/test/cpp_extensions/cuda_extension.cpp +++ b/test/cpp_extensions/cuda_extension.cpp @@ -1,6 +1,6 @@ #include <torch/torch.h> -// Declare the function from cuda_extension_kernel.cu. It will be compiled +// Declare the function from cuda_extension.cu. It will be compiled // separately with nvcc and linked with the object file of cuda_extension.cpp // into one shared library. void sigmoid_add_cuda(const float* x, const float* y, float* output, int size); diff --git a/test/cpp_extensions/cuda_extension.cu b/test/cpp_extensions/cuda_extension.cu new file mode 100644 index 0000000000..29511af8a0 --- /dev/null +++ b/test/cpp_extensions/cuda_extension.cu @@ -0,0 +1,29 @@ +// NOTE: This is a copy of cuda_extension_kernel.cu. It's kept here to test +// collision handling when a C++ file and CUDA file share the same filename. +// Setuptools can't deal with this at all, so the setup.py-based test uses +// cuda_extension_kernel.cu and the JIT test uses this file. Symlinks don't +// work well on Windows, so this is the most thorough solution right now. + +#include <cuda.h> +#include <cuda_runtime.h> + +#include <ATen/ATen.h> + +__global__ void sigmoid_add_kernel( + const float* __restrict__ x, + const float* __restrict__ y, + float* __restrict__ output, + const int size) { + const int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + const float sigmoid_x = 1.0f / (1.0f + __expf(-x[index])); + const float sigmoid_y = 1.0f / (1.0f + __expf(-y[index])); + output[index] = sigmoid_x + sigmoid_y; + } +} + +void sigmoid_add_cuda(const float* x, const float* y, float* output, int size) { + const int threads = 1024; + const int blocks = (size + threads - 1) / threads; + sigmoid_add_kernel<<<blocks, threads>>>(x, y, output, size); +} diff --git a/test/cpp_extensions/cuda_extension_kernel.cu b/test/cpp_extensions/cuda_extension_kernel.cu index 686cfaa494..6602199898 100644 --- a/test/cpp_extensions/cuda_extension_kernel.cu +++ b/test/cpp_extensions/cuda_extension_kernel.cu @@ -1,6 +1,8 @@ #include <cuda.h> #include <cuda_runtime.h> +#include <ATen/ATen.h> + __global__ void sigmoid_add_kernel( const float* __restrict__ x, const float* __restrict__ y, diff --git a/test/cpp_extensions/cuda_extension_kernel2.cu b/test/cpp_extensions/cuda_extension_kernel2.cu new file mode 100644 index 0000000000..817bdf64ac --- /dev/null +++ b/test/cpp_extensions/cuda_extension_kernel2.cu @@ -0,0 +1,23 @@ +#include <cuda.h> +#include <cuda_runtime.h> + +#include <ATen/ATen.h> + +__global__ void tanh_add_kernel( + const float* __restrict__ x, + const float* __restrict__ y, + float* __restrict__ output, + const int size) { + const int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + const float tanh_x = 2.0f / (1.0f + __expf(-2.0f * x[index])) - 1; + const float tanh_y = 2.0f / (1.0f + __expf(-2.0f * y[index])) - 1; + output[index] = tanh_x + tanh_y; + } +} + +void tanh_add_cuda(const float* x, const float* y, float* output, int size) { + const int threads = 1024; + const int blocks = (size + threads - 1) / threads; + tanh_add_kernel<<<blocks, threads>>>(x, y, output, size); +} diff --git a/test/cpp_extensions/setup.py b/test/cpp_extensions/setup.py index 014f074a48..d85dbd1cbc 100644 --- a/test/cpp_extensions/setup.py +++ b/test/cpp_extensions/setup.py @@ -10,8 +10,11 @@ ext_modules = [ if torch.cuda.is_available(): extension = CUDAExtension( - 'torch_test_cuda_extension', - ['cuda_extension.cpp', 'cuda_extension_kernel.cu'], + 'torch_test_cuda_extension', [ + 'cuda_extension.cpp', + 'cuda_extension_kernel.cu', + 'cuda_extension_kernel2.cu', + ], extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}) ext_modules.append(extension) diff --git a/test/test_cpp_extensions.py b/test/test_cpp_extensions.py index 3ff66a5382..bbf82d881e 100644 --- a/test/test_cpp_extensions.py +++ b/test/test_cpp_extensions.py @@ -67,7 +67,7 @@ class TestCppExtension(common.TestCase): name='torch_test_cuda_extension', sources=[ 'cpp_extensions/cuda_extension.cpp', - 'cpp_extensions/cuda_extension_kernel.cu' + 'cpp_extensions/cuda_extension.cu' ], extra_cuda_cflags=['-O2'], verbose=True) diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 66f3aeb7ad..058e404072 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -1,3 +1,4 @@ +import copy import glob import imp import os @@ -82,23 +83,18 @@ class BuildExtension(build_ext): '''A custom build extension for adding compiler-specific options.''' def build_extensions(self): - # On some platforms, like Windows, compiler_cxx is not available. - if hasattr(self.compiler, 'compiler_cxx'): - compiler = self.compiler.compiler_cxx[0] - else: - compiler = os.environ.get('CXX', 'c++') - check_compiler_abi_compatibility(compiler) - + self._check_abi() for extension in self.extensions: - define = '-DTORCH_EXTENSION_NAME={}'.format(extension.name) - extension.extra_compile_args = [define] + self._define_torch_extension_name(extension) # Register .cu and .cuh as valid source extensions. self.compiler.src_extensions += ['.cu', '.cuh'] # Save the original _compile method for later. original_compile = self.compiler._compile - def wrap_compile(obj, src, ext, cc_args, cflags, pp_opts): + def wrap_compile(obj, src, ext, cc_args, extra_postargs, pp_opts): + # Copy before we make any modifications. + cflags = copy.deepcopy(extra_postargs) try: original_compiler = self.compiler.compiler_so if _is_cuda_file(src): @@ -106,10 +102,12 @@ class BuildExtension(build_ext): self.compiler.set_executable('compiler_so', nvcc) if isinstance(cflags, dict): cflags = cflags['nvcc'] - cflags += ['-c', '--compiler-options', "'-fPIC'"] - else: - if isinstance(cflags, dict): + cflags += ['--compiler-options', "'-fPIC'"] + elif isinstance(cflags, dict): cflags = cflags['cxx'] + # NVCC does not allow multiple -std to be passed, so we avoid + # overriding the option if the user explicitly passed it. + if not any(flag.startswith('-std=') for flag in cflags): cflags.append('-std=c++11') original_compile(obj, src, ext, cc_args, cflags, pp_opts) @@ -122,6 +120,22 @@ class BuildExtension(build_ext): build_ext.build_extensions(self) + def _check_abi(self): + # On some platforms, like Windows, compiler_cxx is not available. + if hasattr(self.compiler, 'compiler_cxx'): + compiler = self.compiler.compiler_cxx[0] + else: + compiler = os.environ.get('CXX', 'c++') + check_compiler_abi_compatibility(compiler) + + def _define_torch_extension_name(self, extension): + define = '-DTORCH_EXTENSION_NAME={}'.format(extension.name) + if isinstance(extension.extra_compile_args, dict): + for args in extension.extra_compile_args.values(): + args.append(define) + else: + extension.extra_compile_args.append(define) + def CppExtension(name, sources, *args, **kwargs): ''' @@ -178,7 +192,10 @@ def include_paths(cuda=False): ''' here = os.path.abspath(__file__) torch_path = os.path.dirname(os.path.dirname(here)) - paths = [os.path.join(torch_path, 'lib', 'include')] + lib_include = os.path.join(torch_path, 'lib', 'include') + # Some internal (old) Torch headers don't properly prefix their includes, + # so we need to pass -Itorch/lib/include/TH as well. + paths = [lib_include, os.path.join(lib_include, 'TH')] if cuda: paths.append(_join_cuda_home('include')) return paths @@ -356,18 +373,22 @@ def _write_ninja_file(path, # sysconfig.get_paths()['include'] gives us the location of Python.h includes.append(sysconfig.get_paths()['include']) - cflags = ['-fPIC', '-std=c++11', '-DTORCH_EXTENSION_NAME={}'.format(name)] - cflags += ['-I{}'.format(include) for include in includes] - cflags += extra_cflags + common_cflags = ['-DTORCH_EXTENSION_NAME={}'.format(name)] + common_cflags += ['-I{}'.format(include) for include in includes] + + cflags = common_cflags + ['-fPIC', '-std=c++11'] + extra_cflags flags = ['cflags = {}'.format(' '.join(cflags))] if with_cuda: - cuda_flags = "--compiler-options '-fPIC'" - extra_flags = ' '.join(extra_cuda_cflags) - flags.append('cuda_flags = {} {}'.format(cuda_flags, extra_flags)) + cuda_flags = common_cflags + cuda_flags += ['--compiler-options', "'-fPIC'"] + cuda_flags += extra_cuda_cflags + if not any(flag.startswith('-std=') for flag in cuda_flags): + cuda_flags.append('-std=c++11') + flags.append('cuda_flags = {}'.format(' '.join(cuda_flags))) ldflags = ['-shared'] + extra_ldflags - # The darwin linker needs explicit consent to ignore unresolved symbols + # The darwin linker needs explicit consent to ignore unresolved symbols. if sys.platform == 'darwin': ldflags.append('-undefined dynamic_lookup') flags.append('ldflags = {}'.format(' '.join(ldflags))) @@ -393,9 +414,15 @@ def _write_ninja_file(path, for source_file in sources: # '/path/to/file.cpp' -> 'file' file_name = os.path.splitext(os.path.basename(source_file))[0] - target = '{}.o'.format(file_name) + if _is_cuda_file(source_file): + rule = 'cuda_compile' + # Use a different object filename in case a C++ and CUDA file have + # the same filename but different extension (.cpp vs. .cu). + target = '{}.cuda.o'.format(file_name) + else: + rule = 'compile' + target = '{}.o'.format(file_name) object_files.append(target) - rule = 'cuda_compile' if _is_cuda_file(source_file) else 'compile' build.append('build {}: {} {}'.format(target, rule, source_file)) library_target = '{}.so'.format(name) |