From c0eeb69dfb7326948f7d0063aa1b4112ff575cee Mon Sep 17 00:00:00 2001 From: Zhen Zhang Date: Sun, 31 Oct 2021 02:59:44 -0400 Subject: [PATCH] ZeRO3, improved parameter all-gather operation (#1188) * remove norm(), avoid memcpy after allgather 1) Removing the norm computation in debug printing 2) Changing _all_gather to be sync op in fetch_sub_module Reason: the async version is not async at all, because each all_gather calls torch.cuda.synchronize() to guarantee previous communication op to be completed 3) Adding new function _allgather_params_split_launch the existing _allgather_params has explicit memcpy after the all-gather op. We can avoid the explicit memory copy at python side, to improve the performance. Known issue: the `torch.distributed.all_gather` will do implicit memcpy at the end of each `ncclAllgather`. * WIP: wrapped ncclAllgather as customized op in DS micro benchmark shows the improvement of allgather a transformer layer with 9834560 elements in half precision is about 1.1ms on aws-p4d instance. * WIP: integrated into partition_parameters Performance improvement of 5.1B bert on aws-p4d: fwd: 300ms -> 200ms bwd: 680ms -> 610ms * Fix format * cleaned dead code, modified unit test * removed customized c++ extension revert back to use torch distributed API * change torch.ones to torch empty * typo * warn if not cuda tensor for allgather * fix formatting * fix: move ds_tensor to cuda device but it is strange that the ds_tensor haven't been moved to cuda * remove try clause on the path for fetching params Co-authored-by: Olatunji Ruwase Co-authored-by: Jeff Rasley --- .../runtime/zero/partition_parameters.py | 120 ++++++++++++++++-- deepspeed/runtime/zero/stage3.py | 2 +- 2 files changed, 107 insertions(+), 15 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 868d66d8224c..4e1bd22b5d8c 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -11,7 +11,8 @@ import itertools import torch -from torch.distributed.distributed_c10d import _get_global_rank +from torch.distributed.distributed_c10d import _get_global_rank, group +import torch.distributed as dist from .linear import LinearModuleForZeroStage3, LinearFunctionForZeroStage3 from .offload_constants import * @@ -496,6 +497,14 @@ def get_model(): assert isinstance(module, torch.nn.Module) self._convert_to_zero_parameters(module.parameters(recurse=True)) + self.use_all_gather_base = False + try: + from torch.distributed.distributed_c10d import _all_gather_base as all_gather + self.use_all_gather_base = True + except: + logger.info( + f"_all_gather_base API is not available in torch {torch.__version__}") + def _convert_to_zero_parameters(self, param_list): for param in param_list: if is_zero_param(param): @@ -686,7 +695,9 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): all_gather_list.append(param) if not async_op: - ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) + # ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) + ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy) + for param in all_gather_list: param.ds_status = ZeroParamStatus.AVAILABLE return ret_value @@ -732,8 +743,10 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): see_memory_usage( f'Before partitioning param {param.ds_id} {param.shape}', force=False) + # param.data does not store anything meaningful in partitioned state - param.data = torch.ones(1, dtype=self.dtype).to(param.device) + param.data = torch.empty(1, dtype=self.dtype, device=param.device) + see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) @@ -754,7 +767,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): numel=partition_size): final_location = OFFLOAD_NVME_DEVICE buffer = self.param_swapper.get_buffer(param, partition_size) - partitioned_tensor = torch.zeros(1, + partitioned_tensor = torch.empty(1, dtype=param.dtype, device=buffer.device) partitioned_tensor.data = buffer.data @@ -763,7 +776,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): ) else: - partitioned_tensor = torch.zeros( + partitioned_tensor = torch.empty( partition_size, dtype=param.dtype, device=OFFLOAD_CPU_DEVICE @@ -873,22 +886,101 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): # param.ds_numel).view(param.ds_shape) # param.data = replicated_tensor.data # return None - partitions = [] - for i in range(self.world_size): - partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size)) + if self.use_all_gather_base: + # try the _all_gather_base on PyTorch master branch + handle = dist._all_gather_base(flat_tensor, + param.ds_tensor, + group=self.ds_process_group, + async_op=async_op) + else: + partitions = [] + for i in range(self.world_size): + partitions.append( + flat_tensor.narrow(0, + partition_size * i, + partition_size)) - if i == torch.distributed.get_rank(group=self.ds_process_group): - partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True) + if i == dist.get_rank(group=self.ds_process_group): + partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True) - handle = torch.distributed.all_gather(partitions, - partitions[self.rank], - group=self.ds_process_group, - async_op=async_op) + handle = dist.all_gather(partitions, + partitions[self.rank], + group=self.ds_process_group, + async_op=async_op) replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape) param.data = replicated_tensor.data return handle + def _allgather_params_coalesced(self, param_list, hierarchy=0): + """ blocking call + avoid explicit memory copy in _allgather_params + """ + if len(param_list) == 0: + return + # collect local tensors and partition sizes + partition_sizes = [] + local_tensors = [] + for param in param_list: + partition_sizes.append(param.ds_tensor.ds_numel) + local_tensors.append(param.ds_tensor.cuda()) + + # allocate memory for allgather params + allgather_params = [] + for psize in partition_sizes: + tensor_size = psize * self.world_size + flat_tensor = torch.empty(tensor_size, + dtype=param_list[0].dtype, + device=self.local_device).view(-1) + flat_tensor.requires_grad = False + allgather_params.append(flat_tensor) + + # launch + launch_handles = [] + # backend = get_backend(self.ds_process_group) + # with _batch_p2p_manager(backend): + for param_idx, param in enumerate(param_list): + input_tensor = local_tensors[param_idx].view(-1) + + if self.use_all_gather_base: + # try the _all_gather_base from Pytorch master + h = dist._all_gather_base(allgather_params[param_idx], + input_tensor, + group=self.ds_process_group, + async_op=True) + else: + output_list = [] + for i in range(self.world_size): + psize = partition_sizes[param_idx] + partition = allgather_params[param_idx].narrow(0, i * psize, psize) + output_list.append(partition) + if not partition.is_cuda: + logger.warning( + f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}' + ) + + # back to old all_gather function signature + h = dist.all_gather(output_list, + input_tensor, + group=self.ds_process_group, + async_op=True) + launch_handles.append(h) + + # Wait ensures the operation is enqueued, but not necessarily complete. + launch_handles[-1].wait() + + # assign to param.data (not copy) + for i, param in enumerate(param_list): + gathered_tensor = allgather_params[i] + param.data = gathered_tensor.narrow(0, + 0, + param.ds_numel).view(param.ds_shape).data + + # guarantee the communication to be completed + torch.cuda.synchronize() + + return None + def _allgather_params(self, param_list, hierarchy=0): if len(param_list) == 0: return diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 5989b0f72a97..221c5e51d340 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -444,7 +444,7 @@ def fetch_sub_module(self, sub_module): self.hierarchy += 1 # parameters are partitioned and need to be allgathered - self._all_gather(partitioned_params, async_op=True) + self._all_gather(partitioned_params, async_op=False) # parameters are inflight and communication needs to be completed if partitioned_params or params_in_flight: