summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorTeng Li <teng-li@users.noreply.github.com>2018-02-21 14:59:53 -0800
committerAdam Paszke <adam.paszke@gmail.com>2018-02-21 23:59:52 +0100
commit579de82bcf055e582262e316c2c680e2947ca58a (patch)
treee7e99fa18e9934b2b014a908c7e1171f51b0486c /torch
parent069f66e26778874d65d7e4baeafaa2fe5f8efca7 (diff)
downloadpytorch-579de82bcf055e582262e316c2c680e2947ca58a.tar.gz
pytorch-579de82bcf055e582262e316c2c680e2947ca58a.tar.bz2
pytorch-579de82bcf055e582262e316c2c680e2947ca58a.zip
DDP: 10% of NCCL backend perf improvements with mixed-prec support (#5064)
Diffstat (limited to 'torch')
-rw-r--r--torch/nn/parallel/distributed.py125
1 files changed, 103 insertions, 22 deletions
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index e40dfb1478..853dc90fa9 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -112,9 +112,14 @@ class DistributedDataParallel(Module):
self.output_device = output_device
self.broadcast_buffers = broadcast_buffers
+ # Flag used by the NCCL backend to make sure we only reduce gradients
+ # one time in the execution engine
+ self.need_reduction = False
+
MB = 1024 * 1024
# used for intra-node param sync and inter-node sync as well
self.broadcast_bucket_size = 10 * MB
+ self.nccl_reduce_bucket_size = 256 * MB
# Sync params and buffers
module_states = list(self.module.state_dict().values())
@@ -135,11 +140,15 @@ class DistributedDataParallel(Module):
else:
self._module_copies = [self.module]
- # Currently NCCL backend only supports single reduction thread/bucket
+ # For NCCL backend, since every single NCCL call is asynchoronous, we
+ # therefore directly enqueue all the NCCL reduction calls to the
+ # default CUDA stream without spawning up other reduction threads.
+ # This achieves the best performance.
if dist._backend == dist.dist_backend.NCCL:
- bucket_bytes_cap = float('inf')
- else:
- bucket_bytes_cap = 1 * MB
+ self._register_nccl_grad_hook()
+ return
+
+ bucket_bytes_cap = 1 * MB
# This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems
param_buckets = []
@@ -149,7 +158,6 @@ class DistributedDataParallel(Module):
self.bucket_sizes = []
self.bucket_map = {}
- param_types = set()
# We transpose param_buckets, so the loop is over buckets.
# param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems
@@ -161,10 +169,8 @@ class DistributedDataParallel(Module):
if idx == 0:
# Bucket parameter type tracking
bucket_param_type = param_tuple[0].type()
- param_types.add(bucket_param_type)
# Only gloo and nccl support half-precision
if bucket_param_type == torch.cuda.HalfTensor and \
- dist._backend != dist.dist_backend.NCCL and \
dist._backend != dist.dist_backend.GLOO:
raise RuntimeError("DistributedDataParallel currently only "
"supports half precision parameters "
@@ -175,13 +181,6 @@ class DistributedDataParallel(Module):
self.bucket_map[p] = bucket_idx
self.bucket_sizes[bucket_idx] += 1
- # TODO, adding mixed precision support in NCCL reduction code path
- # This is because NCCL backend doesn't support multiple reduction
- # bucket.
- if len(param_types) > 1 and dist._backend == dist.dist_backend.NCCL:
- raise RuntimeError("DistributedDataParallel currently doesn't "
- "support mixed precision type for NCCL backend")
-
self.buckets = [[[] for _ in range(len(self.device_ids))] for _ in range(len(self.bucket_sizes))]
self.bucket_events = [[None] * len(self.device_ids) for _ in range(len(self.bucket_sizes))]
self.reduced = [False] * len(self.bucket_sizes)
@@ -193,16 +192,22 @@ class DistributedDataParallel(Module):
def __getstate__(self):
attrs = copy.copy(self.__dict__)
- del attrs['_grad_accs'], attrs['_reduction_queues'], attrs['_reduction_streams'], \
- attrs['_reduction_threads'], attrs['_nccl_streams'], attrs['_default_streams']
+ if dist._backend != dist.dist_backend.NCCL:
+ del attrs['_grad_accs'], attrs['_reduction_queues'], \
+ attrs['_reduction_streams'], attrs['_reduction_threads'], \
+ attrs['_nccl_streams'], attrs['_default_streams']
return attrs
def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state)
- self._register_grad_hooks()
- self._start_reduction_threads()
+ if dist._backend == dist.dist_backend.NCCL:
+ self._register_nccl_grad_hook()
+ else:
+ self._register_grad_hooks()
+ self._start_reduction_threads()
def forward(self, *inputs, **kwargs):
+ self.need_reduction = True
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
self._sync_params()
if len(self.device_ids) == 1:
@@ -274,7 +279,86 @@ class DistributedDataParallel(Module):
grad_acc.register_hook(self._make_param_hook(p, device_idx))
self._grad_accs.append(grad_acc)
+ def _register_nccl_grad_hook(self):
+ """
+ This function registers the callback all-reduction function for the
+ NCCL backend. All gradients will be all reduced in one single step.
+ The NCCL reduction will directly be enqueued into the
+ default CUDA stream. Therefore, no synchronization is needed.
+ """
+ # Creating a new group
+ self.nccl_reduction_group_id = dist.new_group()
+
+ def reduction_fn_nccl():
+ # This function only needs to be called once
+ if not self.need_reduction:
+ return
+
+ self.need_reduction = False
+ all_grads = [[] for _ in range(len(self._module_copies))]
+ all_grads_buckets_iters = []
+
+ # Bucketing all the gradients
+ for dev_idx, module in enumerate(self._module_copies):
+ for param in module.parameters():
+ if not param.requires_grad or param.grad is None:
+ continue
+ if param.grad.requires_grad:
+ raise RuntimeError("DistributedDataParallel only works "
+ "with gradients that don't require "
+ "grad")
+ # Adding the gradients for reduction
+ all_grads[dev_idx].append(param.grad.data)
+
+ # Now bucketing the parameters
+ dev_grads_buckets = _take_tensors(all_grads[dev_idx],
+ self.nccl_reduce_bucket_size)
+
+ all_grads_buckets_iters.append(dev_grads_buckets)
+
+ # Now reduce each bucket one after another
+ for grads_batch in zip(*all_grads_buckets_iters):
+ grads_batch_coalesced = []
+ # Coalesce each bucket
+ for dev_idx, dev_grads_batch in enumerate(grads_batch):
+ dev_id = self.device_ids[dev_idx]
+ with torch.cuda.device(dev_id):
+ dev_grads_batch_coalesced = _flatten_dense_tensors(dev_grads_batch)
+ grads_batch_coalesced.append(dev_grads_batch_coalesced)
+
+ # We will only use device 0's results, but this single op should be
+ # faster than doing the following two operation sequentially:
+ # (1) intra-node reduce to lead GPU, followed by
+ # (2) inter-node allreduce for all the first lead GPUs in all nodes
+ dist.all_reduce_multigpu(grads_batch_coalesced,
+ group=self.nccl_reduction_group_id)
+
+ # Now only work on the first device of self.device_ids, uncoalesce
+ # the gradients for each bucket
+ grads_batch_coalesced[0] /= dist.get_world_size()
+ grads_batch_reduced = _unflatten_dense_tensors(grads_batch_coalesced[0], grads_batch[0])
+ for grad, reduced in zip(grads_batch[0], grads_batch_reduced):
+ grad.copy_(reduced)
+
+ # clear the gradients and save memory for replicas
+ for module in self._module_copies[1:]:
+ for param in module.parameters():
+ if param.requires_grad:
+ param.grad = None
+ param.data.set_()
+
+ # Now register the reduction hook on the parameters
+ for p in self.module.parameters():
+ if not p.requires_grad:
+ continue
+
+ def allreduce_hook(*unused):
+ Variable._execution_engine.queue_callback(reduction_fn_nccl)
+
+ p.register_hook(allreduce_hook)
+
def _make_param_hook(self, param, device_idx):
+
bucket_idx = self.bucket_map[param]
def distributed_data_parallel_hook(*unused):
@@ -349,10 +433,7 @@ class DistributedDataParallel(Module):
# We only use the first device for distributed reductions
dist._register_stream(reduction_streams[0])
- if dist._backend == dist.dist_backend.NCCL:
- group_id = dist.group.WORLD
- else:
- group_id = dist.new_group()
+ group_id = dist.new_group()
self._reduction_threads.append(threading.Thread(
target=self._reduction_thread_fn,