diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-03-26 11:22:43 +0100 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-03-28 20:44:49 +0100 |
commit | e3ed705e5d91b584e9191a20f3a4780d354271ff (patch) | |
tree | 3135198039924e2b01a875570e8dc85980af6f22 /numpy/lib | |
parent | 539d4f7ef561ac86ea4f3b81bf1eb9b3ac03b67f (diff) | |
download | python-numpy-e3ed705e5d91b584e9191a20f3a4780d354271ff.tar.gz python-numpy-e3ed705e5d91b584e9191a20f3a4780d354271ff.tar.bz2 python-numpy-e3ed705e5d91b584e9191a20f3a4780d354271ff.zip |
MAINT: Use _validate_axis inside _ureduce
This fixes an omission where duplicate axes would only be detected when positive
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/function_base.py | 31 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 19 | ||||
-rw-r--r-- | numpy/lib/tests/test_nanfunctions.py | 16 |
3 files changed, 33 insertions, 33 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 2a8a13caa..5b3af311b 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -12,7 +12,7 @@ from numpy.core import linspace, atleast_1d, atleast_2d, transpose from numpy.core.numeric import ( ones, zeros, arange, concatenate, array, asarray, asanyarray, empty, empty_like, ndarray, around, floor, ceil, take, dot, where, intp, - integer, isscalar, absolute + integer, isscalar, absolute, AxisError ) from numpy.core.umath import ( pi, multiply, add, arctan2, frompyfunc, cos, less_equal, sqrt, sin, @@ -3972,21 +3972,15 @@ def _ureduce(a, func, **kwargs): if axis is not None: keepdim = list(a.shape) nd = a.ndim - try: - axis = operator.index(axis) - if axis >= nd or axis < -nd: - raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim)) - keepdim[axis] = 1 - except TypeError: - sax = set() - for x in axis: - if x >= nd or x < -nd: - raise IndexError("axis %d out of bounds (%d)" % (x, nd)) - if x in sax: - raise ValueError("duplicate value in axis") - sax.add(x % nd) - keepdim[x] = 1 - keep = sax.symmetric_difference(frozenset(range(nd))) + axis = _nx._validate_axis(axis, nd) + + for ax in axis: + keepdim[ax] = 1 + + if len(axis) == 1: + kwargs['axis'] = axis[0] + else: + keep = set(range(nd)) - set(axis) nkeep = len(keep) # swap axis that should not be reduced to front for i, s in enumerate(sorted(keep)): @@ -4742,7 +4736,8 @@ def delete(arr, obj, axis=None): if ndim != 1: arr = arr.ravel() ndim = arr.ndim - axis = ndim - 1 + axis = -1 + if ndim == 0: # 2013-09-24, 1.9 warnings.warn( @@ -4753,6 +4748,8 @@ def delete(arr, obj, axis=None): else: return arr.copy(order=arrorder) + axis = normalize_axis_index(axis, ndim) + slobj = [slice(None)]*ndim N = arr.shape[axis] newshape = list(arr.shape) diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 6c6ed5941..7c07606f6 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -2973,11 +2973,14 @@ class TestPercentile(TestCase): def test_extended_axis_invalid(self): d = np.ones((3, 5, 7, 11)) - assert_raises(IndexError, np.percentile, d, axis=-5, q=25) - assert_raises(IndexError, np.percentile, d, axis=(0, -5), q=25) - assert_raises(IndexError, np.percentile, d, axis=4, q=25) - assert_raises(IndexError, np.percentile, d, axis=(0, 4), q=25) + assert_raises(np.AxisError, np.percentile, d, axis=-5, q=25) + assert_raises(np.AxisError, np.percentile, d, axis=(0, -5), q=25) + assert_raises(np.AxisError, np.percentile, d, axis=4, q=25) + assert_raises(np.AxisError, np.percentile, d, axis=(0, 4), q=25) + # each of these refers to the same axis twice assert_raises(ValueError, np.percentile, d, axis=(1, 1), q=25) + assert_raises(ValueError, np.percentile, d, axis=(-1, -1), q=25) + assert_raises(ValueError, np.percentile, d, axis=(3, -1), q=25) def test_keepdims(self): d = np.ones((3, 5, 7, 11)) @@ -3349,10 +3352,10 @@ class TestMedian(TestCase): def test_extended_axis_invalid(self): d = np.ones((3, 5, 7, 11)) - assert_raises(IndexError, np.median, d, axis=-5) - assert_raises(IndexError, np.median, d, axis=(0, -5)) - assert_raises(IndexError, np.median, d, axis=4) - assert_raises(IndexError, np.median, d, axis=(0, 4)) + assert_raises(np.AxisError, np.median, d, axis=-5) + assert_raises(np.AxisError, np.median, d, axis=(0, -5)) + assert_raises(np.AxisError, np.median, d, axis=4) + assert_raises(np.AxisError, np.median, d, axis=(0, 4)) assert_raises(ValueError, np.median, d, axis=(1, 1)) def test_keepdims(self): diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py index 2b310457b..1678e1091 100644 --- a/numpy/lib/tests/test_nanfunctions.py +++ b/numpy/lib/tests/test_nanfunctions.py @@ -684,10 +684,10 @@ class TestNanFunctions_Median(TestCase): def test_extended_axis_invalid(self): d = np.ones((3, 5, 7, 11)) - assert_raises(IndexError, np.nanmedian, d, axis=-5) - assert_raises(IndexError, np.nanmedian, d, axis=(0, -5)) - assert_raises(IndexError, np.nanmedian, d, axis=4) - assert_raises(IndexError, np.nanmedian, d, axis=(0, 4)) + assert_raises(np.AxisError, np.nanmedian, d, axis=-5) + assert_raises(np.AxisError, np.nanmedian, d, axis=(0, -5)) + assert_raises(np.AxisError, np.nanmedian, d, axis=4) + assert_raises(np.AxisError, np.nanmedian, d, axis=(0, 4)) assert_raises(ValueError, np.nanmedian, d, axis=(1, 1)) def test_float_special(self): @@ -843,10 +843,10 @@ class TestNanFunctions_Percentile(TestCase): def test_extended_axis_invalid(self): d = np.ones((3, 5, 7, 11)) - assert_raises(IndexError, np.nanpercentile, d, q=5, axis=-5) - assert_raises(IndexError, np.nanpercentile, d, q=5, axis=(0, -5)) - assert_raises(IndexError, np.nanpercentile, d, q=5, axis=4) - assert_raises(IndexError, np.nanpercentile, d, q=5, axis=(0, 4)) + assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=-5) + assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=(0, -5)) + assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=4) + assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=(0, 4)) assert_raises(ValueError, np.nanpercentile, d, q=5, axis=(1, 1)) def test_multiple_percentiles(self): |