summaryrefslogtreecommitdiff
path: root/numpy/ma/tests/test_extras.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2016-09-05 12:58:06 -0600
committerCharles Harris <charlesr.harris@gmail.com>2016-09-05 16:36:14 -0600
commitad5b13a585cb2b43b7f6ef7a30124bec82dfa359 (patch)
tree741ffb503d72aac0de2e83f01c7192ffed504d1a /numpy/ma/tests/test_extras.py
parent66f313f5762eeecc4171e494a7779c3e247fb584 (diff)
downloadpython-numpy-ad5b13a585cb2b43b7f6ef7a30124bec82dfa359.tar.gz
python-numpy-ad5b13a585cb2b43b7f6ef7a30124bec82dfa359.tar.bz2
python-numpy-ad5b13a585cb2b43b7f6ef7a30124bec82dfa359.zip
TST: Add ma.median tests for valid axis.
Diffstat (limited to 'numpy/ma/tests/test_extras.py')
-rw-r--r--numpy/ma/tests/test_extras.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index 6d56d4dc6..27fac3d63 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -10,6 +10,7 @@ Adapted from the original test_ma by Pierre Gerard-Marchant
from __future__ import division, absolute_import, print_function
import warnings
+import itertools
import numpy as np
from numpy.testing import (
@@ -684,6 +685,37 @@ class TestMedian(TestCase):
assert_equal(ma_x.shape, (2,), "shape mismatch")
assert_(type(ma_x) is MaskedArray)
+ def test_axis_argument_errors(self):
+ msg = "mask = %s, ndim = %s, axis = %s, overwrite_input = %s"
+ for ndmin in range(5):
+ for mask in [False, True]:
+ x = array(1, ndmin=ndmin, mask=mask)
+
+ # Valid axis values should not raise exception
+ args = itertools.product(range(-ndmin, ndmin), [False, True])
+ for axis, over in args:
+ try:
+ np.ma.median(x, axis=axis, overwrite_input=over)
+ except:
+ raise AssertionError(msg % (mask, ndmin, axis, over))
+
+ # Invalid axis values should raise exception
+ args = itertools.product([-(ndmin + 1), ndmin], [False, True])
+ for axis, over in args:
+ try:
+ np.ma.median(x, axis=axis, overwrite_input=over)
+ except IndexError:
+ pass
+ else:
+ raise AssertionError(msg % (mask, ndmin, axis, over))
+
+ def test_masked_0d(self):
+ # Check values
+ x = array(1, mask=False)
+ assert_equal(np.ma.median(x), 1)
+ x = array(1, mask=True)
+ assert_equal(np.ma.median(x), np.ma.masked)
+
def test_masked_1d(self):
x = array(np.arange(5), mask=True)
assert_equal(np.ma.median(x), np.ma.masked)