diff options
author | Gaspar Karm <gkarm@live.com> | 2018-01-16 19:15:10 +0200 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2018-01-16 09:15:10 -0800 |
commit | b83e9629dc09e3487a6e3f8ba9eeb3288f6add9b (patch) | |
tree | 3b14610956b4cd6c450dfdb05d986314d4c52930 | |
parent | af75d4ca337dc95520d3ccb067e2c66809e8334e (diff) | |
download | python-numpy-b83e9629dc09e3487a6e3f8ba9eeb3288f6add9b.tar.gz python-numpy-b83e9629dc09e3487a6e3f8ba9eeb3288f6add9b.tar.bz2 python-numpy-b83e9629dc09e3487a6e3f8ba9eeb3288f6add9b.zip |
ENH: Implement fft.fftshift/ifftshift with np.roll for improved performance (#10073)
See the PR for benchmarking information
-rw-r--r-- | numpy/fft/helper.py | 43 | ||||
-rw-r--r-- | numpy/fft/tests/test_helper.py | 110 |
2 files changed, 119 insertions, 34 deletions
diff --git a/numpy/fft/helper.py b/numpy/fft/helper.py index 0856d6759..1a1266e12 100644 --- a/numpy/fft/helper.py +++ b/numpy/fft/helper.py @@ -6,11 +6,8 @@ from __future__ import division, absolute_import, print_function import collections import threading - from numpy.compat import integer_types -from numpy.core import ( - asarray, concatenate, arange, take, integer, empty - ) +from numpy.core import integer, empty, arange, asarray, roll # Created by Pearu Peterson, September 2002 @@ -63,19 +60,16 @@ def fftshift(x, axes=None): [-1., -3., -2.]]) """ - tmp = asarray(x) - ndim = tmp.ndim + x = asarray(x) if axes is None: - axes = list(range(ndim)) + axes = tuple(range(x.ndim)) + shift = [dim // 2 for dim in x.shape] elif isinstance(axes, integer_types): - axes = (axes,) - y = tmp - for k in axes: - n = tmp.shape[k] - p2 = (n+1)//2 - mylist = concatenate((arange(p2, n), arange(p2))) - y = take(y, mylist, k) - return y + shift = x.shape[axes] // 2 + else: + shift = [x.shape[ax] // 2 for ax in axes] + + return roll(x, shift, axes) def ifftshift(x, axes=None): @@ -112,19 +106,16 @@ def ifftshift(x, axes=None): [-3., -2., -1.]]) """ - tmp = asarray(x) - ndim = tmp.ndim + x = asarray(x) if axes is None: - axes = list(range(ndim)) + axes = tuple(range(x.ndim)) + shift = [-(dim // 2) for dim in x.shape] elif isinstance(axes, integer_types): - axes = (axes,) - y = tmp - for k in axes: - n = tmp.shape[k] - p2 = n-(n+1)//2 - mylist = concatenate((arange(p2, n), arange(p2))) - y = take(y, mylist, k) - return y + shift = -(x.shape[axes] // 2) + else: + shift = [-(x.shape[ax] // 2) for ax in axes] + + return roll(x, shift, axes) def fftfreq(n, d=1.0): diff --git a/numpy/fft/tests/test_helper.py b/numpy/fft/tests/test_helper.py index f02edf7cc..4a19b8c60 100644 --- a/numpy/fft/tests/test_helper.py +++ b/numpy/fft/tests/test_helper.py @@ -4,13 +4,9 @@ Copied from fftpack.helper by Pearu Peterson, October 2005 """ from __future__ import division, absolute_import, print_function - import numpy as np -from numpy.testing import ( - run_module_suite, assert_array_almost_equal, assert_equal, - ) -from numpy import fft -from numpy import pi +from numpy.testing import run_module_suite, assert_array_almost_equal, assert_equal +from numpy import fft, pi from numpy.fft.helper import _FFTCache @@ -36,10 +32,108 @@ class TestFFTShift(object): shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]] assert_array_almost_equal(fft.fftshift(freqs, axes=(0, 1)), shifted) assert_array_almost_equal(fft.fftshift(freqs, axes=0), - fft.fftshift(freqs, axes=(0,))) + fft.fftshift(freqs, axes=(0,))) assert_array_almost_equal(fft.ifftshift(shifted, axes=(0, 1)), freqs) assert_array_almost_equal(fft.ifftshift(shifted, axes=0), - fft.ifftshift(shifted, axes=(0,))) + fft.ifftshift(shifted, axes=(0,))) + + assert_array_almost_equal(fft.fftshift(freqs), shifted) + assert_array_almost_equal(fft.ifftshift(shifted), freqs) + + def test_uneven_dims(self): + """ Test 2D input, which has uneven dimension sizes """ + freqs = [ + [0, 1], + [2, 3], + [4, 5] + ] + + # shift in dimension 0 + shift_dim0 = [ + [4, 5], + [0, 1], + [2, 3] + ] + assert_array_almost_equal(fft.fftshift(freqs, axes=0), shift_dim0) + assert_array_almost_equal(fft.ifftshift(shift_dim0, axes=0), freqs) + assert_array_almost_equal(fft.fftshift(freqs, axes=(0,)), shift_dim0) + assert_array_almost_equal(fft.ifftshift(shift_dim0, axes=[0]), freqs) + + # shift in dimension 1 + shift_dim1 = [ + [1, 0], + [3, 2], + [5, 4] + ] + assert_array_almost_equal(fft.fftshift(freqs, axes=1), shift_dim1) + assert_array_almost_equal(fft.ifftshift(shift_dim1, axes=1), freqs) + + # shift in both dimensions + shift_dim_both = [ + [5, 4], + [1, 0], + [3, 2] + ] + assert_array_almost_equal(fft.fftshift(freqs, axes=(0, 1)), shift_dim_both) + assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=(0, 1)), freqs) + assert_array_almost_equal(fft.fftshift(freqs, axes=[0, 1]), shift_dim_both) + assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=[0, 1]), freqs) + + # axes=None (default) shift in all dimensions + assert_array_almost_equal(fft.fftshift(freqs, axes=None), shift_dim_both) + assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=None), freqs) + assert_array_almost_equal(fft.fftshift(freqs), shift_dim_both) + assert_array_almost_equal(fft.ifftshift(shift_dim_both), freqs) + + def test_equal_to_original(self): + """ Test that the new (>=v1.15) implementation (see #10073) is equal to the original (<=v1.14) """ + from numpy.compat import integer_types + from numpy.core import asarray, concatenate, arange, take + + def original_fftshift(x, axes=None): + """ How fftshift was implemented in v1.14""" + tmp = asarray(x) + ndim = tmp.ndim + if axes is None: + axes = list(range(ndim)) + elif isinstance(axes, integer_types): + axes = (axes,) + y = tmp + for k in axes: + n = tmp.shape[k] + p2 = (n + 1) // 2 + mylist = concatenate((arange(p2, n), arange(p2))) + y = take(y, mylist, k) + return y + + def original_ifftshift(x, axes=None): + """ How ifftshift was implemented in v1.14 """ + tmp = asarray(x) + ndim = tmp.ndim + if axes is None: + axes = list(range(ndim)) + elif isinstance(axes, integer_types): + axes = (axes,) + y = tmp + for k in axes: + n = tmp.shape[k] + p2 = n - (n + 1) // 2 + mylist = concatenate((arange(p2, n), arange(p2))) + y = take(y, mylist, k) + return y + + # create possible 2d array combinations and try all possible keywords + # compare output to original functions + for i in range(16): + for j in range(16): + for axes_keyword in [0, 1, None, (0,), (0, 1)]: + inp = np.random.rand(i, j) + + assert_array_almost_equal(fft.fftshift(inp, axes_keyword), + original_fftshift(inp, axes_keyword)) + + assert_array_almost_equal(fft.ifftshift(inp, axes_keyword), + original_ifftshift(inp, axes_keyword)) class TestFFTFreq(object): |