diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 868d66d8224c..4e1bd22b5d8c 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -11,7 +11,8 @@ import itertools import torch -from torch.distributed.distributed_c10d import _get_global_rank +from torch.distributed.distributed_c10d import _get_global_rank, group +import torch.distributed as dist from .linear import LinearModuleForZeroStage3, LinearFunctionForZeroStage3 from .offload_constants import * @@ -496,6 +497,14 @@ def get_model(): assert isinstance(module, torch.nn.Module) self._convert_to_zero_parameters(module.parameters(recurse=True)) + self.use_all_gather_base = False + try: + from torch.distributed.distributed_c10d import _all_gather_base as all_gather + self.use_all_gather_base = True + except: + logger.info( + f"_all_gather_base API is not available in torch {torch.__version__}") + def _convert_to_zero_parameters(self, param_list): for param in param_list: if is_zero_param(param): @@ -686,7 +695,9 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): all_gather_list.append(param) if not async_op: - ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) + # ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) + ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy) + for param in all_gather_list: param.ds_status = ZeroParamStatus.AVAILABLE return ret_value @@ -732,8 +743,10 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): see_memory_usage( f'Before partitioning param {param.ds_id} {param.shape}', force=False) + # param.data does not store anything meaningful in partitioned state - param.data = torch.ones(1, dtype=self.dtype).to(param.device) + param.data = torch.empty(1, dtype=self.dtype, device=param.device) + see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) @@ -754,7 +767,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): numel=partition_size): final_location = OFFLOAD_NVME_DEVICE buffer = self.param_swapper.get_buffer(param, partition_size) - partitioned_tensor = torch.zeros(1, + partitioned_tensor = torch.empty(1, dtype=param.dtype, device=buffer.device) partitioned_tensor.data = buffer.data @@ -763,7 +776,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): ) else: - partitioned_tensor = torch.zeros( + partitioned_tensor = torch.empty( partition_size, dtype=param.dtype, device=OFFLOAD_CPU_DEVICE @@ -873,22 +886,101 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): # param.ds_numel).view(param.ds_shape) # param.data = replicated_tensor.data # return None - partitions = [] - for i in range(self.world_size): - partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size)) + if self.use_all_gather_base: + # try the _all_gather_base on PyTorch master branch + handle = dist._all_gather_base(flat_tensor, + param.ds_tensor, + group=self.ds_process_group, + async_op=async_op) + else: + partitions = [] + for i in range(self.world_size): + partitions.append( + flat_tensor.narrow(0, + partition_size * i, + partition_size)) - if i == torch.distributed.get_rank(group=self.ds_process_group): - partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True) + if i == dist.get_rank(group=self.ds_process_group): + partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True) - handle = torch.distributed.all_gather(partitions, - partitions[self.rank], - group=self.ds_process_group, - async_op=async_op) + handle = dist.all_gather(partitions, + partitions[self.rank], + group=self.ds_process_group, + async_op=async_op) replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape) param.data = replicated_tensor.data return handle + def _allgather_params_coalesced(self, param_list, hierarchy=0): + """ blocking call + avoid explicit memory copy in _allgather_params + """ + if len(param_list) == 0: + return + # collect local tensors and partition sizes + partition_sizes = [] + local_tensors = [] + for param in param_list: + partition_sizes.append(param.ds_tensor.ds_numel) + local_tensors.append(param.ds_tensor.cuda()) + + # allocate memory for allgather params + allgather_params = [] + for psize in partition_sizes: + tensor_size = psize * self.world_size + flat_tensor = torch.empty(tensor_size, + dtype=param_list[0].dtype, + device=self.local_device).view(-1) + flat_tensor.requires_grad = False + allgather_params.append(flat_tensor) + + # launch + launch_handles = [] + # backend = get_backend(self.ds_process_group) + # with _batch_p2p_manager(backend): + for param_idx, param in enumerate(param_list): + input_tensor = local_tensors[param_idx].view(-1) + + if self.use_all_gather_base: + # try the _all_gather_base from Pytorch master + h = dist._all_gather_base(allgather_params[param_idx], + input_tensor, + group=self.ds_process_group, + async_op=True) + else: + output_list = [] + for i in range(self.world_size): + psize = partition_sizes[param_idx] + partition = allgather_params[param_idx].narrow(0, i * psize, psize) + output_list.append(partition) + if not partition.is_cuda: + logger.warning( + f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}' + ) + + # back to old all_gather function signature + h = dist.all_gather(output_list, + input_tensor, + group=self.ds_process_group, + async_op=True) + launch_handles.append(h) + + # Wait ensures the operation is enqueued, but not necessarily complete. + launch_handles[-1].wait() + + # assign to param.data (not copy) + for i, param in enumerate(param_list): + gathered_tensor = allgather_params[i] + param.data = gathered_tensor.narrow(0, + 0, + param.ds_numel).view(param.ds_shape).data + + # guarantee the communication to be completed + torch.cuda.synchronize() + + return None + def _allgather_params(self, param_list, hierarchy=0): if len(param_list) == 0: return diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 5989b0f72a97..221c5e51d340 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -444,7 +444,7 @@ def fetch_sub_module(self, sub_module): self.hierarchy += 1 # parameters are partitioned and need to be allgathered - self._all_gather(partitioned_params, async_op=True) + self._all_gather(partitioned_params, async_op=False) # parameters are inflight and communication needs to be completed if partitioned_params or params_in_flight: