summaryrefslogtreecommitdiff
path: root/docs
diff options
context:
space:
mode:
authorTeng Li <teng-li@users.noreply.github.com>2018-04-15 18:53:10 -0700
committerSoumith Chintala <soumith@gmail.com>2018-04-15 21:53:10 -0400
commitf5beff334bb511ff50d606e4ae5b47938723cd13 (patch)
tree84e0eccc1441b7dee03fa3ccc0db7a7b8c8c545c /docs
parent5463a4a3199b89aa1e944068694fa0f50635807b (diff)
downloadpytorch-f5beff334bb511ff50d606e4ae5b47938723cd13.tar.gz
pytorch-f5beff334bb511ff50d606e4ae5b47938723cd13.tar.bz2
pytorch-f5beff334bb511ff50d606e4ae5b47938723cd13.zip
Added distributed docs on NCCL2 backend/functions and launch module (#6579)
Diffstat (limited to 'docs')
-rw-r--r--docs/source/distributed.rst123
1 files changed, 98 insertions, 25 deletions
diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst
index 27decd0f99..23846f18b1 100644
--- a/docs/source/distributed.rst
+++ b/docs/source/distributed.rst
@@ -7,35 +7,35 @@ Distributed communication package - torch.distributed
.. automodule:: torch.distributed
.. currentmodule:: torch.distributed
-Currently torch.distributed supports three backends, each with
+Currently torch.distributed supports four backends, each with
different capabilities. The table below shows which functions are available
for use with CPU / CUDA tensors.
MPI supports cuda only if the implementation used to build PyTorch supports it.
-+------------+-----------+-----------+-----------+
-| Backend | ``tcp`` | ``gloo`` | ``mpi`` |
-+------------+-----+-----+-----+-----+-----+-----+
-| Device | CPU | GPU | CPU | GPU | CPU | GPU |
-+============+=====+=====+=====+=====+=====+=====+
-| send | ✓ | ✘ | ✘ | ✘ | ✓ | ? |
-+------------+-----+-----+-----+-----+-----+-----+
-| recv | ✓ | ✘ | ✘ | ✘ | ✓ | ? |
-+------------+-----+-----+-----+-----+-----+-----+
-| broadcast | ✓ | ✘ | ✓ | ✓ | ✓ | ? |
-+------------+-----+-----+-----+-----+-----+-----+
-| all_reduce | ✓ | ✘ | ✓ | ✓ | ✓ | ? |
-+------------+-----+-----+-----+-----+-----+-----+
-| reduce | ✓ | ✘ | ✘ | ✘ | ✓ | ? |
-+------------+-----+-----+-----+-----+-----+-----+
-| all_gather | ✓ | ✘ | ✘ | ✘ | ✓ | ? |
-+------------+-----+-----+-----+-----+-----+-----+
-| gather | ✓ | ✘ | ✘ | ✘ | ✓ | ? |
-+------------+-----+-----+-----+-----+-----+-----+
-| scatter | ✓ | ✘ | ✘ | ✘ | ✓ | ? |
-+------------+-----+-----+-----+-----+-----+-----+
-| barrier | ✓ | ✘ | ✓ | ✓ | ✓ | ? |
-+------------+-----+-----+-----+-----+-----+-----+
++------------+-----------+-----------+-----------+-----------+
+| Backend | ``tcp`` | ``gloo`` | ``mpi`` | ``nccl`` |
++------------+-----+-----+-----+-----+-----+-----+-----+-----+
+| Device | CPU | GPU | CPU | GPU | CPU | GPU | CPU | GPU |
++============+=====+=====+=====+=====+=====+=====+=====+=====+
+| send | ✓ | ✘ | ✘ | ✘ | ✓ | ? | ✘ | ✘ |
++------------+-----+-----+-----+-----+-----+-----+-----+-----+
+| recv | ✓ | ✘ | ✘ | ✘ | ✓ | ? | ✘ | ✘ |
++------------+-----+-----+-----+-----+-----+-----+-----+-----+
+| broadcast | ✓ | ✘ | ✓ | ✓ | ✓ | ? | ✘ | ✓ |
++------------+-----+-----+-----+-----+-----+-----+-----+-----+
+| all_reduce | ✓ | ✘ | ✓ | ✓ | ✓ | ? | ✘ | ✓ |
++------------+-----+-----+-----+-----+-----+-----+-----+-----+
+| reduce | ✓ | ✘ | ✘ | ✘ | ✓ | ? | ✘ | ✓ |
++------------+-----+-----+-----+-----+-----+-----+-----+-----+
+| all_gather | ✓ | ✘ | ✘ | ✘ | ✓ | ? | ✘ | ✓ |
++------------+-----+-----+-----+-----+-----+-----+-----+-----+
+| gather | ✓ | ✘ | ✘ | ✘ | ✓ | ? | ✘ | ✓ |
++------------+-----+-----+-----+-----+-----+-----+-----+-----+
+| scatter | ✓ | ✘ | ✘ | ✘ | ✓ | ? | ✘ | ✓ |
++------------+-----+-----+-----+-----+-----+-----+-----+-----+
+| barrier | ✓ | ✘ | ✓ | ✓ | ✓ | ? | ✘ | ✘ |
++------------+-----+-----+-----+-----+-----+-----+-----+-----+
.. _distributed-basics:
@@ -173,7 +173,7 @@ as they should never be created manually, but they are guaranteed to support two
* ``is_completed()`` - returns True if the operation has finished
* ``wait()`` - will block the process until the operation is finished.
``is_completed()`` is guaranteed to return True once it returns.
-
+
When using the MPI backend, :func:`~torch.distributed.isend` and :func:`~torch.distributed.irecv`
support non-overtaking, which has some guarantees on supporting message order. For more detail, see
http://mpi-forum.org/docs/mpi-2.2/mpi22-report/node54.htm#Node54
@@ -199,3 +199,76 @@ Collective functions
.. autofunction:: barrier
+Multi-GPU collective functions
+------------------------------
+
+If you have more than one GPU on each node, when using the NCCL backend,
+:func:`~torch.distributed.broadcast_multigpu`
+:func:`~torch.distributed.all_reduce_multigpu`
+:func:`~torch.distributed.reduce_multigpu` and
+:func:`~torch.distributed.all_gather_multigpu` support distributed collective
+operations among multiple GPUs within each node. These functions can potentially
+improve the overall distributed training performance and be easily used by
+passing a list of tensors. Each Tensor in the passed tensor list needs
+to be on a separate GPU device of the host where the function is called. Note
+that the length of the tensor list needs to be identical among all the
+distributed processes. Also note that currently the multi-GPU collective
+functions are only supported by the NCCL backend.
+
+For example, if the system we use for distributed training has 2 nodes, each
+of which has 8 GPUs. On each of the 16 GPUs, there is a tensor that we would
+like to all-reduce. The following code can serve as a reference:
+
+Code running on Node 0
+
+::
+
+ import torch
+ import torch.distributed as dist
+
+ dist.init_process_group(backend="nccl",
+ init_method="file:///distributed_test",
+ world_size=2,
+ rank=0)
+ tensor_list = []
+ for dev_idx in range(torch.cuda.device_count()):
+ tensor_list.append(torch.FloatTensor([1]).cuda(dev_idx))
+
+ dist.all_reduce_multigpu(tensor_list)
+
+Code running on Node 1
+
+::
+
+ import torch
+ import torch.distributed as dist
+
+ dist.init_process_group(backend="nccl",
+ init_method="file:///distributed_test",
+ world_size=2,
+ rank=1)
+ tensor_list = []
+ for dev_idx in range(torch.cuda.device_count()):
+ tensor_list.append(torch.FloatTensor([1]).cuda(dev_idx))
+
+ dist.all_reduce_multigpu(tensor_list)
+
+After the call, all 16 tensors on the two nodes will have the all-reduced value
+of 16
+
+.. autofunction:: broadcast_multigpu
+
+.. autofunction:: all_reduce_multigpu
+
+.. autofunction:: reduce_multigpu
+
+.. autofunction:: all_gather_multigpu
+
+
+Launch utility
+--------------
+
+The `torch.distributed` package also provides a launch utility in
+`torch.distributed.launch`.
+
+.. automodule:: torch.distributed.launch