diff options
author | Adam Lerer <alerer@fb.com> | 2016-10-17 12:18:55 -0700 |
---|---|---|
committer | Adam Lerer <alerer@fb.com> | 2016-10-23 21:11:22 -0700 |
commit | b5d13296c65e4b3cd5aa9715cf58df0fc043454e (patch) | |
tree | debe0dc9c35b048db4d212ae9ee512a2f3794023 /torch/backends | |
parent | 86288265add5225be6de7870a88941937d9475de (diff) | |
download | pytorch-b5d13296c65e4b3cd5aa9715cf58df0fc043454e.tar.gz pytorch-b5d13296c65e4b3cd5aa9715cf58df0fc043454e.tar.bz2 pytorch-b5d13296c65e4b3cd5aa9715cf58df0fc043454e.zip |
addressing comments
Diffstat (limited to 'torch/backends')
-rw-r--r-- | torch/backends/cudnn/__init__.py | 2 | ||||
-rw-r--r-- | torch/backends/cudnn/rnn.py | 140 |
2 files changed, 81 insertions, 61 deletions
diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index f3d8be610f..3575f6bd35 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -280,7 +280,7 @@ def int_array(itr): return array_type(*itr) def descriptor(tensor, N=None): - if N: + if N is not None: descriptor = TensorDescriptorArray(N) else: descriptor = TensorDescriptor() diff --git a/torch/backends/cudnn/rnn.py b/torch/backends/cudnn/rnn.py index e57f4b7643..abba8187f4 100644 --- a/torch/backends/cudnn/rnn.py +++ b/torch/backends/cudnn/rnn.py @@ -3,8 +3,20 @@ import torch.backends.cudnn as cudnn from torch.backends.cudnn import check_error import ctypes +def get_cudnn_mode(mode): + if mode == 'RNN_RELU': + return cudnn.CUDNN_RNN_RELU + elif mode == 'RNN_TANH': + return cudnn.CUDNN_RNN_TANH + elif mode == 'LSTM': + return cudnn.CUDNN_LSTM + elif mode == 'GRU': + return cudnn.CUDNN_GRU + else: + raise Exception("Unknown mode: {}".format(mode)) + -def initDropoutDescriptor(fn, handle): +def init_dropout_descriptor(fn, handle): dropout_desc = cudnn.DropoutDescriptor() dropout_states_size = ctypes.c_long() @@ -22,7 +34,7 @@ def initDropoutDescriptor(fn, handle): return dropout_desc -def initRNNDescriptor(fn): +def init_rnn_descriptor(fn): rnn_desc = cudnn.RNNDescriptor() rnn_desc.set( @@ -37,26 +49,26 @@ def initRNNDescriptor(fn): return rnn_desc -def initWeightDescriptor(fn, weight): +def init_weight_descriptor(fn, weight): w_desc = cudnn.FilterDescriptor() w_view = weight.view(-1, 1, 1) # seems that filters require >=3 dimensions w_desc.set(w_view) return w_desc -def _inputSize(fn): +def _input_size(fn): return (fn.seq_length, fn.mini_batch, fn.input_size) -def _hiddenSize(fn): +def _hidden_size(fn): return (fn.num_layers * fn.num_directions, fn.mini_batch, fn.hidden_size) -def _outputSize(fn): +def _output_size(fn): return (fn.seq_length, fn.mini_batch, fn.hidden_size * fn.num_directions) -def getNumWeights(handle, rnn_desc, x_desc, datatype): +def get_num_weights(handle, rnn_desc, x_desc, datatype): weight_size = ctypes.c_long() check_error(cudnn.lib.cudnnGetRNNParamsSize( handle, @@ -70,24 +82,30 @@ def getNumWeights(handle, rnn_desc, x_desc, datatype): return weight_size.value // elem_size -def getParameters(fn, handle, weight_buf): +def get_parameters(fn, handle, weight_buf): + """Returns weight and bias tensors for each layer of the RNN. These tensors + are views on the underlying weight buffer allocated by CuDNN. + + Note: for LSTM and GRU, which have multiple parameters of each type (4 and 3, respectively), + these parameters are concatenated along the first dimension. + These parameters are returned in a consistent order by CuDNN: + (reset, forget, cell, outut) for LSTM + (reset, input, new) for GRU + Args: + fn: The RNN function object holding the RNN state + handle: a CuDNN handle + weight_buf: a 1D tensor containing the CuDNN-allocated weight (or grad_weight) buffer + Returns: + parameters: [(weight_ih, weight_hh, bias_ih, bias_hh)*], with length equal to the num_layers. + """ cudnn_methods = [ cudnn.lib.cudnnGetRNNLinLayerMatrixParams, cudnn.lib.cudnnGetRNNLinLayerBiasParams ] - # if fn.mode == cudnn.CUDNN_RNN_RELU or fn.mode == cudnn.CUDNN_RNN_TANH: - # linear_name = ["ih", "hh"] - # elif fn.mode == cudnn.CUDNN_LSTM: - # linear_name = ["ii", "if", "ic", "io", "hi", "hf", "hc", "ho"] - # elif fn.mode == cudnn.CUDNN_GRU: - # linear_name = ["ir", "iu", "ic", "hr", "hu", "hc"] - # else: - # raise Exception("Unknown mode: {}".format(fn.mode)) - params = [] - num_linear_layers = _numLinearLayers(fn) + num_linear_layers = _num_linear_layers(fn) num_layers = fn.num_directions * fn.num_layers for layer in range(num_layers): layer_params = [] @@ -134,7 +152,7 @@ def getParameters(fn, handle, weight_buf): assert(filter_dim_a.prod() == filter_dim_a[0]) param = fn.weight_buf.new().set_( weight_buf.storage(), offset, - filter_dim_a[0] * num_linear_layers / 2, filter_dim_a[2]) + filter_dim_a[0] * num_linear_layers // 2, filter_dim_a[2]) layer_params.append(param) else: assert(cur_offset == offset) @@ -170,16 +188,18 @@ def forward(fn, input, hx, weight, output, hy): input = input.transpose(0, 1) if input.dim() != 3: - raise Exception( - 'input must have 3 dimensions: seq_length, mini_batch, input_size') + raise RuntimeError( + 'input must have 3 dimensions, got {}'.format(input.dim())) if fn.input_size != input.size(2): - raise Exception('input.size(2) must be equal to input_size provided') + raise RuntimeError('input.size(2) must be equal to input_size. Expected {}, got {}'.format( + fn.input_size + )) if fn.dropout != 0 and cudnn.lib.version < 5103: - raise Exception('dropout supported only in cudnn v5.1 and above') + raise RuntimeError('dropout supported only in cudnn v5.1 and above') fn.seq_length, fn.mini_batch, fn.input_size = input.size() - hidden_size = _hiddenSize(fn) - output_size = _outputSize(fn) + hidden_size = _hidden_size(fn) + output_size = _output_size(fn) x = input.contiguous() output.resize_(*output_size) hy.resize_(*hidden_size).zero_() @@ -188,33 +208,33 @@ def forward(fn, input, hx, weight, output, hy): y = output # init descriptors - fn.dropout_desc = initDropoutDescriptor(fn, handle) - fn.rnn_desc = initRNNDescriptor(fn) - fn.x_descs = cudnn.descriptor(x[0], fn.seq_length) - fn.y_descs = cudnn.descriptor(y[0], fn.seq_length) - fn.hx_desc = cudnn.descriptor(hx) - fn.hy_desc = cudnn.descriptor(hx) - fn.cx_desc = cudnn.descriptor(cx) if cx else None - fn.cy_desc = cudnn.descriptor(cx) if cx else None + fn.dropout_desc = init_dropout_descriptor(fn, handle) + fn.rnn_desc = init_rnn_descriptor(fn) + fn.x_descs = cudnn.descriptor(x[0], fn.seq_length) + fn.y_descs = cudnn.descriptor(y[0], fn.seq_length) + fn.hx_desc = cudnn.descriptor(hx) + fn.hy_desc = cudnn.descriptor(hx) + fn.cx_desc = cudnn.descriptor(cx) if cx else None + fn.cy_desc = cudnn.descriptor(cx) if cx else None # create the weight buffer and copy the weights into it - num_weights = getNumWeights( + num_weights = get_num_weights( handle, fn.rnn_desc, fn.x_descs[0], fn.datatype) fn.weight_buf = input.new(num_weights) - fn.w_desc = initWeightDescriptor(fn, fn.weight_buf) + fn.w_desc = init_weight_descriptor(fn, fn.weight_buf) w = fn.weight_buf # this zero might not seem necessary, but it is in the case # where biases are disabled; then they won't be copied and must be zero'd. # Alternatively, _copyParams could be written more carefully. w.zero_() - params = getParameters(fn, handle, w) + params = get_parameters(fn, handle, w) _copyParams(weight, params) if tuple(hx.size()) != hidden_size: - raise Exception('Expected hx size {}, got {}'.format( + raise RuntimeError('Expected hidden size {}, got {}'.format( hidden_size, tuple(hx.size()))) if cx and tuple(cx.size()) != hidden_size: - raise Exception('Expected cx size {}, got {}'.format( + raise RuntimeError('Expected cell size {}, got {}'.format( hidden_size, tuple(cx.size()))) workspace_size = ctypes.c_long() @@ -286,9 +306,9 @@ def backward_grad(fn, input, hx, weight, output, grad_output, grad_hy, grad_inpu grad_output = grad_output.transpose(0, 1) output = output.transpose(0, 1) - input_size = _inputSize(fn) - hidden_size = _hiddenSize(fn) - output_size = _outputSize(fn) + input_size = _input_size(fn) + hidden_size = _hidden_size(fn) + output_size = _output_size(fn) x = input.contiguous() dy = grad_output.contiguous() @@ -301,26 +321,26 @@ def backward_grad(fn, input, hx, weight, output, grad_output, grad_hy, grad_inpu dcx = grad_cx.resize_(*hidden_size) if grad_cx else None if fn.dropout != 0 and lib.version < 5103: - raise Exception('dropout supported only in cudnn v 5.1 and above') + raise RuntimeError('dropout supported only in cudnn v 5.1 and above') if not fn.train: - raise Exception('backward_grad can only be called when training!') + raise RuntimeError('backward_grad can only be called when training!') if tuple(input.size()) != input_size: - raise Exception('Expected input size {}, got {}'.format( + raise RuntimeError('Expected input size {}, got {}'.format( input_size, tuple(input.size()))) - if tuple(output.size()) != _outputSize(fn): - raise Exception('Expected output size {}, got {}'.format( + if tuple(output.size()) != _output_size(fn): + raise RuntimeError('Expected output size {}, got {}'.format( output_size, output.size())) if hx and tuple(hx.size()) != hidden_size: - raise Exception('Expected hx size {}, got {}'.format( + raise RuntimeError('Expected hidden size {}, got {}'.format( hidden_size, hx.size())) if cx and tuple(cx.size()) != hidden_size: - raise Exception('Expected cx size {}, got {}'.format( + raise RuntimeError('Expected cell size {}, got {}'.format( hidden_size, cx.size())) if dhy and tuple(dhy.size()) != hidden_size: - raise Exception('Expected dhy size {}, got {}'.format( + raise RuntimeError('Expected d_hidden size {}, got {}'.format( hidden_size, dhy.size())) if dcy and tuple(dcy.size()) != hidden_size: - raise Exception('Expected dcy size {}, got {}'.format( + raise RuntimeError('Expected d_cell size {}, got {}'.format( hidden_size, dcy.size())) check_error(cudnn.lib.cudnnRNNBackwardData( @@ -345,7 +365,7 @@ def backward_grad(fn, input, hx, weight, output, grad_output, grad_hy, grad_inpu grad_input = grad_input.transpose(0, 1) -def _numLinearLayers(fn): +def _num_linear_layers(fn): if fn.mode == cudnn.CUDNN_LSTM: return 8 elif fn.mode == cudnn.CUDNN_GRU: @@ -355,7 +375,7 @@ def _numLinearLayers(fn): elif fn.mode == cudnn.CUDNN_RNN_TANH: return 2 else: - raise Exception('Unknown mode: {}'.format(fn.mode)) + raise RuntimeError('Unknown mode: {}'.format(fn.mode)) def backward_weight(fn, input, hx, output, weight, grad_weight): @@ -371,19 +391,19 @@ def backward_weight(fn, input, hx, output, weight, grad_weight): input = input.transpose(1, 2) output = output.transpose(1, 2) - input_size = _inputSize(fn) - hidden_size = _hiddenSize(fn) + input_size = _input_size(fn) + hidden_size = _hidden_size(fn) if not fn.train: - raise Exception('backward_weight can only be called when training!') + raise RuntimeError('backward_weight can only be called when training!') if fn.dropout != 0 and lib.version < 5103: - raise Exception('dropout supported only in cudnn v 5.1 and above') + raise RuntimeError('dropout supported only in cudnn v 5.1 and above') if tuple(input.size()) != input_size: - raise Exception('Expected input size {}, got {}'.format( + raise RuntimeError('Expected input size {}, got {}'.format( input_size, tuple(input.size()))) if not fn.train: - raise Exception('backward_weight can only be called when training!') + raise RuntimeError('backward_weight can only be called when training!') if tuple(hx.size()) != hidden_size: - raise Exception('Expected input size {}, got {}'.format( + raise RuntimeError('Expected input size {}, got {}'.format( hidden_size, hx.size())) x = input.contiguous() @@ -403,6 +423,6 @@ def backward_weight(fn, input, hx, output, weight, grad_weight): )) # copy the weights from the weight_buf into grad_weight - grad_params = getParameters(fn, handle, dw) + grad_params = get_parameters(fn, handle, dw) _copyParams(grad_params, grad_weight) return grad_weight |