diff options
author | Antony Lee <anntzer.lee@gmail.com> | 2016-02-19 15:02:35 -0800 |
---|---|---|
committer | Antony Lee <anntzer.lee@gmail.com> | 2016-02-19 17:29:49 -0800 |
commit | 91a86f6715604183741f84d429a3a5c2fc7d7e9e (patch) | |
tree | 1ac82c3c1fdbf9fdb0731301d535e3ed07bcd332 | |
parent | 6d3b34fed6d5c1af6eb02cd47d31ee48a15b582d (diff) | |
download | python-numpy-91a86f6715604183741f84d429a3a5c2fc7d7e9e.tar.gz python-numpy-91a86f6715604183741f84d429a3a5c2fc7d7e9e.tar.bz2 python-numpy-91a86f6715604183741f84d429a3a5c2fc7d7e9e.zip |
Clarify error on repr failure in assert_equal.
assert_equal(np.array([0, 1]), np.matrix([0, 1]))
used to print
x: array([0, 1])
y: [repr failed]
now prints
x: array([0, 1])
y: [repr failed for <matrix>: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()]
-rw-r--r-- | numpy/testing/tests/test_utils.py | 13 | ||||
-rw-r--r-- | numpy/testing/utils.py | 4 |
2 files changed, 15 insertions, 2 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 7de57d408..fe1f411c4 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -227,6 +227,19 @@ class TestEqual(TestArrayEqual): self._assert_func(x, x) self._test_not_equal(x, y) + def test_error_message(self): + try: + self._assert_func(np.array([1, 2]), np.matrix([1, 2])) + except AssertionError as e: + self.assertEqual( + str(e), + "\nArrays are not equal\n\n" + "(shapes (2,), (1, 2) mismatch)\n" + " x: array([1, 2])\n" + " y: [repr failed for <matrix>: The truth value of an array " + "with more than one element is ambiguous. Use a.any() or " + "a.all()]") + class TestArrayAlmostEqual(_GenericTest, unittest.TestCase): diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index f2588788d..133330a12 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -252,8 +252,8 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:', try: r = r_func(a) - except: - r = '[repr failed]' + except Exception as exc: + r = '[repr failed for <{}>: {}]'.format(type(a).__name__, exc) if r.count('\n') > 3: r = '\n'.join(r.splitlines()[:3]) r += '...' |