diff options
author | Janusz Marcinkiewicz <virrages@gmail.com> | 2016-12-19 20:49:43 +0100 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2017-01-31 01:58:09 +0100 |
commit | 76520512e7e267dfeb610cfcc3d5ce2f50c3c351 (patch) | |
tree | 338d49c48750dea414d3cb5242620d6528b31492 /torch/csrc/distributed | |
parent | 66de96588216b945fc475f52e353cdf915d2c52f (diff) | |
download | pytorch-76520512e7e267dfeb610cfcc3d5ce2f50c3c351.tar.gz pytorch-76520512e7e267dfeb610cfcc3d5ce2f50c3c351.tar.bz2 pytorch-76520512e7e267dfeb610cfcc3d5ce2f50c3c351.zip |
DataChannel tests rewrite (#42); DataChannel `isend` and `irecv` implementation (#44)
Diffstat (limited to 'torch/csrc/distributed')
-rw-r--r-- | torch/csrc/distributed/Module.cpp | 66 | ||||
-rw-r--r-- | torch/csrc/distributed/THDP.h | 1 |
2 files changed, 66 insertions, 1 deletions
diff --git a/torch/csrc/distributed/Module.cpp b/torch/csrc/distributed/Module.cpp index e81a7cc79a..be2b54aa85 100644 --- a/torch/csrc/distributed/Module.cpp +++ b/torch/csrc/distributed/Module.cpp @@ -115,6 +115,11 @@ static THDTensorDescriptor* _makeDescriptor(PyObject *obj) "type ") + std::string(THPUtils_typename(obj))); } +static THDRequest* _unpackRequest(PyObject *obj) +{ + return static_cast<THDRequest*>(THPWrapper_get(obj)); +} + static THDReduceOp _getReduceOp(PyObject *obj) { auto it = obj2reduceop.find(obj); @@ -137,6 +142,36 @@ static THDGroup _getGroup(PyObject *obj) return it->second; } +PyObject* THDPModule_isend(PyObject *_unused, PyObject *args) +{ + HANDLE_TH_ERRORS + if (PyTuple_GET_SIZE(args) != 2 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) || + !THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) { + THPUtils_invalidArguments(args, "send", 1, "(tensor input, int dst_rank)"); + return NULL; + } + + THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0)); + int dst_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1)); + return THPWrapper_New(THDIsend(desc, dst_rank), (void(*)(void*))THDRequest_free); + END_HANDLE_TH_ERRORS +} + +PyObject* THDPModule_irecv(PyObject *_unused, PyObject *args) +{ + HANDLE_TH_ERRORS + if (PyTuple_GET_SIZE(args) != 2 || !THPModule_isTensor(PyTuple_GET_ITEM(args, 0)) || + !THPUtils_checkLong(PyTuple_GET_ITEM(args, 1))) { + THPUtils_invalidArguments(args, "recv", 1, "(tensor output, int src_rank)"); + return NULL; + } + + THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0)); + int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1)); + return THPWrapper_New(THDIrecv(desc, src_rank), (void(*)(void*))THDRequest_free); + END_HANDLE_TH_ERRORS +} + PyObject* THDPModule_send(PyObject *_unused, PyObject *args) { HANDLE_TH_ERRORS @@ -164,7 +199,7 @@ PyObject* THDPModule_recv(PyObject *_unused, PyObject *args) THDPTensorDesc desc = _makeDescriptor(PyTuple_GET_ITEM(args, 0)); int src_rank = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 1)); - THDReceive(desc, src_rank); + THDRecv(desc, src_rank); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -431,6 +466,31 @@ invalid_arguments: END_HANDLE_TH_ERRORS } +PyObject* THDPModule_requestIsCompleted(PyObject *_unused, PyObject *_req) +{ + HANDLE_TH_ERRORS + if (!THPWrapper_check(_req)) { + THPUtils_invalidArguments(_req, "requestIsCompleted", 1, "(request req)"); + return NULL; + } + + return PyBool_FromLong(THDRequest_isCompleted(_unpackRequest(_req))); + END_HANDLE_TH_ERRORS +} + +PyObject* THDPModule_requestWait(PyObject *_unused, PyObject *_req) +{ + HANDLE_TH_ERRORS + if (!THPWrapper_check(_req)) { + THPUtils_invalidArguments(_req, "requestWait", 1, "(request req)"); + return NULL; + } + + THDRequest_wait(_unpackRequest(_req)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + PyObject* THDPModule_initExtension(PyObject *_unused, PyObject *args) { if (PyTuple_GET_SIZE(args) != 3) { THPUtils_invalidArguments(args, "initExtension", 1, "(bool is_master_worker, reduce_op obj, group obj)"); @@ -478,6 +538,8 @@ static struct PyMethodDef _THDPModule_methods[] = { {"_dist_init_master_worker", (PyCFunction)THDPModule_initMasterWorker, METH_O, NULL}, {"_dist_get_rank", (PyCFunction)THDPModule_getRank, METH_NOARGS, NULL}, {"_dist_get_num_processes", (PyCFunction)THDPModule_getNumProcesses, METH_NOARGS, NULL}, + {"_dist_isend", (PyCFunction)THDPModule_isend, METH_VARARGS, NULL}, + {"_dist_irecv", (PyCFunction)THDPModule_irecv, METH_VARARGS, NULL}, {"_dist_send", (PyCFunction)THDPModule_send, METH_VARARGS, NULL}, {"_dist_recv", (PyCFunction)THDPModule_recv, METH_VARARGS, NULL}, {"_dist_all_reduce", (PyCFunction)THDPModule_allReduce, METH_VARARGS, NULL}, @@ -490,6 +552,8 @@ static struct PyMethodDef _THDPModule_methods[] = { {"_dist_scatter_recv", (PyCFunction)THDPModule_scatterRecv, METH_VARARGS, NULL}, {"_dist_barrier", (PyCFunction)THDPModule_barrier, METH_O, NULL}, {"_dist_new_group", (PyCFunction)THDPModule_newGroup, METH_VARARGS, NULL}, + {"_dist_request_is_completed", (PyCFunction)THDPModule_requestIsCompleted, METH_O, NULL}, + {"_dist_request_wait", (PyCFunction)THDPModule_requestWait, METH_O, NULL}, {NULL} }; diff --git a/torch/csrc/distributed/THDP.h b/torch/csrc/distributed/THDP.h index e3f18f2090..058466d22e 100644 --- a/torch/csrc/distributed/THDP.h +++ b/torch/csrc/distributed/THDP.h @@ -7,6 +7,7 @@ #include "Module.h" #include "Storage.h" #include "Tensor.h" +#include "../PtrWrapper.h" #ifdef _THP_CORE #include "utils.h" #endif |