summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/lib/index_tricks.py4
-rw-r--r--numpy/ma/extras.py55
-rw-r--r--numpy/ma/tests/test_extras.py5
3 files changed, 12 insertions, 52 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 1fd530f33..58d3e0dcf 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -237,6 +237,8 @@ class AxisConcatenator(object):
For detailed documentation on usage, see `r_`.
"""
+ # allow ma.mr_ to override this
+ concatenate = staticmethod(_nx.concatenate)
def _retval(self, res):
if self.matrix:
@@ -345,7 +347,7 @@ class AxisConcatenator(object):
for k in scalars:
objs[k] = objs[k].astype(final_dtype)
- res = _nx.concatenate(tuple(objs), axis=self.axis)
+ res = self.concatenate(tuple(objs), axis=self.axis)
return self._retval(res)
def __len__(self):
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index e100e471c..10b9634a3 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -1461,60 +1461,15 @@ class MAxisConcatenator(AxisConcatenator):
mr_class
"""
-
- def __init__(self, axis=0):
- AxisConcatenator.__init__(self, axis, matrix=False)
+ concatenate = staticmethod(concatenate)
def __getitem__(self, key):
+ # matrix builder syntax, like 'a, b; c, d'
if isinstance(key, str):
raise MAError("Unavailable for masked array.")
- if not isinstance(key, tuple):
- key = (key,)
- objs = []
- scalars = []
- final_dtypedescr = None
- for k in range(len(key)):
- scalar = False
- if isinstance(key[k], slice):
- step = key[k].step
- start = key[k].start
- stop = key[k].stop
- if start is None:
- start = 0
- if step is None:
- step = 1
- if isinstance(step, complex):
- size = int(abs(step))
- newobj = np.linspace(start, stop, num=size)
- else:
- newobj = np.arange(start, stop, step)
- elif isinstance(key[k], str):
- if (key[k] in 'rc'):
- self.matrix = True
- self.col = (key[k] == 'c')
- continue
- try:
- self.axis = int(key[k])
- continue
- except (ValueError, TypeError):
- raise ValueError("Unknown special directive")
- elif type(key[k]) in np.ScalarType:
- newobj = asarray([key[k]])
- scalars.append(k)
- scalar = True
- else:
- newobj = key[k]
- objs.append(newobj)
- if isinstance(newobj, ndarray) and not scalar:
- if final_dtypedescr is None:
- final_dtypedescr = newobj.dtype
- elif newobj.dtype > final_dtypedescr:
- final_dtypedescr = newobj.dtype
- if final_dtypedescr is not None:
- for k in scalars:
- objs[k] = objs[k].astype(final_dtypedescr)
- res = concatenate(tuple(objs), axis=self.axis)
- return self._retval(res)
+
+ return super(MAxisConcatenator, self).__getitem__(key)
+
class mr_class(MAxisConcatenator):
"""
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index e7ebd8b82..7de21ff59 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -14,7 +14,8 @@ import itertools
import numpy as np
from numpy.testing import (
- TestCase, run_module_suite, assert_warns, suppress_warnings
+ TestCase, run_module_suite, assert_warns, suppress_warnings,
+ assert_raises
)
from numpy.ma.testutils import (
assert_, assert_array_equal, assert_equal, assert_almost_equal
@@ -304,6 +305,8 @@ class TestConcatenator(TestCase):
assert_array_equal(d[5:,:], b_2)
assert_array_equal(d.mask, np.r_[m_1, m_2])
+ def test_matrix_builder(self):
+ assert_raises(np.ma.MAError, lambda: mr_['1, 2; 3, 4'])
class TestNotMasked(TestCase):
# Tests notmasked_edges and notmasked_contiguous.