diff options
author | Stephan Hoyer <shoyer@google.com> | 2017-05-05 21:50:18 -0700 |
---|---|---|
committer | Stephan Hoyer <shoyer@google.com> | 2017-05-07 17:03:40 -0700 |
commit | d51b538ba80d36841cc57911d77ea61cd1d3fb25 (patch) | |
tree | 0c8cbc64c3ae6eb4dd0bda4a810d64902b974324 | |
parent | c9d1f9e467155cec3030b0970816abe928244b9c (diff) | |
download | python-numpy-d51b538ba80d36841cc57911d77ea61cd1d3fb25.tar.gz python-numpy-d51b538ba80d36841cc57911d77ea61cd1d3fb25.tar.bz2 python-numpy-d51b538ba80d36841cc57911d77ea61cd1d3fb25.zip |
ENH: add divmod support to NDArrayOperatorsMixin
-rw-r--r-- | numpy/lib/mixins.py | 8 | ||||
-rw-r--r-- | numpy/lib/tests/test_mixins.py | 92 |
2 files changed, 60 insertions, 40 deletions
diff --git a/numpy/lib/mixins.py b/numpy/lib/mixins.py index bbeed1437..fbdc2edfb 100644 --- a/numpy/lib/mixins.py +++ b/numpy/lib/mixins.py @@ -70,8 +70,7 @@ class NDArrayOperatorsMixin(object): implement. This class does not yet implement the special operators corresponding - to ``divmod`` or ``matmul`` (``@``), because these operation do not yet - have corresponding NumPy ufuncs. + to ``matmul`` (``@``), because ``np.matmul`` is not yet a NumPy ufunc. It is useful for writing classes that do not inherit from `numpy.ndarray`, but that should support arithmetic and numpy universal functions like @@ -161,7 +160,10 @@ class NDArrayOperatorsMixin(object): um.true_divide, 'truediv') __floordiv__, __rfloordiv__, __ifloordiv__ = _numeric_methods( um.floor_divide, 'floordiv') - __mod__, __rmod__, __imod__ = _numeric_methods(um.mod, 'mod') + __mod__, __rmod__, __imod__ = _numeric_methods(um.remainder, 'mod') + __divmod__ = _binary_method(um.divmod, 'divmod') + __rdivmod__ = _reflected_binary_method(um.divmod, 'divmod') + # __idivmod__ does not exist # TODO: handle the optional third argument for __pow__? __pow__, __rpow__, __ipow__ = _numeric_methods(um.power, 'pow') __lshift__, __rlshift__, __ilshift__ = _numeric_methods( diff --git a/numpy/lib/tests/test_mixins.py b/numpy/lib/tests/test_mixins.py index 287d4ed29..db38bdfd6 100644 --- a/numpy/lib/tests/test_mixins.py +++ b/numpy/lib/tests/test_mixins.py @@ -56,11 +56,47 @@ class ArrayLike(np.lib.mixins.NDArrayOperatorsMixin): return '%s(%r)' % (type(self).__name__, self.value) +def wrap_array_like(result): + if type(result) is tuple: + return tuple(ArrayLike(r) for r in result) + else: + return ArrayLike(result) + + def _assert_equal_type_and_value(result, expected, err_msg=None): assert_equal(type(result), type(expected), err_msg=err_msg) - assert_equal(result.value, expected.value, err_msg=err_msg) - assert_equal(getattr(result.value, 'dtype', None), - getattr(expected.value, 'dtype', None), err_msg=err_msg) + if isinstance(result, tuple): + assert_equal(len(result), len(expected), err_msg=err_msg) + for result_item, expected_item in zip(result, expected): + _assert_equal_type_and_value(result_item, expected_item, err_msg) + else: + assert_equal(result.value, expected.value, err_msg=err_msg) + assert_equal(getattr(result.value, 'dtype', None), + getattr(expected.value, 'dtype', None), err_msg=err_msg) + + +_ALL_BINARY_OPERATORS = [ + operator.lt, + operator.le, + operator.eq, + operator.ne, + operator.gt, + operator.ge, + operator.add, + operator.sub, + operator.mul, + operator.truediv, + operator.floordiv, + # TODO: test div on Python 2, only + operator.mod, + divmod, + pow, + operator.lshift, + operator.rshift, + operator.and_, + operator.xor, + operator.or_, +] class TestNDArrayOperatorsMixin(TestCase): @@ -148,52 +184,34 @@ class TestNDArrayOperatorsMixin(TestCase): operator.invert]: _assert_equal_type_and_value(op(array_like), ArrayLike(op(array))) - def test_binary_methods(self): + def test_forward_binary_methods(self): array = np.array([-1, 0, 1, 2]) array_like = ArrayLike(array) - operators = [ - operator.lt, - operator.le, - operator.eq, - operator.ne, - operator.gt, - operator.ge, - operator.add, - operator.sub, - operator.mul, - operator.truediv, - operator.floordiv, - # TODO: test div on Python 2, only - operator.mod, - # divmod is not yet implemented - pow, - operator.lshift, - operator.rshift, - operator.and_, - operator.xor, - operator.or_, - ] - for op in operators: - expected = ArrayLike(op(array, 1)) + for op in _ALL_BINARY_OPERATORS: + expected = wrap_array_like(op(array, 1)) actual = op(array_like, 1) err_msg = 'failed for operator {}'.format(op) _assert_equal_type_and_value(expected, actual, err_msg=err_msg) + def test_reflected_binary_methods(self): + for op in _ALL_BINARY_OPERATORS: + expected = wrap_array_like(op(2, 1)) + actual = op(2, ArrayLike(1)) + err_msg = 'failed for operator {}'.format(op) + _assert_equal_type_and_value(expected, actual, err_msg=err_msg) + def test_ufunc_at(self): array = ArrayLike(np.array([1, 2, 3, 4])) assert_(np.negative.at(array, np.array([0, 1])) is None) _assert_equal_type_and_value(array, ArrayLike([-1, -2, 3, 4])) def test_ufunc_two_outputs(self): - def check(result): - assert_(type(result) is tuple) - assert_equal(len(result), 2) - mantissa, exponent = np.frexp(2 ** -3) - _assert_equal_type_and_value(result[0], ArrayLike(mantissa)) - _assert_equal_type_and_value(result[1], ArrayLike(exponent)) - - check(np.frexp(ArrayLike(2 ** -3))) - check(np.frexp(ArrayLike(np.array(2 ** -3)))) + mantissa, exponent = np.frexp(2 ** -3) + expected = (ArrayLike(mantissa), ArrayLike(exponent)) + _assert_equal_type_and_value( + np.frexp(ArrayLike(2 ** -3)), expected) + _assert_equal_type_and_value( + np.frexp(ArrayLike(np.array(2 ** -3))), expected) if __name__ == "__main__": |