diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-03-22 23:10:43 +0000 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-05-05 21:36:03 +0100 |
commit | e6b8e75547af0cc4d38af458eff5e5d6c14102b8 (patch) | |
tree | ddb9dc76e81382ccce6d7f5f28c2f558b3658fa5 | |
parent | 69b0c42bca27dd5d5522de306bcd7db7deccbfad (diff) | |
download | python-numpy-e6b8e75547af0cc4d38af458eff5e5d6c14102b8.tar.gz python-numpy-e6b8e75547af0cc4d38af458eff5e5d6c14102b8.tar.bz2 python-numpy-e6b8e75547af0cc4d38af458eff5e5d6c14102b8.zip |
MAINT: Remove code duplicated from np.r_ in np.ma.mr_
Also adds a test for the disabled-by-design behaviour - this would return
raw matrices, not masked arrays
-rw-r--r-- | numpy/lib/index_tricks.py | 4 | ||||
-rw-r--r-- | numpy/ma/extras.py | 55 | ||||
-rw-r--r-- | numpy/ma/tests/test_extras.py | 5 |
3 files changed, 12 insertions, 52 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index 1fd530f33..58d3e0dcf 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -237,6 +237,8 @@ class AxisConcatenator(object): For detailed documentation on usage, see `r_`. """ + # allow ma.mr_ to override this + concatenate = staticmethod(_nx.concatenate) def _retval(self, res): if self.matrix: @@ -345,7 +347,7 @@ class AxisConcatenator(object): for k in scalars: objs[k] = objs[k].astype(final_dtype) - res = _nx.concatenate(tuple(objs), axis=self.axis) + res = self.concatenate(tuple(objs), axis=self.axis) return self._retval(res) def __len__(self): diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index e100e471c..10b9634a3 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -1461,60 +1461,15 @@ class MAxisConcatenator(AxisConcatenator): mr_class """ - - def __init__(self, axis=0): - AxisConcatenator.__init__(self, axis, matrix=False) + concatenate = staticmethod(concatenate) def __getitem__(self, key): + # matrix builder syntax, like 'a, b; c, d' if isinstance(key, str): raise MAError("Unavailable for masked array.") - if not isinstance(key, tuple): - key = (key,) - objs = [] - scalars = [] - final_dtypedescr = None - for k in range(len(key)): - scalar = False - if isinstance(key[k], slice): - step = key[k].step - start = key[k].start - stop = key[k].stop - if start is None: - start = 0 - if step is None: - step = 1 - if isinstance(step, complex): - size = int(abs(step)) - newobj = np.linspace(start, stop, num=size) - else: - newobj = np.arange(start, stop, step) - elif isinstance(key[k], str): - if (key[k] in 'rc'): - self.matrix = True - self.col = (key[k] == 'c') - continue - try: - self.axis = int(key[k]) - continue - except (ValueError, TypeError): - raise ValueError("Unknown special directive") - elif type(key[k]) in np.ScalarType: - newobj = asarray([key[k]]) - scalars.append(k) - scalar = True - else: - newobj = key[k] - objs.append(newobj) - if isinstance(newobj, ndarray) and not scalar: - if final_dtypedescr is None: - final_dtypedescr = newobj.dtype - elif newobj.dtype > final_dtypedescr: - final_dtypedescr = newobj.dtype - if final_dtypedescr is not None: - for k in scalars: - objs[k] = objs[k].astype(final_dtypedescr) - res = concatenate(tuple(objs), axis=self.axis) - return self._retval(res) + + return super(MAxisConcatenator, self).__getitem__(key) + class mr_class(MAxisConcatenator): """ diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py index e7ebd8b82..7de21ff59 100644 --- a/numpy/ma/tests/test_extras.py +++ b/numpy/ma/tests/test_extras.py @@ -14,7 +14,8 @@ import itertools import numpy as np from numpy.testing import ( - TestCase, run_module_suite, assert_warns, suppress_warnings + TestCase, run_module_suite, assert_warns, suppress_warnings, + assert_raises ) from numpy.ma.testutils import ( assert_, assert_array_equal, assert_equal, assert_almost_equal @@ -304,6 +305,8 @@ class TestConcatenator(TestCase): assert_array_equal(d[5:,:], b_2) assert_array_equal(d.mask, np.r_[m_1, m_2]) + def test_matrix_builder(self): + assert_raises(np.ma.MAError, lambda: mr_['1, 2; 3, 4']) class TestNotMasked(TestCase): # Tests notmasked_edges and notmasked_contiguous. |