summaryrefslogtreecommitdiff
path: root/test/test_nn.py
diff options
context:
space:
mode:
authorRichard Zou <zou3519@gmail.com>2019-01-18 07:56:17 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-18 07:58:54 -0800
commited0a761c824d5e325e39afe954f92cc14c3c88c1 (patch)
tree809e3094fc71f57f13509b4be71d185e61b4f0db /test/test_nn.py
parentb4bc55beefda3a0724b0fb83c04b6bbd8dd46c77 (diff)
downloadpytorch-ed0a761c824d5e325e39afe954f92cc14c3c88c1.tar.gz
pytorch-ed0a761c824d5e325e39afe954f92cc14c3c88c1.tar.bz2
pytorch-ed0a761c824d5e325e39afe954f92cc14c3c88c1.zip
Improve pack_sequence and pack_padded_sequence error message (#16084)
Summary: Mention that if enforce_sorted=True, the user can set enforce_sorted=False. This is a new flag that is probably hard to discover unless one throughly reads the docs. Fixes #15567 Pull Request resolved: https://github.com/pytorch/pytorch/pull/16084 Differential Revision: D13701118 Pulled By: zou3519 fbshipit-source-id: c9aeb47ae9769d28b0051bcedb8f2f51a5a5c260
Diffstat (limited to 'test/test_nn.py')
-rw-r--r--test/test_nn.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/test/test_nn.py b/test/test_nn.py
index cc966d0750..82e740bae0 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -4362,7 +4362,10 @@ class TestNN(NNTestCase):
self.assertTrue(packed_enforce_sorted.sorted_indices is None)
self.assertTrue(packed_enforce_sorted.unsorted_indices is None)
- with self.assertRaisesRegex(RuntimeError, 'has to be sorted in decreasing order'):
+ with self.assertRaisesRegex(RuntimeError, 'must be sorted in decreasing order'):
+ rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
+
+ with self.assertRaisesRegex(RuntimeError, 'You can pass `enforce_sorted=False`'):
rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
# more dimensions
@@ -4456,6 +4459,10 @@ class TestNN(NNTestCase):
if l < 10:
self.assertEqual(padded.grad.data[l:, i].abs().sum(), 0)
+ # test error message
+ with self.assertRaisesRegex(RuntimeError, 'You can pass `enforce_sorted=False`'):
+ packed = rnn_utils.pack_padded_sequence(torch.randn(3, 3), [1, 3, 2])
+
def _test_variable_sequence(self, device="cpu", dtype=torch.float):
def pad(var, length):
if var.size(0) == length: