summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorSam Gross <colesbury@gmail.com>2018-01-19 09:39:26 -0500
committerSoumith Chintala <soumith@gmail.com>2018-01-19 09:39:26 -0500
commitdb6be0e1f1f90d9298e95c0e612e73e28840425f (patch)
tree4d7a9e5e675c81430b880ca22da3cdbb2234f21a /tools
parentb997474a4f7dc0c1aa2e29049317169dff0a8680 (diff)
downloadpytorch-db6be0e1f1f90d9298e95c0e612e73e28840425f.tar.gz
pytorch-db6be0e1f1f90d9298e95c0e612e73e28840425f.tar.bz2
pytorch-db6be0e1f1f90d9298e95c0e612e73e28840425f.zip
Fix call to THPUtils_parseSlice (#4732)
* Fix call to THPUtils_parseSlice THPUtils_parseSlice returns a bool * Add Variable.__index__ * Add test
Diffstat (limited to 'tools')
-rw-r--r--tools/autograd/templates/python_variable_methods.cpp15
1 files changed, 15 insertions, 0 deletions
diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp
index 8ec2883051..d3f5eb699c 100644
--- a/tools/autograd/templates/python_variable_methods.cpp
+++ b/tools/autograd/templates/python_variable_methods.cpp
@@ -256,6 +256,20 @@ static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) {
END_HANDLE_TH_ERRORS
}
+// This is the __index__ function in Python which is similar to __int__, but
+// called when used as a slice.
+static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) {
+ HANDLE_TH_ERRORS
+ auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
+ // TODO: change the condition to `self_.dim() != 0` once we expose scalars
+ // in PyTorch.
+ if (!isIntegralType(self_.type().scalarType()) || self_.numel() != 1) {
+ throw TypeError("only integer tensors of a single element can be converted to an index");
+ }
+ return wrap(dispatch_to_CLong(self_));
+ END_HANDLE_TH_ERRORS
+}
+
static Tensor dispatch_invert(const Tensor & self) {
AutoNoGIL no_gil;
AutoGPU auto_gpu(self);
@@ -535,6 +549,7 @@ PyMethodDef variable_methods[] = {
{"__float__", (PyCFunction)THPVariable_float_scalar, METH_NOARGS, NULL},
{"__int__", (PyCFunction)THPVariable_integral_scalar, METH_NOARGS, NULL},
{"__long__", (PyCFunction)THPVariable_integral_scalar, METH_NOARGS, NULL},
+ {"__index__", (PyCFunction)THPVariable_index_scalar, METH_NOARGS, NULL},
{"__invert__", (PyCFunction)THPVariable_invert, METH_NOARGS, NULL},
{"__nonzero__", (PyCFunction)THPVariable_is_nonzero, METH_NOARGS, NULL},
{"__matmul__", (PyCFunction)THPVariable_matmul, METH_VARARGS | METH_KEYWORDS, NULL},