summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAntony Lee <anntzer.lee@gmail.com>2016-02-19 15:02:35 -0800
committerAntony Lee <anntzer.lee@gmail.com>2016-02-19 17:29:49 -0800
commit91a86f6715604183741f84d429a3a5c2fc7d7e9e (patch)
tree1ac82c3c1fdbf9fdb0731301d535e3ed07bcd332
parent6d3b34fed6d5c1af6eb02cd47d31ee48a15b582d (diff)
downloadpython-numpy-91a86f6715604183741f84d429a3a5c2fc7d7e9e.tar.gz
python-numpy-91a86f6715604183741f84d429a3a5c2fc7d7e9e.tar.bz2
python-numpy-91a86f6715604183741f84d429a3a5c2fc7d7e9e.zip
Clarify error on repr failure in assert_equal.
assert_equal(np.array([0, 1]), np.matrix([0, 1])) used to print x: array([0, 1]) y: [repr failed] now prints x: array([0, 1]) y: [repr failed for <matrix>: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()]
-rw-r--r--numpy/testing/tests/test_utils.py13
-rw-r--r--numpy/testing/utils.py4
2 files changed, 15 insertions, 2 deletions
diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py
index 7de57d408..fe1f411c4 100644
--- a/numpy/testing/tests/test_utils.py
+++ b/numpy/testing/tests/test_utils.py
@@ -227,6 +227,19 @@ class TestEqual(TestArrayEqual):
self._assert_func(x, x)
self._test_not_equal(x, y)
+ def test_error_message(self):
+ try:
+ self._assert_func(np.array([1, 2]), np.matrix([1, 2]))
+ except AssertionError as e:
+ self.assertEqual(
+ str(e),
+ "\nArrays are not equal\n\n"
+ "(shapes (2,), (1, 2) mismatch)\n"
+ " x: array([1, 2])\n"
+ " y: [repr failed for <matrix>: The truth value of an array "
+ "with more than one element is ambiguous. Use a.any() or "
+ "a.all()]")
+
class TestArrayAlmostEqual(_GenericTest, unittest.TestCase):
diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py
index f2588788d..133330a12 100644
--- a/numpy/testing/utils.py
+++ b/numpy/testing/utils.py
@@ -252,8 +252,8 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:',
try:
r = r_func(a)
- except:
- r = '[repr failed]'
+ except Exception as exc:
+ r = '[repr failed for <{}>: {}]'.format(type(a).__name__, exc)
if r.count('\n') > 3:
r = '\n'.join(r.splitlines()[:3])
r += '...'