summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@google.com>2017-05-05 21:50:18 -0700
committerStephan Hoyer <shoyer@google.com>2017-05-07 17:03:40 -0700
commitd51b538ba80d36841cc57911d77ea61cd1d3fb25 (patch)
tree0c8cbc64c3ae6eb4dd0bda4a810d64902b974324
parentc9d1f9e467155cec3030b0970816abe928244b9c (diff)
downloadpython-numpy-d51b538ba80d36841cc57911d77ea61cd1d3fb25.tar.gz
python-numpy-d51b538ba80d36841cc57911d77ea61cd1d3fb25.tar.bz2
python-numpy-d51b538ba80d36841cc57911d77ea61cd1d3fb25.zip
ENH: add divmod support to NDArrayOperatorsMixin
-rw-r--r--numpy/lib/mixins.py8
-rw-r--r--numpy/lib/tests/test_mixins.py92
2 files changed, 60 insertions, 40 deletions
diff --git a/numpy/lib/mixins.py b/numpy/lib/mixins.py
index bbeed1437..fbdc2edfb 100644
--- a/numpy/lib/mixins.py
+++ b/numpy/lib/mixins.py
@@ -70,8 +70,7 @@ class NDArrayOperatorsMixin(object):
implement.
This class does not yet implement the special operators corresponding
- to ``divmod`` or ``matmul`` (``@``), because these operation do not yet
- have corresponding NumPy ufuncs.
+ to ``matmul`` (``@``), because ``np.matmul`` is not yet a NumPy ufunc.
It is useful for writing classes that do not inherit from `numpy.ndarray`,
but that should support arithmetic and numpy universal functions like
@@ -161,7 +160,10 @@ class NDArrayOperatorsMixin(object):
um.true_divide, 'truediv')
__floordiv__, __rfloordiv__, __ifloordiv__ = _numeric_methods(
um.floor_divide, 'floordiv')
- __mod__, __rmod__, __imod__ = _numeric_methods(um.mod, 'mod')
+ __mod__, __rmod__, __imod__ = _numeric_methods(um.remainder, 'mod')
+ __divmod__ = _binary_method(um.divmod, 'divmod')
+ __rdivmod__ = _reflected_binary_method(um.divmod, 'divmod')
+ # __idivmod__ does not exist
# TODO: handle the optional third argument for __pow__?
__pow__, __rpow__, __ipow__ = _numeric_methods(um.power, 'pow')
__lshift__, __rlshift__, __ilshift__ = _numeric_methods(
diff --git a/numpy/lib/tests/test_mixins.py b/numpy/lib/tests/test_mixins.py
index 287d4ed29..db38bdfd6 100644
--- a/numpy/lib/tests/test_mixins.py
+++ b/numpy/lib/tests/test_mixins.py
@@ -56,11 +56,47 @@ class ArrayLike(np.lib.mixins.NDArrayOperatorsMixin):
return '%s(%r)' % (type(self).__name__, self.value)
+def wrap_array_like(result):
+ if type(result) is tuple:
+ return tuple(ArrayLike(r) for r in result)
+ else:
+ return ArrayLike(result)
+
+
def _assert_equal_type_and_value(result, expected, err_msg=None):
assert_equal(type(result), type(expected), err_msg=err_msg)
- assert_equal(result.value, expected.value, err_msg=err_msg)
- assert_equal(getattr(result.value, 'dtype', None),
- getattr(expected.value, 'dtype', None), err_msg=err_msg)
+ if isinstance(result, tuple):
+ assert_equal(len(result), len(expected), err_msg=err_msg)
+ for result_item, expected_item in zip(result, expected):
+ _assert_equal_type_and_value(result_item, expected_item, err_msg)
+ else:
+ assert_equal(result.value, expected.value, err_msg=err_msg)
+ assert_equal(getattr(result.value, 'dtype', None),
+ getattr(expected.value, 'dtype', None), err_msg=err_msg)
+
+
+_ALL_BINARY_OPERATORS = [
+ operator.lt,
+ operator.le,
+ operator.eq,
+ operator.ne,
+ operator.gt,
+ operator.ge,
+ operator.add,
+ operator.sub,
+ operator.mul,
+ operator.truediv,
+ operator.floordiv,
+ # TODO: test div on Python 2, only
+ operator.mod,
+ divmod,
+ pow,
+ operator.lshift,
+ operator.rshift,
+ operator.and_,
+ operator.xor,
+ operator.or_,
+]
class TestNDArrayOperatorsMixin(TestCase):
@@ -148,52 +184,34 @@ class TestNDArrayOperatorsMixin(TestCase):
operator.invert]:
_assert_equal_type_and_value(op(array_like), ArrayLike(op(array)))
- def test_binary_methods(self):
+ def test_forward_binary_methods(self):
array = np.array([-1, 0, 1, 2])
array_like = ArrayLike(array)
- operators = [
- operator.lt,
- operator.le,
- operator.eq,
- operator.ne,
- operator.gt,
- operator.ge,
- operator.add,
- operator.sub,
- operator.mul,
- operator.truediv,
- operator.floordiv,
- # TODO: test div on Python 2, only
- operator.mod,
- # divmod is not yet implemented
- pow,
- operator.lshift,
- operator.rshift,
- operator.and_,
- operator.xor,
- operator.or_,
- ]
- for op in operators:
- expected = ArrayLike(op(array, 1))
+ for op in _ALL_BINARY_OPERATORS:
+ expected = wrap_array_like(op(array, 1))
actual = op(array_like, 1)
err_msg = 'failed for operator {}'.format(op)
_assert_equal_type_and_value(expected, actual, err_msg=err_msg)
+ def test_reflected_binary_methods(self):
+ for op in _ALL_BINARY_OPERATORS:
+ expected = wrap_array_like(op(2, 1))
+ actual = op(2, ArrayLike(1))
+ err_msg = 'failed for operator {}'.format(op)
+ _assert_equal_type_and_value(expected, actual, err_msg=err_msg)
+
def test_ufunc_at(self):
array = ArrayLike(np.array([1, 2, 3, 4]))
assert_(np.negative.at(array, np.array([0, 1])) is None)
_assert_equal_type_and_value(array, ArrayLike([-1, -2, 3, 4]))
def test_ufunc_two_outputs(self):
- def check(result):
- assert_(type(result) is tuple)
- assert_equal(len(result), 2)
- mantissa, exponent = np.frexp(2 ** -3)
- _assert_equal_type_and_value(result[0], ArrayLike(mantissa))
- _assert_equal_type_and_value(result[1], ArrayLike(exponent))
-
- check(np.frexp(ArrayLike(2 ** -3)))
- check(np.frexp(ArrayLike(np.array(2 ** -3))))
+ mantissa, exponent = np.frexp(2 ** -3)
+ expected = (ArrayLike(mantissa), ArrayLike(exponent))
+ _assert_equal_type_and_value(
+ np.frexp(ArrayLike(2 ** -3)), expected)
+ _assert_equal_type_and_value(
+ np.frexp(ArrayLike(np.array(2 ** -3))), expected)
if __name__ == "__main__":