summaryrefslogtreecommitdiff
path: root/torch/csrc/distributed
diff options
context:
space:
mode:
authorJanusz Marcinkiewicz <virrages@gmail.com>2016-12-19 20:49:43 +0100
committerAdam Paszke <adam.paszke@gmail.com>2017-01-31 01:58:09 +0100
commit76520512e7e267dfeb610cfcc3d5ce2f50c3c351 (patch)
tree338d49c48750dea414d3cb5242620d6528b31492 /torch/csrc/distributed
parent66de96588216b945fc475f52e353cdf915d2c52f (diff)
downloadpytorch-76520512e7e267dfeb610cfcc3d5ce2f50c3c351.tar.gz
pytorch-76520512e7e267dfeb610cfcc3d5ce2f50c3c351.tar.bz2
pytorch-76520512e7e267dfeb610cfcc3d5ce2f50c3c351.zip
DataChannel tests rewrite (#42); DataChannel `isend` and `irecv` implementation (#44)
Diffstat (limited to 'torch/csrc/distributed')
-rw-r--r--torch/csrc/distributed/Module.cpp66
-rw-r--r--torch/csrc/distributed/THDP.h1
2 files changed, 66 insertions, 1 deletions
diff --git a/torch/csrc/distributed/Module.cpp b/torch/csrc/distributed/Module.cpp
index e81a7cc79a..be2b54aa85 100644
--- a/torch/csrc/distributed/Module.cpp
+++ b/torch/csrc/distributed/Module.cpp
@@ -115,6 +115,11 @@ static THDTensorDescriptor* _makeDescriptor(PyObject *obj)
"type ") + std::string(THPUtils_typename(obj)));
}
+static THDRequest* _unpackRequest(PyObject *obj)
+{
+ return static_cast<THDRequest*>(THPWrapper_get(obj));
+}
+
static THDReduceOp _getReduceOp(PyObject *obj)
{
auto it = obj2reduceop.find(obj);
@@ -137,6 +142,36 @@ static THDGroup _getGroup(PyObject *obj)
return it->second;
}
+PyObject* THDPModule_isend(PyObject *_unused, PyObject *args)
+{
+ HANDLE_TH_ERRORS
+ if (PyTuple_GET_SIZE(args) != 2 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
+ !THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
+ THPUtils_invalidArguments(args, "send", 1, "(tensor input, int dst_rank)");
+ return NULL;
+ }
+
+ THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
+ int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
+ return THPWrapper_New(THDIsend(desc, dst_rank), (void(*)(void*))THDRequest_free);
+ END_HANDLE_TH_ERRORS
+}
+
+PyObject* THDPModule_irecv(PyObject *_unused, PyObject *args)
+{
+ HANDLE_TH_ERRORS
+ if (PyTuple_GET_SIZE(args) != 2 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) ||
+ !THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) {
+ THPUtils_invalidArguments(args, "recv", 1, "(tensor output, int src_rank)");
+ return NULL;
+ }
+
+ THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
+ int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
+ return THPWrapper_New(THDIrecv(desc, src_rank), (void(*)(void*))THDRequest_free);
+ END_HANDLE_TH_ERRORS
+}
+
PyObject* THDPModule_send(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
@@ -164,7 +199,7 @@ PyObject* THDPModule_recv(PyObject *_unused, PyObject *args)
THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0));
int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1));
- THDReceive(desc, src_rank);
+ THDRecv(desc, src_rank);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
@@ -431,6 +466,31 @@ invalid_arguments:
END_HANDLE_TH_ERRORS
}
+PyObject* THDPModule_requestIsCompleted(PyObject *_unused, PyObject *_req)
+{
+ HANDLE_TH_ERRORS
+ if (!THPWrapper_check(_req)) {
+ THPUtils_invalidArguments(_req, "requestIsCompleted", 1, "(request req)");
+ return NULL;
+ }
+
+ return PyBool_FromLong(THDRequest_isCompleted(_unpackRequest(_req)));
+ END_HANDLE_TH_ERRORS
+}
+
+PyObject* THDPModule_requestWait(PyObject *_unused, PyObject *_req)
+{
+ HANDLE_TH_ERRORS
+ if (!THPWrapper_check(_req)) {
+ THPUtils_invalidArguments(_req, "requestWait", 1, "(request req)");
+ return NULL;
+ }
+
+ THDRequest_wait(_unpackRequest(_req));
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
PyObject* THDPModule_initExtension(PyObject *_unused, PyObject *args) {
if (PyTuple_GET_SIZE(args) != 3) {
THPUtils_invalidArguments(args, "initExtension", 1, "(bool is_master_worker, reduce_op obj, group obj)");
@@ -478,6 +538,8 @@ static struct PyMethodDef _THDPModule_methods[] = {
{"_dist_init_master_worker", (PyCFunction)THDPModule_initMasterWorker, METH_O, NULL},
{"_dist_get_rank", (PyCFunction)THDPModule_getRank, METH_NOARGS, NULL},
{"_dist_get_num_processes", (PyCFunction)THDPModule_getNumProcesses, METH_NOARGS, NULL},
+ {"_dist_isend", (PyCFunction)THDPModule_isend, METH_VARARGS, NULL},
+ {"_dist_irecv", (PyCFunction)THDPModule_irecv, METH_VARARGS, NULL},
{"_dist_send", (PyCFunction)THDPModule_send, METH_VARARGS, NULL},
{"_dist_recv", (PyCFunction)THDPModule_recv, METH_VARARGS, NULL},
{"_dist_all_reduce", (PyCFunction)THDPModule_allReduce, METH_VARARGS, NULL},
@@ -490,6 +552,8 @@ static struct PyMethodDef _THDPModule_methods[] = {
{"_dist_scatter_recv", (PyCFunction)THDPModule_scatterRecv, METH_VARARGS, NULL},
{"_dist_barrier", (PyCFunction)THDPModule_barrier, METH_O, NULL},
{"_dist_new_group", (PyCFunction)THDPModule_newGroup, METH_VARARGS, NULL},
+ {"_dist_request_is_completed", (PyCFunction)THDPModule_requestIsCompleted, METH_O, NULL},
+ {"_dist_request_wait", (PyCFunction)THDPModule_requestWait, METH_O, NULL},
{NULL}
};
diff --git a/torch/csrc/distributed/THDP.h b/torch/csrc/distributed/THDP.h
index e3f18f2090..058466d22e 100644
--- a/torch/csrc/distributed/THDP.h
+++ b/torch/csrc/distributed/THDP.h
@@ -7,6 +7,7 @@
#include "Module.h"
#include "Storage.h"
#include "Tensor.h"
+#include "../PtrWrapper.h"
#ifdef _THP_CORE
#include "utils.h"
#endif