summaryrefslogtreecommitdiff
path: root/test/test_nn.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_nn.py')
-rw-r--r--test/test_nn.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/test/test_nn.py b/test/test_nn.py
index cc966d0750..82e740bae0 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -4362,7 +4362,10 @@ class TestNN(NNTestCase):
self.assertTrue(packed_enforce_sorted.sorted_indices is None)
self.assertTrue(packed_enforce_sorted.unsorted_indices is None)
- with self.assertRaisesRegex(RuntimeError, 'has to be sorted in decreasing order'):
+ with self.assertRaisesRegex(RuntimeError, 'must be sorted in decreasing order'):
+ rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
+
+ with self.assertRaisesRegex(RuntimeError, 'You can pass `enforce_sorted=False`'):
rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
# more dimensions
@@ -4456,6 +4459,10 @@ class TestNN(NNTestCase):
if l < 10:
self.assertEqual(padded.grad.data[l:, i].abs().sum(), 0)
+ # test error message
+ with self.assertRaisesRegex(RuntimeError, 'You can pass `enforce_sorted=False`'):
+ packed = rnn_utils.pack_padded_sequence(torch.randn(3, 3), [1, 3, 2])
+
def _test_variable_sequence(self, device="cpu", dtype=torch.float):
def pad(var, length):
if var.size(0) == length: