summaryrefslogtreecommitdiff
path: root/tools/autograd
diff options
context:
space:
mode:
authorSam Gross <colesbury@gmail.com>2018-01-19 10:58:13 -0500
committerGitHub <noreply@github.com>2018-01-19 10:58:13 -0500
commit870ef8e95f3d9a39368b629e10d2a554c60f9059 (patch)
tree5414acdee21986926e59af6b849e55b60e1010e7 /tools/autograd
parentb6eb7d7ba0437c280bd0eebcd853c2c38b36b280 (diff)
downloadpytorch-870ef8e95f3d9a39368b629e10d2a554c60f9059.tar.gz
pytorch-870ef8e95f3d9a39368b629e10d2a554c60f9059.tar.bz2
pytorch-870ef8e95f3d9a39368b629e10d2a554c60f9059.zip
Implement record_stream on Variable (#4728)
The function record_stream is currently only defined on Tensor in TensorCuda.cwrap. It would be best to implement this in ATen and automatically bind it to Python, but we're missing ATen types to represent CUDA streams.
Diffstat (limited to 'tools/autograd')
-rw-r--r--tools/autograd/templates/python_variable_methods.cpp22
1 files changed, 22 insertions, 0 deletions
diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp
index d3f5eb699c..f519d1002b 100644
--- a/tools/autograd/templates/python_variable_methods.cpp
+++ b/tools/autograd/templates/python_variable_methods.cpp
@@ -7,6 +7,9 @@
#include "torch/csrc/Size.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
+#ifdef WITH_CUDA
+#include "torch/csrc/cuda/Stream.h"
+#endif
#include "torch/csrc/utils/object_ptr.h"
#include "torch/csrc/utils/python_arg_parser.h"
#include "torch/csrc/utils/python_numbers.h"
@@ -405,6 +408,24 @@ static PyObject * THPVariable_numpy(PyObject* self, PyObject* arg)
END_HANDLE_TH_ERRORS
}
+// TODO: move this to ATen. We would need to expose Stream objects in ATen.
+static PyObject * THPVariable_record_stream(PyObject* self, PyObject* arg)
+{
+ HANDLE_TH_ERRORS
+#ifdef WITH_CUDA
+ auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
+ if (!THCPStream_Check(arg)) {
+ return PyErr_Format(PyExc_TypeError, "expected Stream object");
+ }
+ void* data = self_.data_ptr();
+ THCCachingAllocator_recordStream(data, ((THCPStream*)arg)->cdata);
+ Py_RETURN_NONE;
+#else
+ throw std::runtime_error("PyTorch compiled without CUDA support");
+#endif
+ END_HANDLE_TH_ERRORS
+}
+
static PyObject * THPVariable_map_(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
@@ -577,6 +598,7 @@ PyMethodDef variable_methods[] = {
{"nelement", (PyCFunction)THPVariable_numel, METH_NOARGS, NULL},
{"new", (PyCFunction)THPVariable_new, METH_VARARGS | METH_KEYWORDS, NULL},
{"numpy", (PyCFunction)THPVariable_numpy, METH_NOARGS, NULL},
+ {"record_stream", (PyCFunction)THPVariable_record_stream, METH_O, NULL},
{"short", (PyCFunction)THPVariable_short, METH_NOARGS, NULL},
{"size", (PyCFunction)THPVariable_size, METH_VARARGS | METH_KEYWORDS, NULL},
{"storage", (PyCFunction)THPVariable_storage, METH_NOARGS, NULL},