diff options
author | Adam Lerer <alerer@fb.com> | 2016-12-18 16:39:09 -0800 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2016-12-30 00:14:55 +0100 |
commit | 183b3aacd2db811e8a6e7509f8f306ab5f454272 (patch) | |
tree | 8e436b1b776b08c7e11187bb98d930236626c875 /torch/backends | |
parent | 101950ce926d36131fa342a5b786f922a10d9316 (diff) | |
download | pytorch-183b3aacd2db811e8a6e7509f8f306ab5f454272.tar.gz pytorch-183b3aacd2db811e8a6e7509f8f306ab5f454272.tar.bz2 pytorch-183b3aacd2db811e8a6e7509f8f306ab5f454272.zip |
Hold CuDNN PRNG state between RNN iterations
Diffstat (limited to 'torch/backends')
-rw-r--r-- | torch/backends/cudnn/rnn.py | 25 |
1 files changed, 22 insertions, 3 deletions
diff --git a/torch/backends/cudnn/rnn.py b/torch/backends/cudnn/rnn.py index 0beb481629..24a5cc0893 100644 --- a/torch/backends/cudnn/rnn.py +++ b/torch/backends/cudnn/rnn.py @@ -16,18 +16,34 @@ def get_cudnn_mode(mode): raise Exception("Unknown mode: {}".format(mode)) +class Unserializable(object): + def __init__(self, inner): + self.inner = inner + + def get(self): + return self.inner + + def __getstate__(self): + # Note: can't return {}, because python2 won't call __setstate__ + # if the value evaluates to False + return "<unserializable>" + + def __setstate__(self, state): + self.inner = None + + def init_dropout_descriptor(fn, handle): return cudnn.DropoutDescriptor( handle, fn.dropout, - fn.seed + fn.dropout_seed ) def init_rnn_descriptor(fn): return cudnn.RNNDescriptor( fn.hidden_size, fn.num_layers, - fn.dropout_desc, + fn.dropout_state['desc'].get(), fn.input_mode, fn.bidirectional, fn.mode, @@ -194,7 +210,10 @@ def forward(fn, input, hx, weight, output, hy): y = output # init descriptors - fn.dropout_desc = init_dropout_descriptor(fn, handle) + if ('desc' not in fn.dropout_state) or (fn.dropout_state['desc'].get() is None): + fn.dropout_state['desc'] = Unserializable( + init_dropout_descriptor(fn, handle) + ) fn.rnn_desc = init_rnn_descriptor(fn) fn.x_descs = cudnn.descriptor(x[0], fn.seq_length) fn.y_descs = cudnn.descriptor(y[0], fn.seq_length) |