diff options
author | Richard Zou <zou3519@gmail.com> | 2019-01-18 07:56:17 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-18 07:58:54 -0800 |
commit | ed0a761c824d5e325e39afe954f92cc14c3c88c1 (patch) | |
tree | 809e3094fc71f57f13509b4be71d185e61b4f0db /test/test_nn.py | |
parent | b4bc55beefda3a0724b0fb83c04b6bbd8dd46c77 (diff) | |
download | pytorch-ed0a761c824d5e325e39afe954f92cc14c3c88c1.tar.gz pytorch-ed0a761c824d5e325e39afe954f92cc14c3c88c1.tar.bz2 pytorch-ed0a761c824d5e325e39afe954f92cc14c3c88c1.zip |
Improve pack_sequence and pack_padded_sequence error message (#16084)
Summary:
Mention that if enforce_sorted=True, the user can set
enforce_sorted=False. This is a new flag that is probably hard to
discover unless one throughly reads the docs.
Fixes #15567
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16084
Differential Revision: D13701118
Pulled By: zou3519
fbshipit-source-id: c9aeb47ae9769d28b0051bcedb8f2f51a5a5c260
Diffstat (limited to 'test/test_nn.py')
-rw-r--r-- | test/test_nn.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/test/test_nn.py b/test/test_nn.py index cc966d0750..82e740bae0 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4362,7 +4362,10 @@ class TestNN(NNTestCase): self.assertTrue(packed_enforce_sorted.sorted_indices is None) self.assertTrue(packed_enforce_sorted.unsorted_indices is None) - with self.assertRaisesRegex(RuntimeError, 'has to be sorted in decreasing order'): + with self.assertRaisesRegex(RuntimeError, 'must be sorted in decreasing order'): + rnn_utils.pack_sequence([b, c, a], enforce_sorted=True) + + with self.assertRaisesRegex(RuntimeError, 'You can pass `enforce_sorted=False`'): rnn_utils.pack_sequence([b, c, a], enforce_sorted=True) # more dimensions @@ -4456,6 +4459,10 @@ class TestNN(NNTestCase): if l < 10: self.assertEqual(padded.grad.data[l:, i].abs().sum(), 0) + # test error message + with self.assertRaisesRegex(RuntimeError, 'You can pass `enforce_sorted=False`'): + packed = rnn_utils.pack_padded_sequence(torch.randn(3, 3), [1, 3, 2]) + def _test_variable_sequence(self, device="cpu", dtype=torch.float): def pad(var, length): if var.size(0) == length: |