diff options
author | Sam Gross <colesbury@gmail.com> | 2018-01-19 10:58:13 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-19 10:58:13 -0500 |
commit | 870ef8e95f3d9a39368b629e10d2a554c60f9059 (patch) | |
tree | 5414acdee21986926e59af6b849e55b60e1010e7 /tools/autograd | |
parent | b6eb7d7ba0437c280bd0eebcd853c2c38b36b280 (diff) | |
download | pytorch-870ef8e95f3d9a39368b629e10d2a554c60f9059.tar.gz pytorch-870ef8e95f3d9a39368b629e10d2a554c60f9059.tar.bz2 pytorch-870ef8e95f3d9a39368b629e10d2a554c60f9059.zip |
Implement record_stream on Variable (#4728)
The function record_stream is currently only defined on Tensor in
TensorCuda.cwrap. It would be best to implement this in ATen and
automatically bind it to Python, but we're missing ATen types to
represent CUDA streams.
Diffstat (limited to 'tools/autograd')
-rw-r--r-- | tools/autograd/templates/python_variable_methods.cpp | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index d3f5eb699c..f519d1002b 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -7,6 +7,9 @@ #include "torch/csrc/Size.h" #include "torch/csrc/autograd/python_variable.h" #include "torch/csrc/autograd/utils/wrap_outputs.h" +#ifdef WITH_CUDA +#include "torch/csrc/cuda/Stream.h" +#endif #include "torch/csrc/utils/object_ptr.h" #include "torch/csrc/utils/python_arg_parser.h" #include "torch/csrc/utils/python_numbers.h" @@ -405,6 +408,24 @@ static PyObject * THPVariable_numpy(PyObject* self, PyObject* arg) END_HANDLE_TH_ERRORS } +// TODO: move this to ATen. We would need to expose Stream objects in ATen. +static PyObject * THPVariable_record_stream(PyObject* self, PyObject* arg) +{ + HANDLE_TH_ERRORS +#ifdef WITH_CUDA + auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata; + if (!THCPStream_Check(arg)) { + return PyErr_Format(PyExc_TypeError, "expected Stream object"); + } + void* data = self_.data_ptr(); + THCCachingAllocator_recordStream(data, ((THCPStream*)arg)->cdata); + Py_RETURN_NONE; +#else + throw std::runtime_error("PyTorch compiled without CUDA support"); +#endif + END_HANDLE_TH_ERRORS +} + static PyObject * THPVariable_map_(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS @@ -577,6 +598,7 @@ PyMethodDef variable_methods[] = { {"nelement", (PyCFunction)THPVariable_numel, METH_NOARGS, NULL}, {"new", (PyCFunction)THPVariable_new, METH_VARARGS | METH_KEYWORDS, NULL}, {"numpy", (PyCFunction)THPVariable_numpy, METH_NOARGS, NULL}, + {"record_stream", (PyCFunction)THPVariable_record_stream, METH_O, NULL}, {"short", (PyCFunction)THPVariable_short, METH_NOARGS, NULL}, {"size", (PyCFunction)THPVariable_size, METH_VARARGS | METH_KEYWORDS, NULL}, {"storage", (PyCFunction)THPVariable_storage, METH_NOARGS, NULL}, |