diff options
author | Sam Gross <colesbury@gmail.com> | 2016-09-27 17:55:04 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-09-27 17:55:04 -0400 |
commit | 779a4600309ab0f4059f72b144bd5f7a934c947d (patch) | |
tree | 6747923f8770154ad12d0f15f6f1d92aa09ee333 /torch/backends | |
parent | 5107f2312659a8a005dc6cc0f1c3f1556172c09c (diff) | |
download | pytorch-779a4600309ab0f4059f72b144bd5f7a934c947d.tar.gz pytorch-779a4600309ab0f4059f72b144bd5f7a934c947d.tar.bz2 pytorch-779a4600309ab0f4059f72b144bd5f7a934c947d.zip |
Add cuDNN support for convolutions (#36)
Diffstat (limited to 'torch/backends')
-rw-r--r-- | torch/backends/__init__.py | 0 | ||||
-rw-r--r-- | torch/backends/cudnn/__init__.py | 293 | ||||
-rw-r--r-- | torch/backends/cudnn/conv.py | 131 |
3 files changed, 424 insertions, 0 deletions
diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/torch/backends/__init__.py diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py new file mode 100644 index 0000000000..c71cad3ad6 --- /dev/null +++ b/torch/backends/cudnn/__init__.py @@ -0,0 +1,293 @@ +import ctypes +import torch.cuda +import warnings + +lib = None +libname = 'libcudnn.so.5.1.3' + + +def _loadlib(): + global lib + lib = ctypes.cdll.LoadLibrary(libname) + lib.cudnnGetErrorString.restype = ctypes.c_char_p + + +def is_acceptable(tensor): + if not (isinstance(tensor, torch.cuda.HalfTensor) or + isinstance(tensor, torch.cuda.FloatTensor) or + isinstance(tensor, torch.cuda.DoubleTensor)): + return False + + if lib is None: + try: + _loadlib() + except Exception: + warnings.warn('cuDNN library not found. Check your LD_LIBRARY_PATH') + return False + return True + + +_handles = {} + +benchmark = False +verbose = False +workspace_limit = None + +CUDNN_DATA_FLOAT = 0 +CUDNN_DATA_DOUBLE = 1 +CUDNN_DATA_HALF = 2 + +CUDNN_CONVOLUTION = 0 +CUDNN_CROSS_CORRELATION = 1 + +CUDNN_CONVOLUTION_FWD_NO_WORKSPACE = 0 +CUDNN_CONVOLUTION_FWD_PREFER_FASTEST = 1 +CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT = 2 + +CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE = 0 +CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST = 1 +CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT = 2 + +CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE = 0 +CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST = 1 +CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT = 2 + +CUDNN_TENSOR_NCHW = 0 +CUDNN_TENSOR_NHWC = 1 + + +class CuDNNHandle: + def __init__(self): + ptr = ctypes.c_void_p() + check_error(lib.cudnnCreate(ctypes.byref(ptr))) + self._as_parameter_ = ptr + + def __del__(self): + check_error(lib.cudnnDestroy(self)) + +class CuDNNError(RuntimeError): + def __init__(self, status): + self.status = status + msg = '{}: {}'.format(status, get_error_string(status)) + super(CuDNNError, self).__init__(msg) + +class TensorDescriptor: + def __init__(self): + ptr = ctypes.c_void_p() + check_error(lib.cudnnCreateTensorDescriptor(ctypes.byref(ptr))) + self._as_parameter_ = ptr + + def __del__(self): + check_error(lib.cudnnDestroyTensorDescriptor(self)) + + def set(self, tensor): + self._type = tensor.type() + self._size = tensor.size() + self._stride = tensor.stride() + check_error(lib.cudnnSetTensorNdDescriptor( + self, _typemap[tensor.type()], tensor.dim(), + int_array(tensor.size()), int_array(tensor.stride()))) + + def as_tuple(self): + return (self._type, tuple(self._size), tuple(self._stride)) + +class ConvolutionDescriptor: + def __init__(self): + ptr = ctypes.c_void_p() + check_error(lib.cudnnCreateConvolutionDescriptor(ctypes.byref(ptr))) + self._as_parameter_ = ptr + + def __del__(self): + check_error(lib.cudnnDestroyConvolutionDescriptor(self)) + + def set(self, typename, pad, stride): + self._pad = pad + self._stride = stride + upscale = int_array([1, 1]) + check_error(lib.cudnnSetConvolutionNdDescriptor( + self, 2, int_array(pad), int_array(stride), upscale, + CUDNN_CROSS_CORRELATION, _typemap[typename])) + + def as_tuple(self): + return (self._pad, self._stride) + +class FilterDescriptor: + def __init__(self): + ptr = ctypes.c_void_p() + check_error(lib.cudnnCreateFilterDescriptor(ctypes.byref(ptr))) + self._as_parameter_ = ptr + + def __del__(self): + check_error(lib.cudnnDestroyFilterDescriptor(self)) + + def set(self, weight): + self._size = weight.size() + datatype = _typemap[weight.type()] + check_error(lib.cudnnSetFilterNdDescriptor( + self, datatype, CUDNN_TENSOR_NCHW, 4, int_array(weight.size()))) + + def as_tuple(self): + return tuple(self._size) + +class ConvolutionAlgoPerf(ctypes.Structure): + _fields_ = [ + ("algo", ctypes.c_int), + ("status", ctypes.c_int), + ("time", ctypes.c_float), + ("memory", ctypes.c_size_t), + ] + +def check_error(status): + if status is not 0: + raise CuDNNError(status) + +def get_error_string(status): + return lib.cudnnGetErrorString(status) + +def get_handle(): + if lib is None: + _loadlib() + current_device = torch.cuda.current_device() + handle = _handles.get(current_device, None) + if handle is None: + handle = CuDNNHandle() + _handles[current_device] = handle + return handle + +_typemap = { + 'torch.cuda.HalfTensor': CUDNN_DATA_HALF, + 'torch.cuda.FloatTensor': CUDNN_DATA_FLOAT, + 'torch.cuda.DoubleTensor': CUDNN_DATA_DOUBLE, +} + +def c_type(tensor): + if isinstance(tensor, torch.cuda.HalfTensor): + return ctypes.c_float + elif isinstance(tensor, torch.cuda.FloatTensor): + return ctypes.c_float + elif isinstance(tensor, torch.cuda.DoubleTensor): + return ctypes.c_double + else: + raise ValueError("unknown type '{}'".format(type(tensor))) + +def int_array(itr): + array_type = ctypes.c_int * len(itr) + return array_type(*itr) + +def descriptor(tensor): + descriptor = TensorDescriptor() + if tensor.dim() == 2: + tensor = tensor.view(tensor.size(0), tensor.size(1), 1, 1) + elif tensor.dim() == 3: + tensor = tensor.view(tensor.size(0), tensor.size(1), tensor.size(2), 1) + descriptor.set(tensor) + return descriptor + +_autotuner_forward = {} +_autotuner_backward_data = {} +_autotuner_backward_filter = {} + +def convolution_autotuner_key(idesc, weight_desc, conv_desc): + return (idesc.as_tuple(), weight_desc.as_tuple(), conv_desc.as_tuple()) + +def convolution_forward_algorithm(idesc, weight_desc, conv_desc, odesc): + k = convolution_autotuner_key(idesc, weight_desc, conv_desc) + if k in _autotuner_forward: + return _autotuner_forward[k] + + if benchmark: + perf_results = ConvolutionAlgoPerf() + algo_count = ctypes.c_int() + check_error(lib.cudnnFindConvolutionForwardAlgorithm( + get_handle(), idesc, weight_desc, conv_desc, odesc, 1, + ctypes.byref(algo_count), ctypes.byref(perf_results))) + _autotuner_forward[k] = perf_results.algo + return perf_results.algo + + search_mode = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST + wlimit = 0 + if workspace_limit is not None: + wlimit = workspace_limit + search_mode = CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT + + fwd_alg = ctypes.c_int() + check_error(lib.cudnnGetConvolutionForwardAlgorithm( + get_handle(), idesc, weight_desc, conv_desc, odesc, search_mode, + wlimit, ctypes.byref(fwd_alg))) + return fwd_alg + +def convolution_forward_workspace_size(*args): + check_error(lib.cudnnGetConvolutionForwardWorkspaceSize(*args)) + +def convolution_forward(*args): + check_error(lib.cudnnConvolutionForward(*args)) + +def convolution_backward_data(*args): + return check_error(lib.cudnnConvolutionBackwardData(*args)) + +def convolution_backward_data_algorithm(weight_desc, odesc, conv_desc, idesc): + k = convolution_autotuner_key(idesc, weight_desc, conv_desc) + if k in _autotuner_backward_data: + return _autotuner_backward_data[k] + + if benchmark: + perf_results = ConvolutionAlgoPerf() + algo_count = ctypes.c_int() + check_error(lib.cudnnFindConvolutionBackwardDataAlgorithm( + get_handle(), weight_desc, odesc, conv_desc, idesc, 1, + ctypes.byref(algo_count), ctypes.byref(perf_results))) + _autotuner_backward_data[k] = perf_results.algo + return perf_results.algo + + search_mode = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST + wlimit = 0 + if workspace_limit is not None: + wlimit = workspace_limit + search_mode = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT + + bwd_data_alg = ctypes.c_int() + check_error(lib.cudnnGetConvolutionBackwardDataAlgorithm( + get_handle(), weight_desc, odesc, conv_desc, idesc, search_mode, + wlimit, ctypes.byref(bwd_data_alg))) + return bwd_data_alg + +def convolution_backward_data_workspace_size(*args): + return check_error(lib.cudnnGetConvolutionBackwardDataWorkspaceSize(*args)) + +def convolution_backward_filter(*args): + return check_error(lib.cudnnConvolutionBackwardFilter(*args)) + +def convolution_backward_filter_algorithm(idesc, odesc, conv_desc, weight_desc): + k = convolution_autotuner_key(idesc, weight_desc, conv_desc) + if k in _autotuner_backward_filter: + return _autotuner_backward_filter[k] + + if benchmark: + perf_results = ConvolutionAlgoPerf() + algo_count = ctypes.c_int() + check_error(lib.cudnnFindConvolutionBackwardFilterAlgorithm( + get_handle(), idesc, odesc, conv_desc, weight_desc, 1, + ctypes.byref(algo_count), ctypes.byref(perf_results))) + _autotuner_backward_filter[k] = perf_results.algo + return perf_results.algo + + search_mode = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST + wlimit = 0 + if workspace_limit is not None: + wlimit = workspace_limit + search_mode = CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT + + bwd_filter_alg = ctypes.c_int() + check_error(lib.cudnnGetConvolutionBackwardFilterAlgorithm( + get_handle(), idesc, odesc, conv_desc, weight_desc, search_mode, + wlimit, ctypes.byref(bwd_filter_alg))) + return bwd_filter_alg + +def convolution_backward_filter_workspace_size(*args): + return check_error(lib.cudnnGetConvolutionBackwardFilterWorkspaceSize(*args)) + +def convolution_backward_bias(*args): + check_error(lib.cudnnConvolutionBackwardBias(*args)) + +def add_tensor(*args): + check_error(lib.cudnnAddTensor(*args)) diff --git a/torch/backends/cudnn/conv.py b/torch/backends/cudnn/conv.py new file mode 100644 index 0000000000..93f756fd73 --- /dev/null +++ b/torch/backends/cudnn/conv.py @@ -0,0 +1,131 @@ +import torch.cuda +import torch.backends.cudnn as cudnn +import ctypes + +def forward(fn, input, weight, bias, output): + handle = cudnn.get_handle() + out_channels, in_channels = weight.size(0), weight.size(1) + + inslice = input.narrow(1, 0, in_channels // fn.groups) + outslice = output.narrow(1, 0, out_channels // fn.groups) + weight_slice = ( + weight.narrow(0, 0, out_channels // fn.groups) + .narrow(1, 0, in_channels // fn.groups) + ) + + fn.input_offset = inslice[0].numel() * input.element_size() + fn.output_offset = outslice[0].numel() * output.element_size() + fn.weight_offset = weight_slice.numel() * weight.element_size() + + fn.idesc = cudnn.descriptor(inslice) + fn.odesc = cudnn.descriptor(outslice) + fn.odesc_bias = cudnn.descriptor(output) + + fn.wdesc = cudnn.FilterDescriptor() + fn.wdesc.set(weight_slice) + + fn.conv_desc = cudnn.ConvolutionDescriptor() + fn.conv_desc.set(weight.type(), fn.pad, fn.stride) + + fwd_alg = cudnn.convolution_forward_algorithm( + fn.idesc, fn.wdesc, fn.conv_desc, fn.odesc) + + workspace_size = ctypes.c_size_t() + cudnn.convolution_forward_workspace_size( + cudnn.get_handle(), fn.idesc, fn.wdesc, fn.conv_desc, + fn.odesc, fwd_alg, ctypes.byref(workspace_size)) + + workspace = torch.cuda.ByteStorage(workspace_size.value) + + alpha = cudnn.c_type(input)(1) + beta = cudnn.c_type(output)(0) + for g in range(fn.groups): + input_ptr = ctypes.c_void_p(input.data_ptr() + g * fn.input_offset) + weight_ptr = ctypes.c_void_p(weight.data_ptr() + g * fn.weight_offset) + output_ptr = ctypes.c_void_p(output.data_ptr() + g * fn.output_offset) + workspace_ptr = ctypes.c_void_p(workspace.data_ptr()) + + cudnn.convolution_forward( + handle, ctypes.byref(alpha), fn.idesc, input_ptr, fn.wdesc, + weight_ptr, fn.conv_desc, fwd_alg, workspace_ptr, + workspace_size, ctypes.byref(beta), fn.odesc, output_ptr) + + if bias is not None: + alpha = cudnn.c_type(input)(1) + beta = cudnn.c_type(output)(1) + + fn.bias_desc = cudnn.descriptor(bias.view(1, bias.size(0), 1, 1)) + cudnn.add_tensor( + handle, ctypes.byref(alpha), fn.bias_desc, + ctypes.c_void_p(bias.data_ptr()), ctypes.byref(beta), + fn.odesc_bias, ctypes.c_void_p(output.data_ptr())) + + return output + +def backward_data(fn, grad_output, input, weight): + handle = cudnn.get_handle() + grad_input = input.new().resize_as_(input) + + bwd_data_alg = cudnn.convolution_backward_data_algorithm( + fn.wdesc, fn.odesc, fn.conv_desc, fn.idesc) + + workspace_size = ctypes.c_size_t() + cudnn.convolution_backward_data_workspace_size( + handle, fn.wdesc, fn.odesc, fn.conv_desc, fn.idesc, + bwd_data_alg, ctypes.byref(workspace_size)) + + workspace = torch.cuda.ByteStorage(workspace_size.value) + + alpha = cudnn.c_type(input)(1) + beta = cudnn.c_type(input)(0) + for g in range(fn.groups): + cudnn.convolution_backward_data( + handle, ctypes.byref(alpha), fn.wdesc, + ctypes.c_void_p(weight.data_ptr() + g * fn.weight_offset), + fn.odesc, + ctypes.c_void_p(grad_output.data_ptr() + g * fn.output_offset), + fn.conv_desc, bwd_data_alg, ctypes.c_void_p(workspace.data_ptr()), + workspace_size, ctypes.byref(beta), fn.idesc, + ctypes.c_void_p(grad_input.data_ptr() + g * fn.input_offset)) + + return grad_input + +def backward_filter(fn, grad_output, input, weight): + handle = cudnn.get_handle() + grad_weight = weight.new().resize_as_(weight) + + bwd_filter_alg = cudnn.convolution_backward_filter_algorithm( + fn.idesc, fn.odesc, fn.conv_desc, fn.wdesc) + + workspace_size = ctypes.c_size_t() + cudnn.convolution_backward_filter_workspace_size( + handle, fn.idesc, fn.odesc, fn.conv_desc, fn.wdesc, + bwd_filter_alg, ctypes.byref(workspace_size)) + + workspace = torch.cuda.ByteStorage(workspace_size.value) + + alpha = cudnn.c_type(input)(1) + beta = cudnn.c_type(input)(0) + for g in range(fn.groups): + cudnn.convolution_backward_filter( + handle, ctypes.byref(alpha), fn.idesc, + ctypes.c_void_p(input.data_ptr() + g * fn.input_offset), + fn.odesc, + ctypes.c_void_p(grad_output.data_ptr() + g * fn.output_offset), + fn.conv_desc, bwd_filter_alg, + ctypes.c_void_p(workspace.data_ptr()), workspace_size, + ctypes.byref(beta), fn.wdesc, + ctypes.c_void_p(grad_weight.data_ptr() + g * fn.weight_offset)) + + return grad_weight + +def backward_bias(fn, grad_output, bias): + grad_bias = bias.new().resize_as_(bias) + alpha = cudnn.c_type(grad_output)(1) + beta = cudnn.c_type(grad_output)(0) + + cudnn.convolution_backward_bias( + cudnn.get_handle(), ctypes.byref(alpha), fn.odesc_bias, + ctypes.c_void_p(grad_output.data_ptr()), ctypes.byref(beta), + fn.bias_desc, ctypes.c_void_p(grad_bias.data_ptr())) + return grad_bias |