summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2018-03-25 06:27:10 (GMT)
committerEric Wieser <wieser.eric@gmail.com>2018-03-25 18:32:56 (GMT)
commitdeaa28c4b1da392995b8301f69bc2c20e48142b4 (patch)
treed5e7ad4154040d4faf059e9456f4ef33435c02cd
parente4d678a2f5859d29a853d617e9e5bbd4b6241898 (diff)
downloadpython-numpy-deaa28c4b1da392995b8301f69bc2c20e48142b4.zip
python-numpy-deaa28c4b1da392995b8301f69bc2c20e48142b4.tar.gz
python-numpy-deaa28c4b1da392995b8301f69bc2c20e48142b4.tar.bz2
BUG: Allow spaces in output string of einsum
Also produce more useful error messages Fixes gh-10794
-rw-r--r--numpy/core/src/multiarray/einsum.c.src29
-rw-r--r--numpy/core/tests/test_einsum.py9
2 files changed, 28 insertions, 10 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src
index 7db6061..f71cf17 100644
--- a/numpy/core/src/multiarray/einsum.c.src
+++ b/numpy/core/src/multiarray/einsum.c.src
@@ -1829,9 +1829,10 @@ parse_operand_subscripts(char *subscripts, int length,
break;
}
else {
- PyErr_SetString(PyExc_ValueError,
+ PyErr_Format(PyExc_ValueError,
"einstein sum subscripts string contains a "
- "'.' that is not part of an ellipsis ('...')");
+ "'.' that is not part of an ellipsis ('...') in "
+ "operand %d", iop);
return 0;
}
@@ -1888,6 +1889,12 @@ parse_operand_subscripts(char *subscripts, int length,
return 0;
}
}
+ else if (label == '.') {
+ PyErr_Format(PyExc_ValueError,
+ "einstein sum subscripts string contains a "
+ "'.' that is not part of an ellipsis ('...') in "
+ "operand %d", iop);
+ }
else if (label != ' ') {
PyErr_Format(PyExc_ValueError,
"invalid subscript '%c' in einstein sum "
@@ -2011,7 +2018,8 @@ parse_output_subscripts(char *subscripts, int length,
else {
PyErr_SetString(PyExc_ValueError,
"einstein sum subscripts string contains a "
- "'.' that is not part of an ellipsis ('...')");
+ "'.' that is not part of an ellipsis ('...') "
+ "in the output");
return -1;
}
@@ -2037,8 +2045,15 @@ parse_output_subscripts(char *subscripts, int length,
if (i > 0) {
for (i = 0; i < length; ++i) {
label = subscripts[i];
+ if (label == '.') {
+ PyErr_SetString(PyExc_ValueError,
+ "einstein sum subscripts string contains a "
+ "'.' that is not part of an ellipsis ('...') "
+ "in the output");
+ return -1;
+ }
/* A label for an axis */
- if (label != '.' && label != ' ') {
+ else if (label != ' ') {
if (idim < ndim_left) {
out_labels[idim++] = label;
}
@@ -2049,12 +2064,6 @@ parse_output_subscripts(char *subscripts, int length,
return -1;
}
}
- else {
- PyErr_SetString(PyExc_ValueError,
- "einstein sum subscripts string contains a "
- "'.' that is not part of an ellipsis ('...')");
- return -1;
- }
}
}
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index 9bd85fd..bdcd0c8 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -1,5 +1,7 @@
from __future__ import division, absolute_import, print_function
+import itertools
+
import numpy as np
from numpy.testing import (
run_module_suite, assert_, assert_equal, assert_array_equal,
@@ -918,6 +920,13 @@ class TestEinSumPath(object):
opt = np.einsum(*path_test, optimize=exp_path)
assert_almost_equal(noopt, opt)
+ def test_spaces(self):
+ #gh-10794
+ arr = np.array([[1]])
+ for sp in itertools.product(['', ' '], repeat=4):
+ # no error for any spacing
+ np.einsum('{}...a{}->{}...a{}'.format(*sp), arr)
+
if __name__ == "__main__":
run_module_suite()