summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-03-26 11:22:43 +0100
committerEric Wieser <wieser.eric@gmail.com>2017-03-28 20:44:49 +0100
commite3ed705e5d91b584e9191a20f3a4780d354271ff (patch)
tree3135198039924e2b01a875570e8dc85980af6f22 /numpy/lib
parent539d4f7ef561ac86ea4f3b81bf1eb9b3ac03b67f (diff)
downloadpython-numpy-e3ed705e5d91b584e9191a20f3a4780d354271ff.tar.gz
python-numpy-e3ed705e5d91b584e9191a20f3a4780d354271ff.tar.bz2
python-numpy-e3ed705e5d91b584e9191a20f3a4780d354271ff.zip
MAINT: Use _validate_axis inside _ureduce
This fixes an omission where duplicate axes would only be detected when positive
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py31
-rw-r--r--numpy/lib/tests/test_function_base.py19
-rw-r--r--numpy/lib/tests/test_nanfunctions.py16
3 files changed, 33 insertions, 33 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 2a8a13caa..5b3af311b 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -12,7 +12,7 @@ from numpy.core import linspace, atleast_1d, atleast_2d, transpose
from numpy.core.numeric import (
ones, zeros, arange, concatenate, array, asarray, asanyarray, empty,
empty_like, ndarray, around, floor, ceil, take, dot, where, intp,
- integer, isscalar, absolute
+ integer, isscalar, absolute, AxisError
)
from numpy.core.umath import (
pi, multiply, add, arctan2, frompyfunc, cos, less_equal, sqrt, sin,
@@ -3972,21 +3972,15 @@ def _ureduce(a, func, **kwargs):
if axis is not None:
keepdim = list(a.shape)
nd = a.ndim
- try:
- axis = operator.index(axis)
- if axis >= nd or axis < -nd:
- raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
- keepdim[axis] = 1
- except TypeError:
- sax = set()
- for x in axis:
- if x >= nd or x < -nd:
- raise IndexError("axis %d out of bounds (%d)" % (x, nd))
- if x in sax:
- raise ValueError("duplicate value in axis")
- sax.add(x % nd)
- keepdim[x] = 1
- keep = sax.symmetric_difference(frozenset(range(nd)))
+ axis = _nx._validate_axis(axis, nd)
+
+ for ax in axis:
+ keepdim[ax] = 1
+
+ if len(axis) == 1:
+ kwargs['axis'] = axis[0]
+ else:
+ keep = set(range(nd)) - set(axis)
nkeep = len(keep)
# swap axis that should not be reduced to front
for i, s in enumerate(sorted(keep)):
@@ -4742,7 +4736,8 @@ def delete(arr, obj, axis=None):
if ndim != 1:
arr = arr.ravel()
ndim = arr.ndim
- axis = ndim - 1
+ axis = -1
+
if ndim == 0:
# 2013-09-24, 1.9
warnings.warn(
@@ -4753,6 +4748,8 @@ def delete(arr, obj, axis=None):
else:
return arr.copy(order=arrorder)
+ axis = normalize_axis_index(axis, ndim)
+
slobj = [slice(None)]*ndim
N = arr.shape[axis]
newshape = list(arr.shape)
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 6c6ed5941..7c07606f6 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -2973,11 +2973,14 @@ class TestPercentile(TestCase):
def test_extended_axis_invalid(self):
d = np.ones((3, 5, 7, 11))
- assert_raises(IndexError, np.percentile, d, axis=-5, q=25)
- assert_raises(IndexError, np.percentile, d, axis=(0, -5), q=25)
- assert_raises(IndexError, np.percentile, d, axis=4, q=25)
- assert_raises(IndexError, np.percentile, d, axis=(0, 4), q=25)
+ assert_raises(np.AxisError, np.percentile, d, axis=-5, q=25)
+ assert_raises(np.AxisError, np.percentile, d, axis=(0, -5), q=25)
+ assert_raises(np.AxisError, np.percentile, d, axis=4, q=25)
+ assert_raises(np.AxisError, np.percentile, d, axis=(0, 4), q=25)
+ # each of these refers to the same axis twice
assert_raises(ValueError, np.percentile, d, axis=(1, 1), q=25)
+ assert_raises(ValueError, np.percentile, d, axis=(-1, -1), q=25)
+ assert_raises(ValueError, np.percentile, d, axis=(3, -1), q=25)
def test_keepdims(self):
d = np.ones((3, 5, 7, 11))
@@ -3349,10 +3352,10 @@ class TestMedian(TestCase):
def test_extended_axis_invalid(self):
d = np.ones((3, 5, 7, 11))
- assert_raises(IndexError, np.median, d, axis=-5)
- assert_raises(IndexError, np.median, d, axis=(0, -5))
- assert_raises(IndexError, np.median, d, axis=4)
- assert_raises(IndexError, np.median, d, axis=(0, 4))
+ assert_raises(np.AxisError, np.median, d, axis=-5)
+ assert_raises(np.AxisError, np.median, d, axis=(0, -5))
+ assert_raises(np.AxisError, np.median, d, axis=4)
+ assert_raises(np.AxisError, np.median, d, axis=(0, 4))
assert_raises(ValueError, np.median, d, axis=(1, 1))
def test_keepdims(self):
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py
index 2b310457b..1678e1091 100644
--- a/numpy/lib/tests/test_nanfunctions.py
+++ b/numpy/lib/tests/test_nanfunctions.py
@@ -684,10 +684,10 @@ class TestNanFunctions_Median(TestCase):
def test_extended_axis_invalid(self):
d = np.ones((3, 5, 7, 11))
- assert_raises(IndexError, np.nanmedian, d, axis=-5)
- assert_raises(IndexError, np.nanmedian, d, axis=(0, -5))
- assert_raises(IndexError, np.nanmedian, d, axis=4)
- assert_raises(IndexError, np.nanmedian, d, axis=(0, 4))
+ assert_raises(np.AxisError, np.nanmedian, d, axis=-5)
+ assert_raises(np.AxisError, np.nanmedian, d, axis=(0, -5))
+ assert_raises(np.AxisError, np.nanmedian, d, axis=4)
+ assert_raises(np.AxisError, np.nanmedian, d, axis=(0, 4))
assert_raises(ValueError, np.nanmedian, d, axis=(1, 1))
def test_float_special(self):
@@ -843,10 +843,10 @@ class TestNanFunctions_Percentile(TestCase):
def test_extended_axis_invalid(self):
d = np.ones((3, 5, 7, 11))
- assert_raises(IndexError, np.nanpercentile, d, q=5, axis=-5)
- assert_raises(IndexError, np.nanpercentile, d, q=5, axis=(0, -5))
- assert_raises(IndexError, np.nanpercentile, d, q=5, axis=4)
- assert_raises(IndexError, np.nanpercentile, d, q=5, axis=(0, 4))
+ assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=-5)
+ assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=(0, -5))
+ assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=4)
+ assert_raises(np.AxisError, np.nanpercentile, d, q=5, axis=(0, 4))
assert_raises(ValueError, np.nanpercentile, d, q=5, axis=(1, 1))
def test_multiple_percentiles(self):