diff options
author | Sam Gross <colesbury@gmail.com> | 2018-01-19 09:39:26 -0500 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2018-01-19 09:39:26 -0500 |
commit | db6be0e1f1f90d9298e95c0e612e73e28840425f (patch) | |
tree | 4d7a9e5e675c81430b880ca22da3cdbb2234f21a /tools | |
parent | b997474a4f7dc0c1aa2e29049317169dff0a8680 (diff) | |
download | pytorch-db6be0e1f1f90d9298e95c0e612e73e28840425f.tar.gz pytorch-db6be0e1f1f90d9298e95c0e612e73e28840425f.tar.bz2 pytorch-db6be0e1f1f90d9298e95c0e612e73e28840425f.zip |
Fix call to THPUtils_parseSlice (#4732)
* Fix call to THPUtils_parseSlice
THPUtils_parseSlice returns a bool
* Add Variable.__index__
* Add test
Diffstat (limited to 'tools')
-rw-r--r-- | tools/autograd/templates/python_variable_methods.cpp | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 8ec2883051..d3f5eb699c 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -256,6 +256,20 @@ static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) { END_HANDLE_TH_ERRORS } +// This is the __index__ function in Python which is similar to __int__, but +// called when used as a slice. +static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata; + // TODO: change the condition to `self_.dim() != 0` once we expose scalars + // in PyTorch. + if (!isIntegralType(self_.type().scalarType()) || self_.numel() != 1) { + throw TypeError("only integer tensors of a single element can be converted to an index"); + } + return wrap(dispatch_to_CLong(self_)); + END_HANDLE_TH_ERRORS +} + static Tensor dispatch_invert(const Tensor & self) { AutoNoGIL no_gil; AutoGPU auto_gpu(self); @@ -535,6 +549,7 @@ PyMethodDef variable_methods[] = { {"__float__", (PyCFunction)THPVariable_float_scalar, METH_NOARGS, NULL}, {"__int__", (PyCFunction)THPVariable_integral_scalar, METH_NOARGS, NULL}, {"__long__", (PyCFunction)THPVariable_integral_scalar, METH_NOARGS, NULL}, + {"__index__", (PyCFunction)THPVariable_index_scalar, METH_NOARGS, NULL}, {"__invert__", (PyCFunction)THPVariable_invert, METH_NOARGS, NULL}, {"__nonzero__", (PyCFunction)THPVariable_is_nonzero, METH_NOARGS, NULL}, {"__matmul__", (PyCFunction)THPVariable_matmul, METH_VARARGS | METH_KEYWORDS, NULL}, |