summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2016-10-28 11:08:40 -0400
committerGitHub <noreply@github.com>2016-10-28 11:08:40 -0400
commite908bfa9977a45b311ef09f03551aa1780686739 (patch)
tree921f6aa27d01ba8817089b092093726bae962111
parentf303ccebaa2953e506c4991825cda2ac8b5d2fb9 (diff)
parent9a90abf995d0d8d9e96992a083dc55a41a93254f (diff)
downloadpython-numpy-e908bfa9977a45b311ef09f03551aa1780686739.tar.gz
python-numpy-e908bfa9977a45b311ef09f03551aa1780686739.tar.bz2
python-numpy-e908bfa9977a45b311ef09f03551aa1780686739.zip
Merge pull request #8218 from mattharrigan/ediff1d-performance
BUG: ediff1d should return subclasses
-rw-r--r--numpy/lib/arraysetops.py27
-rw-r--r--numpy/lib/tests/test_arraysetops.py2
2 files changed, 17 insertions, 12 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py
index e63e09546..836f4583f 100644
--- a/numpy/lib/arraysetops.py
+++ b/numpy/lib/arraysetops.py
@@ -81,28 +81,31 @@ def ediff1d(ary, to_end=None, to_begin=None):
# force a 1d array
ary = np.asanyarray(ary).ravel()
- # get the length of the diff'd values
- l = len(ary) - 1
- if l < 0:
- # force length to be non negative, match previous API
- # should this be an warning or deprecated?
- l = 0
+ # fast track default case
+ if to_begin is None and to_end is None:
+ return ary[1:] - ary[:-1]
if to_begin is None:
- to_begin = np.array([])
+ l_begin = 0
else:
to_begin = np.asanyarray(to_begin).ravel()
+ l_begin = len(to_begin)
if to_end is None:
- to_end = np.array([])
+ l_end = 0
else:
to_end = np.asanyarray(to_end).ravel()
+ l_end = len(to_end)
# do the calculation in place and copy to_begin and to_end
- result = np.empty(l + len(to_begin) + len(to_end), dtype=ary.dtype)
- result[:len(to_begin)] = to_begin
- result[len(to_begin) + l:] = to_end
- np.subtract(ary[1:], ary[:-1], result[len(to_begin):len(to_begin) + l])
+ l_diff = max(len(ary) - 1, 0)
+ result = np.empty(l_diff + l_begin + l_end, dtype=ary.dtype)
+ result = ary.__array_wrap__(result)
+ if l_begin > 0:
+ result[:l_begin] = to_begin
+ if l_end > 0:
+ result[l_begin + l_diff:] = to_end
+ np.subtract(ary[1:], ary[:-1], result[l_begin:l_begin + l_diff])
return result
diff --git a/numpy/lib/tests/test_arraysetops.py b/numpy/lib/tests/test_arraysetops.py
index b75a2b060..75918fbee 100644
--- a/numpy/lib/tests/test_arraysetops.py
+++ b/numpy/lib/tests/test_arraysetops.py
@@ -175,6 +175,8 @@ class TestSetOps(TestCase):
assert_array_equal([1,7,8], ediff1d(two_elem, to_end=[7,8]))
assert_array_equal([7,1], ediff1d(two_elem, to_begin=7))
assert_array_equal([5,6,1], ediff1d(two_elem, to_begin=[5,6]))
+ assert(isinstance(ediff1d(np.matrix(1)), np.matrix))
+ assert(isinstance(ediff1d(np.matrix(1), to_begin=1), np.matrix))
def test_in1d(self):
# we use two different sizes for the b array here to test the