summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2017-05-17 13:02:01 -0600
committerCharles Harris <charlesr.harris@gmail.com>2017-05-17 19:03:15 -0600
commitbeac50cf98f450539dcdeee0273cfe5175d45d26 (patch)
tree52ebc11c1a4130592bc227eb385fb513073a7c4a
parentb9e3ac9abb6e435cdf6bbe33e0bc894d6a879a53 (diff)
downloadpython-numpy-beac50cf98f450539dcdeee0273cfe5175d45d26.tar.gz
python-numpy-beac50cf98f450539dcdeee0273cfe5175d45d26.tar.bz2
python-numpy-beac50cf98f450539dcdeee0273cfe5175d45d26.zip
DEP: Deprecate incorrect behavior of expand_dims.
Expand_dims works as documented when the index of the inserted NewAxis in the resulting array satisfies -a.ndim - 1 <= index <= a.ndim. However, when index > a.ndim index is replaced by a.ndim and, when index < -a.ndim - 1, it is replaced by index + a.ndim + 1, which may be negative and results in incorrect placement. The latter two cases are now deprecated. Closes #9100.
-rw-r--r--doc/release/1.13.0-notes.rst3
-rw-r--r--numpy/lib/shape_base.py21
-rw-r--r--numpy/lib/tests/test_shape_base.py23
3 files changed, 43 insertions, 4 deletions
diff --git a/doc/release/1.13.0-notes.rst b/doc/release/1.13.0-notes.rst
index 23173d826..84d1106d6 100644
--- a/doc/release/1.13.0-notes.rst
+++ b/doc/release/1.13.0-notes.rst
@@ -53,6 +53,9 @@ Deprecations
with ``np.minimum``.
* Calling ``ndarray.conjugate`` on non-numeric dtypes is deprecated (it
should match the behavior of ``np.conjugate``, which throws an error).
+* Calling ``expand_dims`` when the ``axis`` keyword does not satisfy
+ ``-a.ndim - 1 <= axis <= a.ndim``, where ``a`` is the array being reshaped,
+ is deprecated.
Future Changes
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index 01d13514a..ea77f40e0 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -240,14 +240,20 @@ def expand_dims(a, axis):
"""
Expand the shape of an array.
- Insert a new axis, corresponding to a given position in the array shape.
+ Insert a new axis that will appear at the `axis` position in the expanded
+ array shape.
+
+ .. note:: Previous to NumPy 1.13.0, neither ``axis < -a.ndim - 1`` nor
+ ``axis > a.ndim`` raised errors or put the new axis where documented.
+ Those axis values are now deprecated and will raise an AxisError in the
+ future.
Parameters
----------
a : array_like
Input array.
axis : int
- Position (amongst axes) where new axis is to be inserted.
+ Position in the expanded axes where the new axis is placed.
Returns
-------
@@ -291,7 +297,16 @@ def expand_dims(a, axis):
"""
a = asarray(a)
shape = a.shape
- axis = normalize_axis_index(axis, a.ndim + 1)
+ if axis > a.ndim or axis < -a.ndim - 1:
+ # 2017-05-17, 1.13.0
+ warnings.warn("Both axis > a.ndim and axis < -a.ndim - 1 are "
+ "deprecated and will raise an AxisError in the future.",
+ DeprecationWarning, stacklevel=2)
+ # When the deprecation period expires, delete this if block,
+ if axis < 0:
+ axis = axis + a.ndim + 1
+ # and uncomment the following line.
+ # axis = normalize_axis_index(axis, a.ndim + 1)
return a.reshape(shape[:axis] + (1,) + shape[axis:])
row_stack = vstack
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index 4d06001f4..14406fe21 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -1,9 +1,11 @@
from __future__ import division, absolute_import, print_function
import numpy as np
+import warnings
+
from numpy.lib.shape_base import (
apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit,
- vsplit, dstack, column_stack, kron, tile
+ vsplit, dstack, column_stack, kron, tile, expand_dims,
)
from numpy.testing import (
run_module_suite, TestCase, assert_, assert_equal, assert_array_equal,
@@ -182,6 +184,25 @@ class TestApplyOverAxes(TestCase):
assert_array_equal(aoa_a, np.array([[[60], [92], [124]]]))
+class TestExpandDims(TestCase):
+ def test_functionality(self):
+ s = (2, 3, 4, 5)
+ a = np.empty(s)
+ for axis in range(-5, 4):
+ b = expand_dims(a, axis)
+ assert_(b.shape[axis] == 1)
+ assert_(np.squeeze(b).shape == s)
+
+ def test_deprecations(self):
+ # 2017-05-17, 1.13.0
+ s = (2, 3, 4, 5)
+ a = np.empty(s)
+ with warnings.catch_warnings():
+ warnings.simplefilter("always")
+ assert_warns(DeprecationWarning, expand_dims, a, -6)
+ assert_warns(DeprecationWarning, expand_dims, a, 5)
+
+
class TestArraySplit(TestCase):
def test_integer_0_split(self):
a = np.arange(10)