diff options
author | Gaspar Karm <gkarm@live.com> | 2018-01-16 19:15:10 +0200 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2018-01-16 09:15:10 -0800 |
commit | b83e9629dc09e3487a6e3f8ba9eeb3288f6add9b (patch) | |
tree | 3b14610956b4cd6c450dfdb05d986314d4c52930 /numpy/fft/helper.py | |
parent | af75d4ca337dc95520d3ccb067e2c66809e8334e (diff) | |
download | python-numpy-b83e9629dc09e3487a6e3f8ba9eeb3288f6add9b.tar.gz python-numpy-b83e9629dc09e3487a6e3f8ba9eeb3288f6add9b.tar.bz2 python-numpy-b83e9629dc09e3487a6e3f8ba9eeb3288f6add9b.zip |
ENH: Implement fft.fftshift/ifftshift with np.roll for improved performance (#10073)
See the PR for benchmarking information
Diffstat (limited to 'numpy/fft/helper.py')
-rw-r--r-- | numpy/fft/helper.py | 43 |
1 files changed, 17 insertions, 26 deletions
diff --git a/numpy/fft/helper.py b/numpy/fft/helper.py index 0856d6759..1a1266e12 100644 --- a/numpy/fft/helper.py +++ b/numpy/fft/helper.py @@ -6,11 +6,8 @@ from __future__ import division, absolute_import, print_function import collections import threading - from numpy.compat import integer_types -from numpy.core import ( - asarray, concatenate, arange, take, integer, empty - ) +from numpy.core import integer, empty, arange, asarray, roll # Created by Pearu Peterson, September 2002 @@ -63,19 +60,16 @@ def fftshift(x, axes=None): [-1., -3., -2.]]) """ - tmp = asarray(x) - ndim = tmp.ndim + x = asarray(x) if axes is None: - axes = list(range(ndim)) + axes = tuple(range(x.ndim)) + shift = [dim // 2 for dim in x.shape] elif isinstance(axes, integer_types): - axes = (axes,) - y = tmp - for k in axes: - n = tmp.shape[k] - p2 = (n+1)//2 - mylist = concatenate((arange(p2, n), arange(p2))) - y = take(y, mylist, k) - return y + shift = x.shape[axes] // 2 + else: + shift = [x.shape[ax] // 2 for ax in axes] + + return roll(x, shift, axes) def ifftshift(x, axes=None): @@ -112,19 +106,16 @@ def ifftshift(x, axes=None): [-3., -2., -1.]]) """ - tmp = asarray(x) - ndim = tmp.ndim + x = asarray(x) if axes is None: - axes = list(range(ndim)) + axes = tuple(range(x.ndim)) + shift = [-(dim // 2) for dim in x.shape] elif isinstance(axes, integer_types): - axes = (axes,) - y = tmp - for k in axes: - n = tmp.shape[k] - p2 = n-(n+1)//2 - mylist = concatenate((arange(p2, n), arange(p2))) - y = take(y, mylist, k) - return y + shift = -(x.shape[axes] // 2) + else: + shift = [-(x.shape[ax] // 2) for ax in axes] + + return roll(x, shift, axes) def fftfreq(n, d=1.0): |