summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2018-03-26 09:23:10 -0400
committerGitHub <noreply@github.com>2018-03-26 09:23:10 -0400
commitf336efc1c026682c3c0c0b21eecdb3df29cca6d1 (patch)
tree965e766b8d61207f7b0b7bd0388c6dcb8161bcff
parent342f0e34c4921a20f193c339e4ca15e0d70a10ae (diff)
parentdeaa28c4b1da392995b8301f69bc2c20e48142b4 (diff)
downloadpython-numpy-f336efc1c026682c3c0c0b21eecdb3df29cca6d1.tar.gz
python-numpy-f336efc1c026682c3c0c0b21eecdb3df29cca6d1.tar.bz2
python-numpy-f336efc1c026682c3c0c0b21eecdb3df29cca6d1.zip
Merge pull request #10795 from eric-wieser/einsum-output-spaces
BUG: Allow spaces in output string of einsum
-rw-r--r--numpy/core/src/multiarray/einsum.c.src29
-rw-r--r--numpy/core/tests/test_einsum.py9
2 files changed, 28 insertions, 10 deletions
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src
index e482d3280..5dbc30aa9 100644
--- a/numpy/core/src/multiarray/einsum.c.src
+++ b/numpy/core/src/multiarray/einsum.c.src
@@ -1829,9 +1829,10 @@ parse_operand_subscripts(char *subscripts, int length,
break;
}
else {
- PyErr_SetString(PyExc_ValueError,
+ PyErr_Format(PyExc_ValueError,
"einstein sum subscripts string contains a "
- "'.' that is not part of an ellipsis ('...')");
+ "'.' that is not part of an ellipsis ('...') in "
+ "operand %d", iop);
return 0;
}
@@ -1888,6 +1889,12 @@ parse_operand_subscripts(char *subscripts, int length,
return 0;
}
}
+ else if (label == '.') {
+ PyErr_Format(PyExc_ValueError,
+ "einstein sum subscripts string contains a "
+ "'.' that is not part of an ellipsis ('...') in "
+ "operand %d", iop);
+ }
else if (label != ' ') {
PyErr_Format(PyExc_ValueError,
"invalid subscript '%c' in einstein sum "
@@ -2011,7 +2018,8 @@ parse_output_subscripts(char *subscripts, int length,
else {
PyErr_SetString(PyExc_ValueError,
"einstein sum subscripts string contains a "
- "'.' that is not part of an ellipsis ('...')");
+ "'.' that is not part of an ellipsis ('...') "
+ "in the output");
return -1;
}
@@ -2037,8 +2045,15 @@ parse_output_subscripts(char *subscripts, int length,
if (i > 0) {
for (i = 0; i < length; ++i) {
label = subscripts[i];
+ if (label == '.') {
+ PyErr_SetString(PyExc_ValueError,
+ "einstein sum subscripts string contains a "
+ "'.' that is not part of an ellipsis ('...') "
+ "in the output");
+ return -1;
+ }
/* A label for an axis */
- if (label != '.' && label != ' ') {
+ else if (label != ' ') {
if (idim < ndim_left) {
out_labels[idim++] = label;
}
@@ -2049,12 +2064,6 @@ parse_output_subscripts(char *subscripts, int length,
return -1;
}
}
- else {
- PyErr_SetString(PyExc_ValueError,
- "einstein sum subscripts string contains a "
- "'.' that is not part of an ellipsis ('...')");
- return -1;
- }
}
}
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index 5863c11ad..406b10eab 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -1,5 +1,7 @@
from __future__ import division, absolute_import, print_function
+import itertools
+
import numpy as np
from numpy.testing import (
run_module_suite, assert_, assert_equal, assert_array_equal,
@@ -924,6 +926,13 @@ class TestEinSumPath(object):
opt = np.einsum(*path_test, optimize=exp_path)
assert_almost_equal(noopt, opt)
+ def test_spaces(self):
+ #gh-10794
+ arr = np.array([[1]])
+ for sp in itertools.product(['', ' '], repeat=4):
+ # no error for any spacing
+ np.einsum('{}...a{}->{}...a{}'.format(*sp), arr)
+
if __name__ == "__main__":
run_module_suite()