diff options
author | Teng Li <teng-li@users.noreply.github.com> | 2018-02-21 14:59:53 -0800 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2018-02-21 23:59:52 +0100 |
commit | 579de82bcf055e582262e316c2c680e2947ca58a (patch) | |
tree | e7e99fa18e9934b2b014a908c7e1171f51b0486c /torch | |
parent | 069f66e26778874d65d7e4baeafaa2fe5f8efca7 (diff) | |
download | pytorch-579de82bcf055e582262e316c2c680e2947ca58a.tar.gz pytorch-579de82bcf055e582262e316c2c680e2947ca58a.tar.bz2 pytorch-579de82bcf055e582262e316c2c680e2947ca58a.zip |
DDP: 10% of NCCL backend perf improvements with mixed-prec support (#5064)
Diffstat (limited to 'torch')
-rw-r--r-- | torch/nn/parallel/distributed.py | 125 |
1 files changed, 103 insertions, 22 deletions
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index e40dfb1478..853dc90fa9 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -112,9 +112,14 @@ class DistributedDataParallel(Module): self.output_device = output_device self.broadcast_buffers = broadcast_buffers + # Flag used by the NCCL backend to make sure we only reduce gradients + # one time in the execution engine + self.need_reduction = False + MB = 1024 * 1024 # used for intra-node param sync and inter-node sync as well self.broadcast_bucket_size = 10 * MB + self.nccl_reduce_bucket_size = 256 * MB # Sync params and buffers module_states = list(self.module.state_dict().values()) @@ -135,11 +140,15 @@ class DistributedDataParallel(Module): else: self._module_copies = [self.module] - # Currently NCCL backend only supports single reduction thread/bucket + # For NCCL backend, since every single NCCL call is asynchoronous, we + # therefore directly enqueue all the NCCL reduction calls to the + # default CUDA stream without spawning up other reduction threads. + # This achieves the best performance. if dist._backend == dist.dist_backend.NCCL: - bucket_bytes_cap = float('inf') - else: - bucket_bytes_cap = 1 * MB + self._register_nccl_grad_hook() + return + + bucket_bytes_cap = 1 * MB # This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems param_buckets = [] @@ -149,7 +158,6 @@ class DistributedDataParallel(Module): self.bucket_sizes = [] self.bucket_map = {} - param_types = set() # We transpose param_buckets, so the loop is over buckets. # param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems @@ -161,10 +169,8 @@ class DistributedDataParallel(Module): if idx == 0: # Bucket parameter type tracking bucket_param_type = param_tuple[0].type() - param_types.add(bucket_param_type) # Only gloo and nccl support half-precision if bucket_param_type == torch.cuda.HalfTensor and \ - dist._backend != dist.dist_backend.NCCL and \ dist._backend != dist.dist_backend.GLOO: raise RuntimeError("DistributedDataParallel currently only " "supports half precision parameters " @@ -175,13 +181,6 @@ class DistributedDataParallel(Module): self.bucket_map[p] = bucket_idx self.bucket_sizes[bucket_idx] += 1 - # TODO, adding mixed precision support in NCCL reduction code path - # This is because NCCL backend doesn't support multiple reduction - # bucket. - if len(param_types) > 1 and dist._backend == dist.dist_backend.NCCL: - raise RuntimeError("DistributedDataParallel currently doesn't " - "support mixed precision type for NCCL backend") - self.buckets = [[[] for _ in range(len(self.device_ids))] for _ in range(len(self.bucket_sizes))] self.bucket_events = [[None] * len(self.device_ids) for _ in range(len(self.bucket_sizes))] self.reduced = [False] * len(self.bucket_sizes) @@ -193,16 +192,22 @@ class DistributedDataParallel(Module): def __getstate__(self): attrs = copy.copy(self.__dict__) - del attrs['_grad_accs'], attrs['_reduction_queues'], attrs['_reduction_streams'], \ - attrs['_reduction_threads'], attrs['_nccl_streams'], attrs['_default_streams'] + if dist._backend != dist.dist_backend.NCCL: + del attrs['_grad_accs'], attrs['_reduction_queues'], \ + attrs['_reduction_streams'], attrs['_reduction_threads'], \ + attrs['_nccl_streams'], attrs['_default_streams'] return attrs def __setstate__(self, state): super(DistributedDataParallel, self).__setstate__(state) - self._register_grad_hooks() - self._start_reduction_threads() + if dist._backend == dist.dist_backend.NCCL: + self._register_nccl_grad_hook() + else: + self._register_grad_hooks() + self._start_reduction_threads() def forward(self, *inputs, **kwargs): + self.need_reduction = True inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) self._sync_params() if len(self.device_ids) == 1: @@ -274,7 +279,86 @@ class DistributedDataParallel(Module): grad_acc.register_hook(self._make_param_hook(p, device_idx)) self._grad_accs.append(grad_acc) + def _register_nccl_grad_hook(self): + """ + This function registers the callback all-reduction function for the + NCCL backend. All gradients will be all reduced in one single step. + The NCCL reduction will directly be enqueued into the + default CUDA stream. Therefore, no synchronization is needed. + """ + # Creating a new group + self.nccl_reduction_group_id = dist.new_group() + + def reduction_fn_nccl(): + # This function only needs to be called once + if not self.need_reduction: + return + + self.need_reduction = False + all_grads = [[] for _ in range(len(self._module_copies))] + all_grads_buckets_iters = [] + + # Bucketing all the gradients + for dev_idx, module in enumerate(self._module_copies): + for param in module.parameters(): + if not param.requires_grad or param.grad is None: + continue + if param.grad.requires_grad: + raise RuntimeError("DistributedDataParallel only works " + "with gradients that don't require " + "grad") + # Adding the gradients for reduction + all_grads[dev_idx].append(param.grad.data) + + # Now bucketing the parameters + dev_grads_buckets = _take_tensors(all_grads[dev_idx], + self.nccl_reduce_bucket_size) + + all_grads_buckets_iters.append(dev_grads_buckets) + + # Now reduce each bucket one after another + for grads_batch in zip(*all_grads_buckets_iters): + grads_batch_coalesced = [] + # Coalesce each bucket + for dev_idx, dev_grads_batch in enumerate(grads_batch): + dev_id = self.device_ids[dev_idx] + with torch.cuda.device(dev_id): + dev_grads_batch_coalesced = _flatten_dense_tensors(dev_grads_batch) + grads_batch_coalesced.append(dev_grads_batch_coalesced) + + # We will only use device 0's results, but this single op should be + # faster than doing the following two operation sequentially: + # (1) intra-node reduce to lead GPU, followed by + # (2) inter-node allreduce for all the first lead GPUs in all nodes + dist.all_reduce_multigpu(grads_batch_coalesced, + group=self.nccl_reduction_group_id) + + # Now only work on the first device of self.device_ids, uncoalesce + # the gradients for each bucket + grads_batch_coalesced[0] /= dist.get_world_size() + grads_batch_reduced = _unflatten_dense_tensors(grads_batch_coalesced[0], grads_batch[0]) + for grad, reduced in zip(grads_batch[0], grads_batch_reduced): + grad.copy_(reduced) + + # clear the gradients and save memory for replicas + for module in self._module_copies[1:]: + for param in module.parameters(): + if param.requires_grad: + param.grad = None + param.data.set_() + + # Now register the reduction hook on the parameters + for p in self.module.parameters(): + if not p.requires_grad: + continue + + def allreduce_hook(*unused): + Variable._execution_engine.queue_callback(reduction_fn_nccl) + + p.register_hook(allreduce_hook) + def _make_param_hook(self, param, device_idx): + bucket_idx = self.bucket_map[param] def distributed_data_parallel_hook(*unused): @@ -349,10 +433,7 @@ class DistributedDataParallel(Module): # We only use the first device for distributed reductions dist._register_stream(reduction_streams[0]) - if dist._backend == dist.dist_backend.NCCL: - group_id = dist.group.WORLD - else: - group_id = dist.new_group() + group_id = dist.new_group() self._reduction_threads.append(threading.Thread( target=self._reduction_thread_fn, |