diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2018-03-24 23:27:10 -0700 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2018-03-25 11:32:56 -0700 |
commit | deaa28c4b1da392995b8301f69bc2c20e48142b4 (patch) | |
tree | d5e7ad4154040d4faf059e9456f4ef33435c02cd /numpy | |
parent | e4d678a2f5859d29a853d617e9e5bbd4b6241898 (diff) | |
download | python-numpy-deaa28c4b1da392995b8301f69bc2c20e48142b4.tar.gz python-numpy-deaa28c4b1da392995b8301f69bc2c20e48142b4.tar.bz2 python-numpy-deaa28c4b1da392995b8301f69bc2c20e48142b4.zip |
BUG: Allow spaces in output string of einsum
Also produce more useful error messages
Fixes gh-10794
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/einsum.c.src | 29 | ||||
-rw-r--r-- | numpy/core/tests/test_einsum.py | 9 |
2 files changed, 28 insertions, 10 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index 7db606194..f71cf17e7 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -1829,9 +1829,10 @@ parse_operand_subscripts(char *subscripts, int length, break; } else { - PyErr_SetString(PyExc_ValueError, + PyErr_Format(PyExc_ValueError, "einstein sum subscripts string contains a " - "'.' that is not part of an ellipsis ('...')"); + "'.' that is not part of an ellipsis ('...') in " + "operand %d", iop); return 0; } @@ -1888,6 +1889,12 @@ parse_operand_subscripts(char *subscripts, int length, return 0; } } + else if (label == '.') { + PyErr_Format(PyExc_ValueError, + "einstein sum subscripts string contains a " + "'.' that is not part of an ellipsis ('...') in " + "operand %d", iop); + } else if (label != ' ') { PyErr_Format(PyExc_ValueError, "invalid subscript '%c' in einstein sum " @@ -2011,7 +2018,8 @@ parse_output_subscripts(char *subscripts, int length, else { PyErr_SetString(PyExc_ValueError, "einstein sum subscripts string contains a " - "'.' that is not part of an ellipsis ('...')"); + "'.' that is not part of an ellipsis ('...') " + "in the output"); return -1; } @@ -2037,8 +2045,15 @@ parse_output_subscripts(char *subscripts, int length, if (i > 0) { for (i = 0; i < length; ++i) { label = subscripts[i]; + if (label == '.') { + PyErr_SetString(PyExc_ValueError, + "einstein sum subscripts string contains a " + "'.' that is not part of an ellipsis ('...') " + "in the output"); + return -1; + } /* A label for an axis */ - if (label != '.' && label != ' ') { + else if (label != ' ') { if (idim < ndim_left) { out_labels[idim++] = label; } @@ -2049,12 +2064,6 @@ parse_output_subscripts(char *subscripts, int length, return -1; } } - else { - PyErr_SetString(PyExc_ValueError, - "einstein sum subscripts string contains a " - "'.' that is not part of an ellipsis ('...')"); - return -1; - } } } diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index 9bd85fdb9..bdcd0c852 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -1,5 +1,7 @@ from __future__ import division, absolute_import, print_function +import itertools + import numpy as np from numpy.testing import ( run_module_suite, assert_, assert_equal, assert_array_equal, @@ -918,6 +920,13 @@ class TestEinSumPath(object): opt = np.einsum(*path_test, optimize=exp_path) assert_almost_equal(noopt, opt) + def test_spaces(self): + #gh-10794 + arr = np.array([[1]]) + for sp in itertools.product(['', ' '], repeat=4): + # no error for any spacing + np.einsum('{}...a{}->{}...a{}'.format(*sp), arr) + if __name__ == "__main__": run_module_suite() |