summaryrefslogtreecommitdiff
path: root/torch/backends
diff options
context:
space:
mode:
authorAdam Lerer <alerer@fb.com>2016-12-18 16:39:09 -0800
committerAdam Paszke <adam.paszke@gmail.com>2016-12-30 00:14:55 +0100
commit183b3aacd2db811e8a6e7509f8f306ab5f454272 (patch)
tree8e436b1b776b08c7e11187bb98d930236626c875 /torch/backends
parent101950ce926d36131fa342a5b786f922a10d9316 (diff)
downloadpytorch-183b3aacd2db811e8a6e7509f8f306ab5f454272.tar.gz
pytorch-183b3aacd2db811e8a6e7509f8f306ab5f454272.tar.bz2
pytorch-183b3aacd2db811e8a6e7509f8f306ab5f454272.zip
Hold CuDNN PRNG state between RNN iterations
Diffstat (limited to 'torch/backends')
-rw-r--r--torch/backends/cudnn/rnn.py25
1 files changed, 22 insertions, 3 deletions
diff --git a/torch/backends/cudnn/rnn.py b/torch/backends/cudnn/rnn.py
index 0beb481629..24a5cc0893 100644
--- a/torch/backends/cudnn/rnn.py
+++ b/torch/backends/cudnn/rnn.py
@@ -16,18 +16,34 @@ def get_cudnn_mode(mode):
raise Exception("Unknown mode: {}".format(mode))
+class Unserializable(object):
+ def __init__(self, inner):
+ self.inner = inner
+
+ def get(self):
+ return self.inner
+
+ def __getstate__(self):
+ # Note: can't return {}, because python2 won't call __setstate__
+ # if the value evaluates to False
+ return "<unserializable>"
+
+ def __setstate__(self, state):
+ self.inner = None
+
+
def init_dropout_descriptor(fn, handle):
return cudnn.DropoutDescriptor(
handle,
fn.dropout,
- fn.seed
+ fn.dropout_seed
)
def init_rnn_descriptor(fn):
return cudnn.RNNDescriptor(
fn.hidden_size,
fn.num_layers,
- fn.dropout_desc,
+ fn.dropout_state['desc'].get(),
fn.input_mode,
fn.bidirectional,
fn.mode,
@@ -194,7 +210,10 @@ def forward(fn, input, hx, weight, output, hy):
y = output
# init descriptors
- fn.dropout_desc = init_dropout_descriptor(fn, handle)
+ if ('desc' not in fn.dropout_state) or (fn.dropout_state['desc'].get() is None):
+ fn.dropout_state['desc'] = Unserializable(
+ init_dropout_descriptor(fn, handle)
+ )
fn.rnn_desc = init_rnn_descriptor(fn)
fn.x_descs = cudnn.descriptor(x[0], fn.seq_length)
fn.y_descs = cudnn.descriptor(y[0], fn.seq_length)