summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2005-12-30 22:06:12 +0000
committerTravis Oliphant <oliphant@enthought.com>2005-12-30 22:06:12 +0000
commit2719abddd4e61211f4e8bfa914d82280e179f77f (patch)
tree75011ec7dde5079d95a88c9cf55e198a7ad41bbf
parent4112ff35f848c199e5e2f12f08984e1d9392c1ca (diff)
downloadpython-numpy-2719abddd4e61211f4e8bfa914d82280e179f77f.tar.gz
python-numpy-2719abddd4e61211f4e8bfa914d82280e179f77f.tar.bz2
python-numpy-2719abddd4e61211f4e8bfa914d82280e179f77f.zip
Changed sort to in-place --- uses copy for now.
-rw-r--r--scipy/base/function_base.py14
-rw-r--r--scipy/base/include/scipy/arrayobject.h7
-rw-r--r--scipy/base/oldnumeric.py5
-rw-r--r--scipy/base/src/arraymethods.c17
-rw-r--r--scipy/base/src/multiarraymodule.c120
-rw-r--r--scipy/base/tests/test_function_base.py67
6 files changed, 176 insertions, 54 deletions
diff --git a/scipy/base/function_base.py b/scipy/base/function_base.py
index 864b3d9e2..2ebf3f90d 100644
--- a/scipy/base/function_base.py
+++ b/scipy/base/function_base.py
@@ -1,4 +1,5 @@
-__all__ = ['logspace', 'linspace', 'round_',
+
+l__all__ = ['logspace', 'linspace', 'round_',
'select', 'piecewise', 'trim_zeros',
'copy', 'iterable', 'base_repr', 'binary_repr',
'diff', 'gradient', 'angle', 'unwrap', 'sort_complex', 'disp',
@@ -17,7 +18,8 @@ from numeric import ScalarType, dot, where, newaxis
from umath import pi, multiply, add, arctan2, maximum, minimum, frompyfunc, \
isnan, absolute, cos, less_equal, sqrt, sin, mod
from oldnumeric import ravel, nonzero, choose, \
- sometrue, alltrue, reshape, any, all, typecodes, ArrayType, squeeze
+ sometrue, alltrue, reshape, any, all, typecodes, ArrayType, squeeze,\
+ sort
from type_check import ScalarType, isscalar
from shape_base import atleast_1d
from twodim_base import diag
@@ -111,13 +113,13 @@ def histogram(a, bins=10, range=None, normed=False):
if not iterable(bins):
if range is None:
range = (a.min(), a.max())
- mn, mx = [a+0.0 for a in range]
+ mn, mx = [mi+0.0 for mi in range]
if mn == mx:
mn -= 0.5
mx += 0.5
- bins = linspace(mn, mx, bins)
+ bins = linspace(mn, mx, bins, endpoint=False)
- n = a.sort().searchsorted(bins)
+ n = sort(a).searchsorted(bins)
n = concatenate([n, [len(a)]])
n = n[1:]-n[:-1]
@@ -451,7 +453,7 @@ def unwrap(p, discont=pi, axis=-1):
_nx.putmask(ddmod, (ddmod==-pi) & (dd > 0), pi)
ph_correct = ddmod - dd;
_nx.putmask(ph_correct, abs(dd)<discont, 0)
- up = array(p, copy=True, typecode='d')
+ up = array(p, copy=True, dtype='d')
up[slice1] = p[slice1] + ph_correct.cumsum(axis)
return up
diff --git a/scipy/base/include/scipy/arrayobject.h b/scipy/base/include/scipy/arrayobject.h
index c289add38..c0dcbd53d 100644
--- a/scipy/base/include/scipy/arrayobject.h
+++ b/scipy/base/include/scipy/arrayobject.h
@@ -208,6 +208,13 @@ enum PyArray_TYPECHAR { PyArray_BOOLLTR = '?',
PyArray_COMPLEXLTR = 'c'
};
+typedef enum {
+ PyArray_QUICKSORT,
+ PyArray_TIMSORT,
+ PyArray_HEAPSORT,
+ PyArray_MERGESORT,
+} PyArray_SORTKIND;
+
/* Define bit-width array types and typedefs */
#define MAX_INT8 127
diff --git a/scipy/base/oldnumeric.py b/scipy/base/oldnumeric.py
index 4b055ba16..521cbd6eb 100644
--- a/scipy/base/oldnumeric.py
+++ b/scipy/base/oldnumeric.py
@@ -204,8 +204,9 @@ def transpose(a, axes=None):
def sort(a, axis=-1):
"""sort(a,axis=-1) returns array with elements sorted along given axis.
"""
- a = array(a, copy=False)
- return a.sort(axis)
+ a = array(a, copy=True)
+ a.sort(axis)
+ return a
def argsort(a, axis=-1):
"""argsort(a,axis=-1) return the indices into a of the sorted array
diff --git a/scipy/base/src/arraymethods.c b/scipy/base/src/arraymethods.c
index 18ccc35ee..d778dd445 100644
--- a/scipy/base/src/arraymethods.c
+++ b/scipy/base/src/arraymethods.c
@@ -709,28 +709,31 @@ array_choose(PyArrayObject *self, PyObject *args)
return _ARET(PyArray_Choose(self, choices));
}
-static char doc_sort[] = "a.sort(<None>)";
+static char doc_sort[] = "a.sort(<-1>) sorts in place along axis. Return is None.";
static PyObject *
array_sort(PyArrayObject *self, PyObject *args)
{
- int axis=MAX_DIMS;
+ int axis=-1;
+ int val;
if (!PyArg_ParseTuple(args, "|O&", PyArray_AxisConverter,
&axis)) return NULL;
- return _ARET(PyArray_Sort(self, axis));
+ val = PyArray_Sort(self, axis, PyArray_QUICKSORT);
+ if (val < 0) return NULL;
+ Py_INCREF(Py_None);
+ return Py_None;
}
-static char doc_argsort[] = "a.argsort(<None>)\n"\
+static char doc_argsort[] = "a.argsort(<-1>)\n"\
" Return the indexes into a that would sort it along the"\
- " given axis (or <None> if the sorting should be done"\
- " in terms of a.flat";
+ " given axis";
static PyObject *
array_argsort(PyArrayObject *self, PyObject *args)
{
- int axis=MAX_DIMS;
+ int axis=-1;
if (!PyArg_ParseTuple(args, "|O&", PyArray_AxisConverter,
&axis)) return NULL;
diff --git a/scipy/base/src/multiarraymodule.c b/scipy/base/src/multiarraymodule.c
index 55f68f1db..b8a04e523 100644
--- a/scipy/base/src/multiarraymodule.c
+++ b/scipy/base/src/multiarraymodule.c
@@ -1575,6 +1575,12 @@ qsortCompare (const void *a, const void *b)
return global_obj->descr->f->compare(a,b,global_obj);
}
+/* Consumes reference to ap (op gets it)
+ op contains a version of the array with axes swapped if
+ local variable axis is not the last dimension.
+ orign must be defined locally.
+*/
+
#define SWAPAXES(op, ap) { \
orign = (ap)->nd-1; \
if (axis != orign) { \
@@ -1585,7 +1591,12 @@ qsortCompare (const void *a, const void *b)
else (op) = (ap); \
}
-#define SWAPBACK(op, ap) { \
+/* Consumes reference to ap (op gets it)
+ origin must be previously defined locally.
+ SWAPAXES must have been called previously.
+ op contains the swapped version of the array.
+*/
+#define SWAPBACK(op, ap) { \
if (axis != orign) { \
(op) = (PyAO *)PyArray_SwapAxes((ap), axis, orign); \
Py_DECREF((ap)); \
@@ -1594,59 +1605,85 @@ qsortCompare (const void *a, const void *b)
else (op) = (ap); \
}
+#define SWAPAXES2(op, ap) { \
+ orign = (ap)->nd-1; \
+ if (axis != orign) { \
+ (op) = (PyAO *)PyArray_SwapAxes((ap), axis, orign); \
+ Py_DECREF((ap)); \
+ if ((op) == NULL) return -1; \
+ } \
+ else (op) = (ap); \
+ }
+
+#define SWAPBACK2(op, ap) { \
+ if (axis != orign) { \
+ (op) = (PyAO *)PyArray_SwapAxes((ap), axis, orign); \
+ Py_DECREF((ap)); \
+ if ((op) == NULL) return -1; \
+ } \
+ else (op) = (ap); \
+ }
+
/*MULTIARRAY_API
Sort an array
*/
-static PyObject *
-PyArray_Sort(PyArrayObject *op, int axis)
+static int
+PyArray_Sort(PyArrayObject *op, int axis, PyArray_SORTKIND which)
{
PyArrayObject *ap=NULL, *store_arr=NULL;
+ PyArrayObject *save=op;
char *ip;
int i, n, m, elsize, orign;
- if ((ap = (PyAO*) _check_axis(op, &axis, 0))==NULL) return NULL;
-
- SWAPAXES(op, ap);
-
- ap = (PyArrayObject *)PyArray_FromAny((PyObject *)op,
- NULL, 1, 0, ENSURECOPY);
+ n = op->nd;
+ if (axis < 0) axis += n;
+ if ((axis < 0) || (axis >= n)) {
+ PyErr_Format(PyExc_ValueError,
+ "axis(=%d) out of bounds", axis);
+ return -1;
+ }
- Py_DECREF(op);
+ SWAPAXES2(ap, op);
- if (ap == NULL) return NULL;
+ op = (PyArrayObject *)PyArray_FromAny((PyObject *)ap,
+ NULL, 1, 0,
+ ENSURECOPY);
+
+ if (op == NULL) return -1;
- if (ap->descr->f->compare == NULL) {
+ if (op->descr->f->compare == NULL) {
PyErr_SetString(PyExc_TypeError,
- "compare not supported for type");
- Py_DECREF(ap);
- return NULL;
+ "sort not supported for type");
+ Py_DECREF(op);
+ return -1;
}
- elsize = ap->descr->elsize;
- m = ap->dimensions[ap->nd-1];
+ elsize = op->descr->elsize;
+ m = op->dimensions[op->nd-1];
if (m == 0) goto finish;
- n = PyArray_SIZE(ap)/m;
+ n = PyArray_SIZE(op)/m;
/* Store global -- allows re-entry -- restore before leaving*/
store_arr = global_obj;
- global_obj = ap;
+ global_obj = op;
- for (ip=ap->data, i=0; i<n; i++, ip+=elsize*m) {
+ for (ip=op->data, i=0; i<n; i++, ip+=elsize*m) {
qsort(ip, m, elsize, qsortCompare);
}
global_obj = store_arr;
if (PyErr_Occurred()) {
- Py_DECREF(ap);
- return NULL;
+ Py_DECREF(op);
+ return -1;
}
finish:
- SWAPBACK(op, ap);
+ SWAPBACK2(ap, op);
- return (PyObject *)op;
+ PyArray_CopyInto(save, ap);
+ return 0;
}
@@ -1675,38 +1712,43 @@ PyArray_ArgSort(PyArrayObject *op, int axis)
int argsort_elsize;
char *store_ptr;
- if ((ap = (PyAO *)_check_axis(op, &axis, 0))==NULL) return NULL;
+ n = op->nd;
+ if (axis < 0) axis += n;
+ if ((axis < 0) || (axis >= n)) {
+ PyErr_Format(PyExc_ValueError,
+ "axis(=%d) out of bounds", axis);
+ return NULL;
+ }
- SWAPAXES(op, ap);
+ SWAPAXES(ap, op);
- ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)op,
+ op = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)ap,
PyArray_NOTYPE,
1, 0);
- Py_DECREF(op);
- if (ap == NULL) return NULL;
+ if (op == NULL) return NULL;
- ret = (PyArrayObject *)PyArray_New(ap->ob_type, ap->nd,
- ap->dimensions, PyArray_INTP,
- NULL, NULL, 0, 0, (PyObject *)ap);
+ ret = (PyArrayObject *)PyArray_New(op->ob_type, op->nd,
+ op->dimensions, PyArray_INTP,
+ NULL, NULL, 0, 0, (PyObject *)op);
if (ret == NULL) goto fail;
- if (ap->descr->f->compare == NULL) {
+ if (op->descr->f->compare == NULL) {
PyErr_SetString(PyExc_TypeError,
"compare not supported for type");
goto fail;
}
ip = (intp *)ret->data;
- argsort_elsize = ap->descr->elsize;
- m = ap->dimensions[ap->nd-1];
+ argsort_elsize = op->descr->elsize;
+ m = op->dimensions[op->nd-1];
if (m == 0) goto finish;
- n = PyArray_SIZE(ap)/m;
+ n = PyArray_SIZE(op)/m;
store_ptr = global_data;
- global_data = ap->data;
+ global_data = op->data;
store = global_obj;
- global_obj = ap;
+ global_obj = op;
for (i=0; i<n; i++, ip+=m, global_data += m*argsort_elsize) {
for(j=0; j<m; j++) ip[j] = j;
qsort((char *)ip, m, sizeof(intp),
@@ -1717,7 +1759,7 @@ PyArray_ArgSort(PyArrayObject *op, int axis)
finish:
- Py_DECREF(ap);
+ Py_DECREF(op);
SWAPBACK(op, ret);
return (PyObject *)op;
diff --git a/scipy/base/tests/test_function_base.py b/scipy/base/tests/test_function_base.py
index 77beb14f7..a4c3ef7e2 100644
--- a/scipy/base/tests/test_function_base.py
+++ b/scipy/base/tests/test_function_base.py
@@ -263,6 +263,73 @@ class test_vectorize(ScipyTestCase):
assert_array_equal(r,[5,8,1,4])
+
+class test_unwrap(ScipyTestCase):
+ def check_simple(self):
+ #check that unwrap removes jumps greather that 2*pi
+ assert_array_equal(unwrap([1,1+2*pi]),[1,1])
+ #check that unwrap maintans continuity
+ assert(all(diff(unwrap(rand(10)*100))<pi))
+
+
+class test_filterwindows(ScipyTestCase):
+ def check_hanning(self):
+ #check symmetry
+ w=hanning(10)
+ assert_array_almost_equal(w,flipud(w),7)
+ #check known value
+ assert_almost_equal(sum(w),4.500,4)
+
+ def check_hamming(self):
+ #check symmetry
+ w=hamming(10)
+ assert_array_almost_equal(w,flipud(w),7)
+ #check known value
+ assert_almost_equal(sum(w),4.9400,4)
+
+ def check_bartlett(self):
+ #check symmetry
+ w=bartlett(10)
+ assert_array_almost_equal(w,flipud(w),7)
+ #check known value
+ assert_almost_equal(sum(w),4.4444,4)
+
+ def check_blackman(self):
+ #check symmetry
+ w=blackman(10)
+ assert_array_almost_equal(w,flipud(w),7)
+ #check known value
+ assert_almost_equal(sum(w),3.7800,4)
+
+
+class test_trapz(ScipyTestCase):
+ def check_simple(self):
+ r=trapz(exp(-1/2*(arange(-10,10,.1))**2)/sqrt(2*pi),dx=0.1)
+ #check integral of normal equals 1
+ assert_almost_equal(sum(r),1,7)
+
+class test_sinc(ScipyTestCase):
+ def check_simple(self):
+ assert(sinc(0)==1)
+ w=sinc(linspace(-1,1,100))
+ #check symmetry
+ assert_array_almost_equal(w,flipud(w),7)
+
+class test_histogram(ScipyTestCase):
+ def check_simple(self):
+ n=100
+ v=rand(n)
+ (a,b)=histogram(v)
+ #check if the sum of the bins equals the number of samples
+ assert(sum(a)==n)
+ #check that the bin counts are evenly spaced when the data is from a linear function
+ (a,b)=histogram(linspace(0,10,100))
+ assert(all(a==10))
+
+
+
+
+
def compare_results(res,desired):
for i in range(len(desired)):
assert_array_equal(res[i],desired[i])