diff options
author | Travis Oliphant <oliphant@enthought.com> | 2005-12-30 22:06:12 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2005-12-30 22:06:12 +0000 |
commit | 2719abddd4e61211f4e8bfa914d82280e179f77f (patch) | |
tree | 75011ec7dde5079d95a88c9cf55e198a7ad41bbf | |
parent | 4112ff35f848c199e5e2f12f08984e1d9392c1ca (diff) | |
download | python-numpy-2719abddd4e61211f4e8bfa914d82280e179f77f.tar.gz python-numpy-2719abddd4e61211f4e8bfa914d82280e179f77f.tar.bz2 python-numpy-2719abddd4e61211f4e8bfa914d82280e179f77f.zip |
Changed sort to in-place --- uses copy for now.
-rw-r--r-- | scipy/base/function_base.py | 14 | ||||
-rw-r--r-- | scipy/base/include/scipy/arrayobject.h | 7 | ||||
-rw-r--r-- | scipy/base/oldnumeric.py | 5 | ||||
-rw-r--r-- | scipy/base/src/arraymethods.c | 17 | ||||
-rw-r--r-- | scipy/base/src/multiarraymodule.c | 120 | ||||
-rw-r--r-- | scipy/base/tests/test_function_base.py | 67 |
6 files changed, 176 insertions, 54 deletions
diff --git a/scipy/base/function_base.py b/scipy/base/function_base.py index 864b3d9e2..2ebf3f90d 100644 --- a/scipy/base/function_base.py +++ b/scipy/base/function_base.py @@ -1,4 +1,5 @@ -__all__ = ['logspace', 'linspace', 'round_', + +l__all__ = ['logspace', 'linspace', 'round_', 'select', 'piecewise', 'trim_zeros', 'copy', 'iterable', 'base_repr', 'binary_repr', 'diff', 'gradient', 'angle', 'unwrap', 'sort_complex', 'disp', @@ -17,7 +18,8 @@ from numeric import ScalarType, dot, where, newaxis from umath import pi, multiply, add, arctan2, maximum, minimum, frompyfunc, \ isnan, absolute, cos, less_equal, sqrt, sin, mod from oldnumeric import ravel, nonzero, choose, \ - sometrue, alltrue, reshape, any, all, typecodes, ArrayType, squeeze + sometrue, alltrue, reshape, any, all, typecodes, ArrayType, squeeze,\ + sort from type_check import ScalarType, isscalar from shape_base import atleast_1d from twodim_base import diag @@ -111,13 +113,13 @@ def histogram(a, bins=10, range=None, normed=False): if not iterable(bins): if range is None: range = (a.min(), a.max()) - mn, mx = [a+0.0 for a in range] + mn, mx = [mi+0.0 for mi in range] if mn == mx: mn -= 0.5 mx += 0.5 - bins = linspace(mn, mx, bins) + bins = linspace(mn, mx, bins, endpoint=False) - n = a.sort().searchsorted(bins) + n = sort(a).searchsorted(bins) n = concatenate([n, [len(a)]]) n = n[1:]-n[:-1] @@ -451,7 +453,7 @@ def unwrap(p, discont=pi, axis=-1): _nx.putmask(ddmod, (ddmod==-pi) & (dd > 0), pi) ph_correct = ddmod - dd; _nx.putmask(ph_correct, abs(dd)<discont, 0) - up = array(p, copy=True, typecode='d') + up = array(p, copy=True, dtype='d') up[slice1] = p[slice1] + ph_correct.cumsum(axis) return up diff --git a/scipy/base/include/scipy/arrayobject.h b/scipy/base/include/scipy/arrayobject.h index c289add38..c0dcbd53d 100644 --- a/scipy/base/include/scipy/arrayobject.h +++ b/scipy/base/include/scipy/arrayobject.h @@ -208,6 +208,13 @@ enum PyArray_TYPECHAR { PyArray_BOOLLTR = '?', PyArray_COMPLEXLTR = 'c' }; +typedef enum { + PyArray_QUICKSORT, + PyArray_TIMSORT, + PyArray_HEAPSORT, + PyArray_MERGESORT, +} PyArray_SORTKIND; + /* Define bit-width array types and typedefs */ #define MAX_INT8 127 diff --git a/scipy/base/oldnumeric.py b/scipy/base/oldnumeric.py index 4b055ba16..521cbd6eb 100644 --- a/scipy/base/oldnumeric.py +++ b/scipy/base/oldnumeric.py @@ -204,8 +204,9 @@ def transpose(a, axes=None): def sort(a, axis=-1): """sort(a,axis=-1) returns array with elements sorted along given axis. """ - a = array(a, copy=False) - return a.sort(axis) + a = array(a, copy=True) + a.sort(axis) + return a def argsort(a, axis=-1): """argsort(a,axis=-1) return the indices into a of the sorted array diff --git a/scipy/base/src/arraymethods.c b/scipy/base/src/arraymethods.c index 18ccc35ee..d778dd445 100644 --- a/scipy/base/src/arraymethods.c +++ b/scipy/base/src/arraymethods.c @@ -709,28 +709,31 @@ array_choose(PyArrayObject *self, PyObject *args) return _ARET(PyArray_Choose(self, choices)); } -static char doc_sort[] = "a.sort(<None>)"; +static char doc_sort[] = "a.sort(<-1>) sorts in place along axis. Return is None."; static PyObject * array_sort(PyArrayObject *self, PyObject *args) { - int axis=MAX_DIMS; + int axis=-1; + int val; if (!PyArg_ParseTuple(args, "|O&", PyArray_AxisConverter, &axis)) return NULL; - return _ARET(PyArray_Sort(self, axis)); + val = PyArray_Sort(self, axis, PyArray_QUICKSORT); + if (val < 0) return NULL; + Py_INCREF(Py_None); + return Py_None; } -static char doc_argsort[] = "a.argsort(<None>)\n"\ +static char doc_argsort[] = "a.argsort(<-1>)\n"\ " Return the indexes into a that would sort it along the"\ - " given axis (or <None> if the sorting should be done"\ - " in terms of a.flat"; + " given axis"; static PyObject * array_argsort(PyArrayObject *self, PyObject *args) { - int axis=MAX_DIMS; + int axis=-1; if (!PyArg_ParseTuple(args, "|O&", PyArray_AxisConverter, &axis)) return NULL; diff --git a/scipy/base/src/multiarraymodule.c b/scipy/base/src/multiarraymodule.c index 55f68f1db..b8a04e523 100644 --- a/scipy/base/src/multiarraymodule.c +++ b/scipy/base/src/multiarraymodule.c @@ -1575,6 +1575,12 @@ qsortCompare (const void *a, const void *b) return global_obj->descr->f->compare(a,b,global_obj); } +/* Consumes reference to ap (op gets it) + op contains a version of the array with axes swapped if + local variable axis is not the last dimension. + orign must be defined locally. +*/ + #define SWAPAXES(op, ap) { \ orign = (ap)->nd-1; \ if (axis != orign) { \ @@ -1585,7 +1591,12 @@ qsortCompare (const void *a, const void *b) else (op) = (ap); \ } -#define SWAPBACK(op, ap) { \ +/* Consumes reference to ap (op gets it) + origin must be previously defined locally. + SWAPAXES must have been called previously. + op contains the swapped version of the array. +*/ +#define SWAPBACK(op, ap) { \ if (axis != orign) { \ (op) = (PyAO *)PyArray_SwapAxes((ap), axis, orign); \ Py_DECREF((ap)); \ @@ -1594,59 +1605,85 @@ qsortCompare (const void *a, const void *b) else (op) = (ap); \ } +#define SWAPAXES2(op, ap) { \ + orign = (ap)->nd-1; \ + if (axis != orign) { \ + (op) = (PyAO *)PyArray_SwapAxes((ap), axis, orign); \ + Py_DECREF((ap)); \ + if ((op) == NULL) return -1; \ + } \ + else (op) = (ap); \ + } + +#define SWAPBACK2(op, ap) { \ + if (axis != orign) { \ + (op) = (PyAO *)PyArray_SwapAxes((ap), axis, orign); \ + Py_DECREF((ap)); \ + if ((op) == NULL) return -1; \ + } \ + else (op) = (ap); \ + } + /*MULTIARRAY_API Sort an array */ -static PyObject * -PyArray_Sort(PyArrayObject *op, int axis) +static int +PyArray_Sort(PyArrayObject *op, int axis, PyArray_SORTKIND which) { PyArrayObject *ap=NULL, *store_arr=NULL; + PyArrayObject *save=op; char *ip; int i, n, m, elsize, orign; - if ((ap = (PyAO*) _check_axis(op, &axis, 0))==NULL) return NULL; - - SWAPAXES(op, ap); - - ap = (PyArrayObject *)PyArray_FromAny((PyObject *)op, - NULL, 1, 0, ENSURECOPY); + n = op->nd; + if (axis < 0) axis += n; + if ((axis < 0) || (axis >= n)) { + PyErr_Format(PyExc_ValueError, + "axis(=%d) out of bounds", axis); + return -1; + } - Py_DECREF(op); + SWAPAXES2(ap, op); - if (ap == NULL) return NULL; + op = (PyArrayObject *)PyArray_FromAny((PyObject *)ap, + NULL, 1, 0, + ENSURECOPY); + + if (op == NULL) return -1; - if (ap->descr->f->compare == NULL) { + if (op->descr->f->compare == NULL) { PyErr_SetString(PyExc_TypeError, - "compare not supported for type"); - Py_DECREF(ap); - return NULL; + "sort not supported for type"); + Py_DECREF(op); + return -1; } - elsize = ap->descr->elsize; - m = ap->dimensions[ap->nd-1]; + elsize = op->descr->elsize; + m = op->dimensions[op->nd-1]; if (m == 0) goto finish; - n = PyArray_SIZE(ap)/m; + n = PyArray_SIZE(op)/m; /* Store global -- allows re-entry -- restore before leaving*/ store_arr = global_obj; - global_obj = ap; + global_obj = op; - for (ip=ap->data, i=0; i<n; i++, ip+=elsize*m) { + for (ip=op->data, i=0; i<n; i++, ip+=elsize*m) { qsort(ip, m, elsize, qsortCompare); } global_obj = store_arr; if (PyErr_Occurred()) { - Py_DECREF(ap); - return NULL; + Py_DECREF(op); + return -1; } finish: - SWAPBACK(op, ap); + SWAPBACK2(ap, op); - return (PyObject *)op; + PyArray_CopyInto(save, ap); + return 0; } @@ -1675,38 +1712,43 @@ PyArray_ArgSort(PyArrayObject *op, int axis) int argsort_elsize; char *store_ptr; - if ((ap = (PyAO *)_check_axis(op, &axis, 0))==NULL) return NULL; + n = op->nd; + if (axis < 0) axis += n; + if ((axis < 0) || (axis >= n)) { + PyErr_Format(PyExc_ValueError, + "axis(=%d) out of bounds", axis); + return NULL; + } - SWAPAXES(op, ap); + SWAPAXES(ap, op); - ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)op, + op = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)ap, PyArray_NOTYPE, 1, 0); - Py_DECREF(op); - if (ap == NULL) return NULL; + if (op == NULL) return NULL; - ret = (PyArrayObject *)PyArray_New(ap->ob_type, ap->nd, - ap->dimensions, PyArray_INTP, - NULL, NULL, 0, 0, (PyObject *)ap); + ret = (PyArrayObject *)PyArray_New(op->ob_type, op->nd, + op->dimensions, PyArray_INTP, + NULL, NULL, 0, 0, (PyObject *)op); if (ret == NULL) goto fail; - if (ap->descr->f->compare == NULL) { + if (op->descr->f->compare == NULL) { PyErr_SetString(PyExc_TypeError, "compare not supported for type"); goto fail; } ip = (intp *)ret->data; - argsort_elsize = ap->descr->elsize; - m = ap->dimensions[ap->nd-1]; + argsort_elsize = op->descr->elsize; + m = op->dimensions[op->nd-1]; if (m == 0) goto finish; - n = PyArray_SIZE(ap)/m; + n = PyArray_SIZE(op)/m; store_ptr = global_data; - global_data = ap->data; + global_data = op->data; store = global_obj; - global_obj = ap; + global_obj = op; for (i=0; i<n; i++, ip+=m, global_data += m*argsort_elsize) { for(j=0; j<m; j++) ip[j] = j; qsort((char *)ip, m, sizeof(intp), @@ -1717,7 +1759,7 @@ PyArray_ArgSort(PyArrayObject *op, int axis) finish: - Py_DECREF(ap); + Py_DECREF(op); SWAPBACK(op, ret); return (PyObject *)op; diff --git a/scipy/base/tests/test_function_base.py b/scipy/base/tests/test_function_base.py index 77beb14f7..a4c3ef7e2 100644 --- a/scipy/base/tests/test_function_base.py +++ b/scipy/base/tests/test_function_base.py @@ -263,6 +263,73 @@ class test_vectorize(ScipyTestCase): assert_array_equal(r,[5,8,1,4]) + +class test_unwrap(ScipyTestCase): + def check_simple(self): + #check that unwrap removes jumps greather that 2*pi + assert_array_equal(unwrap([1,1+2*pi]),[1,1]) + #check that unwrap maintans continuity + assert(all(diff(unwrap(rand(10)*100))<pi)) + + +class test_filterwindows(ScipyTestCase): + def check_hanning(self): + #check symmetry + w=hanning(10) + assert_array_almost_equal(w,flipud(w),7) + #check known value + assert_almost_equal(sum(w),4.500,4) + + def check_hamming(self): + #check symmetry + w=hamming(10) + assert_array_almost_equal(w,flipud(w),7) + #check known value + assert_almost_equal(sum(w),4.9400,4) + + def check_bartlett(self): + #check symmetry + w=bartlett(10) + assert_array_almost_equal(w,flipud(w),7) + #check known value + assert_almost_equal(sum(w),4.4444,4) + + def check_blackman(self): + #check symmetry + w=blackman(10) + assert_array_almost_equal(w,flipud(w),7) + #check known value + assert_almost_equal(sum(w),3.7800,4) + + +class test_trapz(ScipyTestCase): + def check_simple(self): + r=trapz(exp(-1/2*(arange(-10,10,.1))**2)/sqrt(2*pi),dx=0.1) + #check integral of normal equals 1 + assert_almost_equal(sum(r),1,7) + +class test_sinc(ScipyTestCase): + def check_simple(self): + assert(sinc(0)==1) + w=sinc(linspace(-1,1,100)) + #check symmetry + assert_array_almost_equal(w,flipud(w),7) + +class test_histogram(ScipyTestCase): + def check_simple(self): + n=100 + v=rand(n) + (a,b)=histogram(v) + #check if the sum of the bins equals the number of samples + assert(sum(a)==n) + #check that the bin counts are evenly spaced when the data is from a linear function + (a,b)=histogram(linspace(0,10,100)) + assert(all(a==10)) + + + + + def compare_results(res,desired): for i in range(len(desired)): assert_array_equal(res[i],desired[i]) |