summaryrefslogtreecommitdiff
path: root/numpy/ma/tests/test_extras.py
diff options
context:
space:
mode:
authorDaniel da Silva <mail@danieldasilva.org>2015-04-03 21:29:24 -0400
committerDaniel da Silva <mail@danieldasilva.org>2015-05-03 22:18:18 -0400
commit883d052e3eb9b45a4bb87e7e84f487e0c9e5c882 (patch)
tree201f09008b172571f7e478f3fb9eeadca01def20 /numpy/ma/tests/test_extras.py
parent147c60f83f401037ff29593826d2c5729a73c2c5 (diff)
downloadpython-numpy-883d052e3eb9b45a4bb87e7e84f487e0c9e5c882.tar.gz
python-numpy-883d052e3eb9b45a4bb87e7e84f487e0c9e5c882.tar.bz2
python-numpy-883d052e3eb9b45a4bb87e7e84f487e0c9e5c882.zip
ENH: Introduce np.ma.compress_nd(), generalizes np.ma.compress_rowcols()
Provides a way to supress slices along an abitrary tuple of dimensions.
Diffstat (limited to 'numpy/ma/tests/test_extras.py')
-rw-r--r--numpy/ma/tests/test_extras.py122
1 files changed, 117 insertions, 5 deletions
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index ee8e6bc18..b6749ae9e 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -29,7 +29,7 @@ from numpy.ma.extras import (
cov, corrcoef, median, average,
unique, setxor1d, setdiff1d, union1d, intersect1d, in1d, ediff1d,
apply_over_axes, apply_along_axis,
- compress_rowcols, mask_rowcols,
+ compress_nd, compress_rowcols, mask_rowcols,
clump_masked, clump_unmasked,
flatnotmasked_contiguous, notmasked_contiguous, notmasked_edges,
masked_all, masked_all_like)
@@ -347,10 +347,122 @@ class TestNotMasked(TestCase):
assert_equal(tmp[2][-2], slice(0, 6, None))
-class Test2DFunctions(TestCase):
- # Tests 2D functions
- def test_compress2d(self):
- # Tests compress2d
+class TestCompressFunctions(TestCase):
+
+ def test_compress_nd(self):
+ # Tests compress_nd
+ x = np.array(list(range(3*4*5))).reshape(3, 4, 5)
+ m = np.zeros((3,4,5)).astype(bool)
+ m[1,1,1] = True
+ x = array(x, mask=m)
+
+ # axis=None
+ a = compress_nd(x)
+ assert_equal(a, [[[ 0, 2, 3 , 4],
+ [10, 12, 13, 14],
+ [15, 17, 18, 19]],
+ [[40, 42, 43, 44],
+ [50, 52, 53, 54],
+ [55, 57, 58, 59]]])
+
+ # axis=0
+ a = compress_nd(x, 0)
+ assert_equal(a, [[[ 0, 1, 2, 3, 4],
+ [ 5, 6, 7, 8, 9],
+ [10, 11, 12, 13, 14],
+ [15, 16, 17, 18, 19]],
+ [[40, 41, 42, 43, 44],
+ [45, 46, 47, 48, 49],
+ [50, 51, 52, 53, 54],
+ [55, 56, 57, 58, 59]]])
+
+ # axis=1
+ a = compress_nd(x, 1)
+ assert_equal(a, [[[ 0, 1, 2, 3, 4],
+ [10, 11, 12, 13, 14],
+ [15, 16, 17, 18, 19]],
+ [[20, 21, 22, 23, 24],
+ [30, 31, 32, 33, 34],
+ [35, 36, 37, 38, 39]],
+ [[40, 41, 42, 43, 44],
+ [50, 51, 52, 53, 54],
+ [55, 56, 57, 58, 59]]])
+
+ a2 = compress_nd(x, (1,))
+ a3 = compress_nd(x, -2)
+ a4 = compress_nd(x, (-2,))
+ assert_equal(a, a2)
+ assert_equal(a, a3)
+ assert_equal(a, a4)
+
+ # axis=2
+ a = compress_nd(x, 2)
+ assert_equal(a, [[[ 0, 2, 3, 4],
+ [ 5, 7, 8, 9],
+ [10, 12, 13, 14],
+ [15, 17, 18, 19]],
+ [[20, 22, 23, 24],
+ [25, 27, 28, 29],
+ [30, 32, 33, 34],
+ [35, 37, 38, 39]],
+ [[40, 42, 43, 44],
+ [45, 47, 48, 49],
+ [50, 52, 53, 54],
+ [55, 57, 58, 59]]])
+
+ a2 = compress_nd(x, (2,))
+ a3 = compress_nd(x, -1)
+ a4 = compress_nd(x, (-1,))
+ assert_equal(a, a2)
+ assert_equal(a, a3)
+ assert_equal(a, a4)
+
+ # axis=(0, 1)
+ a = compress_nd(x, (0, 1))
+ assert_equal(a, [[[ 0, 1, 2, 3, 4],
+ [10, 11, 12, 13, 14],
+ [15, 16, 17, 18, 19]],
+ [[40, 41, 42, 43, 44],
+ [50, 51, 52, 53, 54],
+ [55, 56, 57, 58, 59]]])
+ a2 = compress_nd(x, (0, -2))
+ assert_equal(a, a2)
+
+ # axis=(1, 2)
+ a = compress_nd(x, (1, 2))
+ assert_equal(a, [[[ 0, 2, 3, 4],
+ [10, 12, 13, 14],
+ [15, 17, 18, 19]],
+ [[20, 22, 23, 24],
+ [30, 32, 33, 34],
+ [35, 37, 38, 39]],
+ [[40, 42, 43, 44],
+ [50, 52, 53, 54],
+ [55, 57, 58, 59]]])
+
+ a2 = compress_nd(x, (-2, 2))
+ a3 = compress_nd(x, (1, -1))
+ a4 = compress_nd(x, (-2, -1))
+ assert_equal(a, a2)
+ assert_equal(a, a3)
+ assert_equal(a, a4)
+
+ # axis=(0, 2)
+ a = compress_nd(x, (0, 2))
+ assert_equal(a, [[[ 0, 2, 3, 4],
+ [ 5, 7, 8, 9],
+ [10, 12, 13, 14],
+ [15, 17, 18, 19]],
+ [[40, 42, 43, 44],
+ [45, 47, 48, 49],
+ [50, 52, 53, 54],
+ [55, 57, 58, 59]]])
+
+ a2 = compress_nd(x, (0, -1))
+ assert_equal(a, a2)
+
+ def test_compress_rowcols(self):
+ # Tests compress_rowcols
x = array(np.arange(9).reshape(3, 3),
mask=[[1, 0, 0], [0, 0, 0], [0, 0, 0]])
assert_equal(compress_rowcols(x), [[4, 5], [7, 8]])