summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGaspar Karm <gkarm@live.com>2018-01-16 19:15:10 +0200
committerEric Wieser <wieser.eric@gmail.com>2018-01-16 09:15:10 -0800
commitb83e9629dc09e3487a6e3f8ba9eeb3288f6add9b (patch)
tree3b14610956b4cd6c450dfdb05d986314d4c52930
parentaf75d4ca337dc95520d3ccb067e2c66809e8334e (diff)
downloadpython-numpy-b83e9629dc09e3487a6e3f8ba9eeb3288f6add9b.tar.gz
python-numpy-b83e9629dc09e3487a6e3f8ba9eeb3288f6add9b.tar.bz2
python-numpy-b83e9629dc09e3487a6e3f8ba9eeb3288f6add9b.zip
ENH: Implement fft.fftshift/ifftshift with np.roll for improved performance (#10073)
See the PR for benchmarking information
-rw-r--r--numpy/fft/helper.py43
-rw-r--r--numpy/fft/tests/test_helper.py110
2 files changed, 119 insertions, 34 deletions
diff --git a/numpy/fft/helper.py b/numpy/fft/helper.py
index 0856d6759..1a1266e12 100644
--- a/numpy/fft/helper.py
+++ b/numpy/fft/helper.py
@@ -6,11 +6,8 @@ from __future__ import division, absolute_import, print_function
import collections
import threading
-
from numpy.compat import integer_types
-from numpy.core import (
- asarray, concatenate, arange, take, integer, empty
- )
+from numpy.core import integer, empty, arange, asarray, roll
# Created by Pearu Peterson, September 2002
@@ -63,19 +60,16 @@ def fftshift(x, axes=None):
[-1., -3., -2.]])
"""
- tmp = asarray(x)
- ndim = tmp.ndim
+ x = asarray(x)
if axes is None:
- axes = list(range(ndim))
+ axes = tuple(range(x.ndim))
+ shift = [dim // 2 for dim in x.shape]
elif isinstance(axes, integer_types):
- axes = (axes,)
- y = tmp
- for k in axes:
- n = tmp.shape[k]
- p2 = (n+1)//2
- mylist = concatenate((arange(p2, n), arange(p2)))
- y = take(y, mylist, k)
- return y
+ shift = x.shape[axes] // 2
+ else:
+ shift = [x.shape[ax] // 2 for ax in axes]
+
+ return roll(x, shift, axes)
def ifftshift(x, axes=None):
@@ -112,19 +106,16 @@ def ifftshift(x, axes=None):
[-3., -2., -1.]])
"""
- tmp = asarray(x)
- ndim = tmp.ndim
+ x = asarray(x)
if axes is None:
- axes = list(range(ndim))
+ axes = tuple(range(x.ndim))
+ shift = [-(dim // 2) for dim in x.shape]
elif isinstance(axes, integer_types):
- axes = (axes,)
- y = tmp
- for k in axes:
- n = tmp.shape[k]
- p2 = n-(n+1)//2
- mylist = concatenate((arange(p2, n), arange(p2)))
- y = take(y, mylist, k)
- return y
+ shift = -(x.shape[axes] // 2)
+ else:
+ shift = [-(x.shape[ax] // 2) for ax in axes]
+
+ return roll(x, shift, axes)
def fftfreq(n, d=1.0):
diff --git a/numpy/fft/tests/test_helper.py b/numpy/fft/tests/test_helper.py
index f02edf7cc..4a19b8c60 100644
--- a/numpy/fft/tests/test_helper.py
+++ b/numpy/fft/tests/test_helper.py
@@ -4,13 +4,9 @@ Copied from fftpack.helper by Pearu Peterson, October 2005
"""
from __future__ import division, absolute_import, print_function
-
import numpy as np
-from numpy.testing import (
- run_module_suite, assert_array_almost_equal, assert_equal,
- )
-from numpy import fft
-from numpy import pi
+from numpy.testing import run_module_suite, assert_array_almost_equal, assert_equal
+from numpy import fft, pi
from numpy.fft.helper import _FFTCache
@@ -36,10 +32,108 @@ class TestFFTShift(object):
shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
assert_array_almost_equal(fft.fftshift(freqs, axes=(0, 1)), shifted)
assert_array_almost_equal(fft.fftshift(freqs, axes=0),
- fft.fftshift(freqs, axes=(0,)))
+ fft.fftshift(freqs, axes=(0,)))
assert_array_almost_equal(fft.ifftshift(shifted, axes=(0, 1)), freqs)
assert_array_almost_equal(fft.ifftshift(shifted, axes=0),
- fft.ifftshift(shifted, axes=(0,)))
+ fft.ifftshift(shifted, axes=(0,)))
+
+ assert_array_almost_equal(fft.fftshift(freqs), shifted)
+ assert_array_almost_equal(fft.ifftshift(shifted), freqs)
+
+ def test_uneven_dims(self):
+ """ Test 2D input, which has uneven dimension sizes """
+ freqs = [
+ [0, 1],
+ [2, 3],
+ [4, 5]
+ ]
+
+ # shift in dimension 0
+ shift_dim0 = [
+ [4, 5],
+ [0, 1],
+ [2, 3]
+ ]
+ assert_array_almost_equal(fft.fftshift(freqs, axes=0), shift_dim0)
+ assert_array_almost_equal(fft.ifftshift(shift_dim0, axes=0), freqs)
+ assert_array_almost_equal(fft.fftshift(freqs, axes=(0,)), shift_dim0)
+ assert_array_almost_equal(fft.ifftshift(shift_dim0, axes=[0]), freqs)
+
+ # shift in dimension 1
+ shift_dim1 = [
+ [1, 0],
+ [3, 2],
+ [5, 4]
+ ]
+ assert_array_almost_equal(fft.fftshift(freqs, axes=1), shift_dim1)
+ assert_array_almost_equal(fft.ifftshift(shift_dim1, axes=1), freqs)
+
+ # shift in both dimensions
+ shift_dim_both = [
+ [5, 4],
+ [1, 0],
+ [3, 2]
+ ]
+ assert_array_almost_equal(fft.fftshift(freqs, axes=(0, 1)), shift_dim_both)
+ assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=(0, 1)), freqs)
+ assert_array_almost_equal(fft.fftshift(freqs, axes=[0, 1]), shift_dim_both)
+ assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=[0, 1]), freqs)
+
+ # axes=None (default) shift in all dimensions
+ assert_array_almost_equal(fft.fftshift(freqs, axes=None), shift_dim_both)
+ assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=None), freqs)
+ assert_array_almost_equal(fft.fftshift(freqs), shift_dim_both)
+ assert_array_almost_equal(fft.ifftshift(shift_dim_both), freqs)
+
+ def test_equal_to_original(self):
+ """ Test that the new (>=v1.15) implementation (see #10073) is equal to the original (<=v1.14) """
+ from numpy.compat import integer_types
+ from numpy.core import asarray, concatenate, arange, take
+
+ def original_fftshift(x, axes=None):
+ """ How fftshift was implemented in v1.14"""
+ tmp = asarray(x)
+ ndim = tmp.ndim
+ if axes is None:
+ axes = list(range(ndim))
+ elif isinstance(axes, integer_types):
+ axes = (axes,)
+ y = tmp
+ for k in axes:
+ n = tmp.shape[k]
+ p2 = (n + 1) // 2
+ mylist = concatenate((arange(p2, n), arange(p2)))
+ y = take(y, mylist, k)
+ return y
+
+ def original_ifftshift(x, axes=None):
+ """ How ifftshift was implemented in v1.14 """
+ tmp = asarray(x)
+ ndim = tmp.ndim
+ if axes is None:
+ axes = list(range(ndim))
+ elif isinstance(axes, integer_types):
+ axes = (axes,)
+ y = tmp
+ for k in axes:
+ n = tmp.shape[k]
+ p2 = n - (n + 1) // 2
+ mylist = concatenate((arange(p2, n), arange(p2)))
+ y = take(y, mylist, k)
+ return y
+
+ # create possible 2d array combinations and try all possible keywords
+ # compare output to original functions
+ for i in range(16):
+ for j in range(16):
+ for axes_keyword in [0, 1, None, (0,), (0, 1)]:
+ inp = np.random.rand(i, j)
+
+ assert_array_almost_equal(fft.fftshift(inp, axes_keyword),
+ original_fftshift(inp, axes_keyword))
+
+ assert_array_almost_equal(fft.ifftshift(inp, axes_keyword),
+ original_ifftshift(inp, axes_keyword))
class TestFFTFreq(object):