diff options
author | Christian Sarofeen <csarofeen@nvidia.com> | 2017-02-17 10:16:12 -0800 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2017-02-17 19:16:12 +0100 |
commit | 04aba1caec28dcc7517c85b40addb5a7741683fa (patch) | |
tree | 781b206e859a4ab2c8e25d73d0303aa718a07957 /torch/backends | |
parent | c26b9c0a5eae5fa505e3c2f93c2d7545cbf27c54 (diff) | |
download | pytorch-04aba1caec28dcc7517c85b40addb5a7741683fa.tar.gz pytorch-04aba1caec28dcc7517c85b40addb5a7741683fa.tar.bz2 pytorch-04aba1caec28dcc7517c85b40addb5a7741683fa.zip |
Fix cuDNN dropout desc for multi-gpu (#772)
Diffstat (limited to 'torch/backends')
-rw-r--r-- | torch/backends/cudnn/rnn.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/torch/backends/cudnn/rnn.py b/torch/backends/cudnn/rnn.py index 6474248a8c..03e309a2e6 100644 --- a/torch/backends/cudnn/rnn.py +++ b/torch/backends/cudnn/rnn.py @@ -47,7 +47,7 @@ def init_rnn_descriptor(fn, handle): handle, fn.hidden_size, fn.num_layers, - fn.dropout_state['desc'].get(), + fn.dropout_state['desc_' + str(torch.cuda.current_device())].get(), fn.input_mode, fn.bidirectional, fn.mode, @@ -217,8 +217,9 @@ def forward(fn, input, hx, weight, output, hy): y = output # init descriptors - if ('desc' not in fn.dropout_state) or (fn.dropout_state['desc'].get() is None): - fn.dropout_state['desc'] = Unserializable( + desc_name = 'desc_' + str(torch.cuda.current_device()) + if (desc_name not in fn.dropout_state) or (fn.dropout_state[desc_name].get() is None): + fn.dropout_state[desc_name] = Unserializable( init_dropout_descriptor(fn, handle) ) fn.rnn_desc = init_rnn_descriptor(fn, handle) |