summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAllan Haldane <allan.haldane@gmail.com>2018-02-25 14:52:41 -0500
committerAllan Haldane <allan.haldane@gmail.com>2018-03-05 22:50:47 -0500
commit04d2f0494f0fef2ede1461053b6cfc9bd37aaf2f (patch)
treed4c4e0063ce092a2552c6358a03fcce717a8fd61
parent400607bdd44d7ad23a3fe666c796e9893b2bed46 (diff)
downloadpython-numpy-04d2f0494f0fef2ede1461053b6cfc9bd37aaf2f.tar.gz
python-numpy-04d2f0494f0fef2ede1461053b6cfc9bd37aaf2f.tar.bz2
python-numpy-04d2f0494f0fef2ede1461053b6cfc9bd37aaf2f.zip
BUG: Further back-compat fix for subclassed array repr
Fixes #10663
-rw-r--r--numpy/core/arrayprint.py7
-rw-r--r--numpy/core/tests/test_arrayprint.py53
2 files changed, 47 insertions, 13 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py
index cbe95f51b..7dc73d6de 100644
--- a/numpy/core/arrayprint.py
+++ b/numpy/core/arrayprint.py
@@ -471,14 +471,15 @@ def _array2string(a, options, separator=' ', prefix=""):
# The formatter __init__s in _get_format_function cannot deal with
# subclasses yet, and we also need to avoid recursion issues in
# _formatArray with subclasses which return 0d arrays in place of scalars
- a = asarray(a)
+ data = asarray(a)
+ if a.shape == ():
+ a = data
if a.size > options['threshold']:
summary_insert = "..."
- data = _leading_trailing(a, options['edgeitems'])
+ data = _leading_trailing(data, options['edgeitems'])
else:
summary_insert = ""
- data = a
# find the right formatting function for the array
format_function = _get_format_function(data, **options)
diff --git a/numpy/core/tests/test_arrayprint.py b/numpy/core/tests/test_arrayprint.py
index 88aaa3403..309df8545 100644
--- a/numpy/core/tests/test_arrayprint.py
+++ b/numpy/core/tests/test_arrayprint.py
@@ -5,7 +5,7 @@ import sys, gc
import numpy as np
from numpy.testing import (
- run_module_suite, assert_, assert_equal, assert_raises, assert_warns
+ run_module_suite, assert_, assert_equal, assert_raises, assert_warns, dec
)
import textwrap
@@ -34,6 +34,27 @@ class TestArrayRepr(object):
" [(1,), (1,)]], dtype=[('a', '<i4')])"
)
+ @dec.knownfailureif(True, "See gh-10544")
+ def test_object_subclass(self):
+ class sub(np.ndarray):
+ def __new__(cls, inp):
+ obj = np.asarray(inp).view(cls)
+ return obj
+
+ def __getitem__(self, ind):
+ ret = super(sub, self).__getitem__(ind)
+ return sub(ret)
+
+ # test that object + subclass is OK:
+ x = sub([None, None])
+ assert_equal(repr(x), 'sub([None, None], dtype=object)')
+ assert_equal(str(x), '[None None]')
+
+ x = sub([None, sub([None, None])])
+ assert_equal(repr(x),
+ 'sub([None, sub([None, None], dtype=object)], dtype=object)')
+ assert_equal(str(x), '[None sub([None, None], dtype=object)]')
+
def test_0d_object_subclass(self):
# make sure that subclasses which return 0ds instead
# of scalars don't cause infinite recursion in str
@@ -73,15 +94,27 @@ class TestArrayRepr(object):
assert_equal(repr(x), 'sub(sub(None, dtype=object), dtype=object)')
assert_equal(str(x), 'None')
- # test that object + subclass is OK:
- x = sub([None, None])
- assert_equal(repr(x), 'sub([None, None], dtype=object)')
- assert_equal(str(x), '[None None]')
-
- x = sub([None, sub([None, None])])
- assert_equal(repr(x),
- 'sub([None, sub([None, None], dtype=object)], dtype=object)')
- assert_equal(str(x), '[None sub([None, None], dtype=object)]')
+ # gh-10663
+ class DuckCounter(np.ndarray):
+ def __getitem__(self, item):
+ result = super(DuckCounter, self).__getitem__(item)
+ if not isinstance(result, DuckCounter):
+ result = result[...].view(DuckCounter)
+ return result
+
+ def to_string(self):
+ return {0: 'zero', 1: 'one', 2: 'two'}.get(self.item(), 'many')
+
+ def __str__(self):
+ if self.shape == ():
+ return self.to_string()
+ else:
+ fmt = {'all': lambda x: x.to_string()}
+ return np.array2string(self, formatter=fmt)
+
+ dc = np.arange(5).view(DuckCounter)
+ assert_equal(str(dc), "[zero one two many many]")
+ assert_equal(str(dc[0]), "zero")
def test_self_containing(self):
arr0d = np.array(None)