diff options
Diffstat (limited to 'test/test_nn.py')
-rw-r--r-- | test/test_nn.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/test/test_nn.py b/test/test_nn.py index cc966d0750..82e740bae0 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4362,7 +4362,10 @@ class TestNN(NNTestCase): self.assertTrue(packed_enforce_sorted.sorted_indices is None) self.assertTrue(packed_enforce_sorted.unsorted_indices is None) - with self.assertRaisesRegex(RuntimeError, 'has to be sorted in decreasing order'): + with self.assertRaisesRegex(RuntimeError, 'must be sorted in decreasing order'): + rnn_utils.pack_sequence([b, c, a], enforce_sorted=True) + + with self.assertRaisesRegex(RuntimeError, 'You can pass `enforce_sorted=False`'): rnn_utils.pack_sequence([b, c, a], enforce_sorted=True) # more dimensions @@ -4456,6 +4459,10 @@ class TestNN(NNTestCase): if l < 10: self.assertEqual(padded.grad.data[l:, i].abs().sum(), 0) + # test error message + with self.assertRaisesRegex(RuntimeError, 'You can pass `enforce_sorted=False`'): + packed = rnn_utils.pack_padded_sequence(torch.randn(3, 3), [1, 3, 2]) + def _test_variable_sequence(self, device="cpu", dtype=torch.float): def pad(var, length): if var.size(0) == length: |