summaryrefslogtreecommitdiff
path: root/numpy/fft/helper.py
diff options
context:
space:
mode:
authorGaspar Karm <gkarm@live.com>2018-01-16 19:15:10 +0200
committerEric Wieser <wieser.eric@gmail.com>2018-01-16 09:15:10 -0800
commitb83e9629dc09e3487a6e3f8ba9eeb3288f6add9b (patch)
tree3b14610956b4cd6c450dfdb05d986314d4c52930 /numpy/fft/helper.py
parentaf75d4ca337dc95520d3ccb067e2c66809e8334e (diff)
downloadpython-numpy-b83e9629dc09e3487a6e3f8ba9eeb3288f6add9b.tar.gz
python-numpy-b83e9629dc09e3487a6e3f8ba9eeb3288f6add9b.tar.bz2
python-numpy-b83e9629dc09e3487a6e3f8ba9eeb3288f6add9b.zip
ENH: Implement fft.fftshift/ifftshift with np.roll for improved performance (#10073)
See the PR for benchmarking information
Diffstat (limited to 'numpy/fft/helper.py')
-rw-r--r--numpy/fft/helper.py43
1 files changed, 17 insertions, 26 deletions
diff --git a/numpy/fft/helper.py b/numpy/fft/helper.py
index 0856d6759..1a1266e12 100644
--- a/numpy/fft/helper.py
+++ b/numpy/fft/helper.py
@@ -6,11 +6,8 @@ from __future__ import division, absolute_import, print_function
import collections
import threading
-
from numpy.compat import integer_types
-from numpy.core import (
- asarray, concatenate, arange, take, integer, empty
- )
+from numpy.core import integer, empty, arange, asarray, roll
# Created by Pearu Peterson, September 2002
@@ -63,19 +60,16 @@ def fftshift(x, axes=None):
[-1., -3., -2.]])
"""
- tmp = asarray(x)
- ndim = tmp.ndim
+ x = asarray(x)
if axes is None:
- axes = list(range(ndim))
+ axes = tuple(range(x.ndim))
+ shift = [dim // 2 for dim in x.shape]
elif isinstance(axes, integer_types):
- axes = (axes,)
- y = tmp
- for k in axes:
- n = tmp.shape[k]
- p2 = (n+1)//2
- mylist = concatenate((arange(p2, n), arange(p2)))
- y = take(y, mylist, k)
- return y
+ shift = x.shape[axes] // 2
+ else:
+ shift = [x.shape[ax] // 2 for ax in axes]
+
+ return roll(x, shift, axes)
def ifftshift(x, axes=None):
@@ -112,19 +106,16 @@ def ifftshift(x, axes=None):
[-3., -2., -1.]])
"""
- tmp = asarray(x)
- ndim = tmp.ndim
+ x = asarray(x)
if axes is None:
- axes = list(range(ndim))
+ axes = tuple(range(x.ndim))
+ shift = [-(dim // 2) for dim in x.shape]
elif isinstance(axes, integer_types):
- axes = (axes,)
- y = tmp
- for k in axes:
- n = tmp.shape[k]
- p2 = n-(n+1)//2
- mylist = concatenate((arange(p2, n), arange(p2)))
- y = take(y, mylist, k)
- return y
+ shift = -(x.shape[axes] // 2)
+ else:
+ shift = [-(x.shape[ax] // 2) for ax in axes]
+
+ return roll(x, shift, axes)
def fftfreq(n, d=1.0):