summaryrefslogtreecommitdiff
path: root/torch/backends
diff options
context:
space:
mode:
authorChristian Sarofeen <csarofeen@nvidia.com>2017-02-17 10:16:12 -0800
committerAdam Paszke <adam.paszke@gmail.com>2017-02-17 19:16:12 +0100
commit04aba1caec28dcc7517c85b40addb5a7741683fa (patch)
tree781b206e859a4ab2c8e25d73d0303aa718a07957 /torch/backends
parentc26b9c0a5eae5fa505e3c2f93c2d7545cbf27c54 (diff)
downloadpytorch-04aba1caec28dcc7517c85b40addb5a7741683fa.tar.gz
pytorch-04aba1caec28dcc7517c85b40addb5a7741683fa.tar.bz2
pytorch-04aba1caec28dcc7517c85b40addb5a7741683fa.zip
Fix cuDNN dropout desc for multi-gpu (#772)
Diffstat (limited to 'torch/backends')
-rw-r--r--torch/backends/cudnn/rnn.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/torch/backends/cudnn/rnn.py b/torch/backends/cudnn/rnn.py
index 6474248a8c..03e309a2e6 100644
--- a/torch/backends/cudnn/rnn.py
+++ b/torch/backends/cudnn/rnn.py
@@ -47,7 +47,7 @@ def init_rnn_descriptor(fn, handle):
handle,
fn.hidden_size,
fn.num_layers,
- fn.dropout_state['desc'].get(),
+ fn.dropout_state['desc_' + str(torch.cuda.current_device())].get(),
fn.input_mode,
fn.bidirectional,
fn.mode,
@@ -217,8 +217,9 @@ def forward(fn, input, hx, weight, output, hy):
y = output
# init descriptors
- if ('desc' not in fn.dropout_state) or (fn.dropout_state['desc'].get() is None):
- fn.dropout_state['desc'] = Unserializable(
+ desc_name = 'desc_' + str(torch.cuda.current_device())
+ if (desc_name not in fn.dropout_state) or (fn.dropout_state[desc_name].get() is None):
+ fn.dropout_state[desc_name] = Unserializable(
init_dropout_descriptor(fn, handle)
)
fn.rnn_desc = init_rnn_descriptor(fn, handle)