From beac50cf98f450539dcdeee0273cfe5175d45d26 Mon Sep 17 00:00:00 2001 From: Charles Harris Date: Wed, 17 May 2017 13:02:01 -0600 Subject: DEP: Deprecate incorrect behavior of expand_dims. Expand_dims works as documented when the index of the inserted NewAxis in the resulting array satisfies -a.ndim - 1 <= index <= a.ndim. However, when index > a.ndim index is replaced by a.ndim and, when index < -a.ndim - 1, it is replaced by index + a.ndim + 1, which may be negative and results in incorrect placement. The latter two cases are now deprecated. Closes #9100. --- doc/release/1.13.0-notes.rst | 3 +++ numpy/lib/shape_base.py | 21 ++++++++++++++++++--- numpy/lib/tests/test_shape_base.py | 23 ++++++++++++++++++++++- 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/doc/release/1.13.0-notes.rst b/doc/release/1.13.0-notes.rst index 23173d826..84d1106d6 100644 --- a/doc/release/1.13.0-notes.rst +++ b/doc/release/1.13.0-notes.rst @@ -53,6 +53,9 @@ Deprecations with ``np.minimum``. * Calling ``ndarray.conjugate`` on non-numeric dtypes is deprecated (it should match the behavior of ``np.conjugate``, which throws an error). +* Calling ``expand_dims`` when the ``axis`` keyword does not satisfy + ``-a.ndim - 1 <= axis <= a.ndim``, where ``a`` is the array being reshaped, + is deprecated. Future Changes diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 01d13514a..ea77f40e0 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -240,14 +240,20 @@ def expand_dims(a, axis): """ Expand the shape of an array. - Insert a new axis, corresponding to a given position in the array shape. + Insert a new axis that will appear at the `axis` position in the expanded + array shape. + + .. note:: Previous to NumPy 1.13.0, neither ``axis < -a.ndim - 1`` nor + ``axis > a.ndim`` raised errors or put the new axis where documented. + Those axis values are now deprecated and will raise an AxisError in the + future. Parameters ---------- a : array_like Input array. axis : int - Position (amongst axes) where new axis is to be inserted. + Position in the expanded axes where the new axis is placed. Returns ------- @@ -291,7 +297,16 @@ def expand_dims(a, axis): """ a = asarray(a) shape = a.shape - axis = normalize_axis_index(axis, a.ndim + 1) + if axis > a.ndim or axis < -a.ndim - 1: + # 2017-05-17, 1.13.0 + warnings.warn("Both axis > a.ndim and axis < -a.ndim - 1 are " + "deprecated and will raise an AxisError in the future.", + DeprecationWarning, stacklevel=2) + # When the deprecation period expires, delete this if block, + if axis < 0: + axis = axis + a.ndim + 1 + # and uncomment the following line. + # axis = normalize_axis_index(axis, a.ndim + 1) return a.reshape(shape[:axis] + (1,) + shape[axis:]) row_stack = vstack diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index 4d06001f4..14406fe21 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -1,9 +1,11 @@ from __future__ import division, absolute_import, print_function import numpy as np +import warnings + from numpy.lib.shape_base import ( apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit, - vsplit, dstack, column_stack, kron, tile + vsplit, dstack, column_stack, kron, tile, expand_dims, ) from numpy.testing import ( run_module_suite, TestCase, assert_, assert_equal, assert_array_equal, @@ -182,6 +184,25 @@ class TestApplyOverAxes(TestCase): assert_array_equal(aoa_a, np.array([[[60], [92], [124]]])) +class TestExpandDims(TestCase): + def test_functionality(self): + s = (2, 3, 4, 5) + a = np.empty(s) + for axis in range(-5, 4): + b = expand_dims(a, axis) + assert_(b.shape[axis] == 1) + assert_(np.squeeze(b).shape == s) + + def test_deprecations(self): + # 2017-05-17, 1.13.0 + s = (2, 3, 4, 5) + a = np.empty(s) + with warnings.catch_warnings(): + warnings.simplefilter("always") + assert_warns(DeprecationWarning, expand_dims, a, -6) + assert_warns(DeprecationWarning, expand_dims, a, 5) + + class TestArraySplit(TestCase): def test_integer_0_split(self): a = np.arange(10) -- cgit v1.2.3