diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index db986c747970..cee5701502e2 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -11,6 +11,7 @@ import functools import itertools from typing import List +import logging import torch from torch import Tensor from deepspeed import comm as dist @@ -898,7 +899,8 @@ def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = False) - # to debug correctness issues. params = sorted(params, key=lambda p: p.ds_id) - debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}") + if logger.isEnabledFor(logging.DEBUG): + debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}") if safe_mode: # ensure that same list (with same ordering) of parameters are diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index ff2cfff8f8c0..8bf999458d8e 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -15,6 +15,7 @@ from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id from deepspeed.accelerator import get_accelerator +import logging def debug_rank0(message: str) -> None: @@ -235,25 +236,28 @@ def fetch_sub_module(self, current_submodule: Module) -> None: 2. kick off fetch for next few parameters we will need later (prefetch) 3. block on parameters in immediately required sub module """ - debug_rank0( - f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} " - + str({ - "avail": f"{self.__n_available_params:.1e}", - "queue_sz": f"{len(self.__param_queue or [])}", - "inflight": [p.ds_id for p in self.__inflight_param_registry], - })) + if logger.isEnabledFor(logging.DEBUG): + debug_rank0( + f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} " + + str({ + "avail": f"{self.__n_available_params:.1e}", + "queue_sz": f"{len(self.__param_queue or [])}", + "inflight": [p.ds_id for p in self.__inflight_param_registry], + })) params_to_fetch = frozenset(iter_params(current_submodule)) # kick off all gather for params in the immediately required submodule - for param in params_to_fetch: - debug_rank0(f"-fetch: {param.ds_summary()}") + if logger.isEnabledFor(logging.DEBUG): + for param in params_to_fetch: + debug_rank0(f"-fetch: {param.ds_summary()}") self.__all_gather_params(params_to_fetch) # wait for parameters in the immediately needed submodule to become available for param in params_to_fetch: param.ds_active_sub_modules.add(current_submodule.id) - debug_rank0(f"-wait: {param.ds_summary()}") + if logger.isEnabledFor(logging.DEBUG): + debug_rank0(f"-wait: {param.ds_summary()}") if param in self.__inflight_param_registry: with get_accelerator().stream(self.__allgather_stream): while self.__ongoing_fetch_events and self.__ongoing_fetch_events[0].query(): @@ -328,8 +332,9 @@ def _is_currently_on_nvme(param): params_to_prefetch.add(param_in_trace.param) numel_prefetching += param_in_trace.param.ds_numel - for param in params_to_prefetch: - debug_rank0(f"-prefetch: {param.ds_summary()}") + if logger.isEnabledFor(logging.DEBUG): + for param in params_to_prefetch: + debug_rank0(f"-prefetch: {param.ds_summary()}") self.__all_gather_params(params_to_prefetch) if self.__prefetch_nvme: @@ -394,7 +399,8 @@ def __all_gather_params(self, params: Set[Parameter]) -> None: @instrument_w_nvtx def __release_param(self, param: Parameter) -> None: if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: - debug_rank0(f"-release: {param.ds_summary()}") + if logger.isEnabledFor(logging.DEBUG): + debug_rank0(f"-release: {param.ds_summary()}") param.partition() self.__n_available_params -= param.ds_numel diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index c5359a827282..7448cb51ae25 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -261,7 +261,6 @@ def __init__(self, if self.swap_optimizer: self._configure_tensor_swapping(offload_optimizer_config, aio_config) - self.params_in_ipg_bucket = [] self.is_gradient_accumulation_boundary: bool = True self.param_reduce_events: Deque[get_accelerator().Event] = collections.deque() @@ -277,7 +276,6 @@ def __init__(self, self.grads_in_ipg_bucket = [] self.params_in_ipg_bucket = [] - self.params_already_reduced = [] self.is_gradient_accumulation_boundary = True self._release_ipg_buffers() self.previous_reduced_grads = None @@ -291,7 +289,6 @@ def __init__(self, unique_id = id(param) self.param_id[unique_id] = count self.param_dict[count] = param - self.params_already_reduced.append(False) count = count + 1 #Largest partitioned param @@ -307,7 +304,6 @@ def __init__(self, if self.offload_optimizer: self.norm_for_param_grads = {} - self.local_overflow = False # stores if a partition has been reduced in this step self.is_partition_reduced = {} @@ -397,20 +393,20 @@ def _setup_for_real_optimizer(self): dtype=self.dtype, device=get_accelerator().current_device_name()) - grad_partitions_flat_buffer = None + self.grad_partitions_flat_buffer = None self.__param_id_to_grad_partition: Dict[int, Tensor] = {} all_params = list(itertools.chain.from_iterable(self.fp16_groups)) - grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params), - dtype=self.dtype, - device=self.device) + self.grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params), + dtype=self.dtype, + device=self.device) if self.offload_optimizer_pin_memory: - grad_partitions_flat_buffer = get_accelerator().pin_memory(grad_partitions_flat_buffer) + self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer) offset = 0 for param in all_params: - self.__param_id_to_grad_partition[param.ds_id] = grad_partitions_flat_buffer.narrow( + self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow( 0, offset, param.partition_numel()) offset += param.partition_numel() @@ -966,11 +962,6 @@ def independent_gradient_partition_epilogue(self): self.reduce_and_partition_stream.synchronize() - # if dist.get_rank() == 0: - # logger.info("Params already reduced %s", self.params_already_reduced) - for i in range(len(self.params_already_reduced)): - self.params_already_reduced[i] = False - #in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad #TODO: use a similar code path for both cpu_offload and non-cpu offload if not self.offload_optimizer: @@ -1045,18 +1036,11 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): # 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a # garbage data and `self.average_tensor()` will crash because its params_to_reduce will be # empty, while reduction_list will have that garbage data. - if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size: + if self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size and self.elements_in_ipg_bucket > 0: self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.ds_numel) self.__reduce_and_partition_ipg_grads() - param_id = self.get_param_id(param) - - assert self.params_already_reduced[param_id] == False, \ - f"The parameter {param_id} has already been reduced. \ - Gradient computed twice for this partition. \ - Multiple gradient reduction is currently not supported" - self.__add_grad_to_ipg_bucket(param) @instrument_w_nvtx @@ -1087,8 +1071,6 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: raise RuntimeError(f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter " f"gradients whose size is not same as the params") - self.params_in_ipg_bucket.sort(key=lambda p: p.ds_id) - assert len(set(p.ds_id for p in self.params_in_ipg_bucket)) == len(self.params_in_ipg_bucket) while self.param_reduce_events and self.param_reduce_events[0].query(): @@ -1100,7 +1082,13 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: if safe_mode: assert_ints_same_as_other_ranks([p.ds_id for p in self.params_in_ipg_bucket]) - grad_partitions = self.__avg_scatter_grads(self.params_in_ipg_bucket) + if self.contiguous_gradients and not self.reduce_scatter: + grad_bucket = self.__ipg_bucket_flat_buffer.narrow(0, 0, self.elements_in_ipg_bucket) + grad_partitions = self.__avg_scatter_contiguous_grads(grad_bucket) + else: + self.params_in_ipg_bucket.sort(key=lambda p: p.ds_id) + grad_partitions = self.__avg_scatter_grads(self.params_in_ipg_bucket) + self.partition_grads(self.params_in_ipg_bucket, grad_partitions) self.params_in_ipg_bucket.clear() @@ -1109,6 +1097,47 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: event.record() self.param_reduce_events.append(event) + @instrument_w_nvtx + def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor) -> List[Tensor]: + dtype = buffer_to_reduce.dtype + if self.communication_data_type == self.dtype: + buffer_to_reduce = buffer_to_reduce.to(self.communication_data_type) + if self.postscale_gradients and self.gradient_predivide_factor != 1.0: + buffer_to_reduce = buffer_to_reduce.div_(self.gradient_predivide_factor) + + world_sz = dist.get_world_size(self.dp_process_group) + rank = dist.get_rank(self.dp_process_group) + buffer_to_reduce.div_(world_sz) + + dist.all_reduce(buffer_to_reduce, group=self.dp_process_group) + + if self.postscale_gradients and self.gradient_predivide_factor != world_sz: + buffer_to_reduce = buffer_to_reduce.mul(self.gradient_predivide_factor) + + if self.communication_data_type != self.dtype: + buffer_to_reduce = buffer_to_reduce.to(self.dtype) + + grad_partitions = [] + grad_offset_in_buffer = 0 + for param in self.params_in_ipg_bucket: + grad = param.grad + chunk_sz = math.ceil(grad.numel() / world_sz) + + start_offset = grad_offset_in_buffer + min(rank * chunk_sz, grad.numel()) + end_offset = grad_offset_in_buffer + min(rank * chunk_sz + chunk_sz, grad.numel()) + + partition = buffer_to_reduce[start_offset:end_offset] + if param.partition_numel() != partition.numel(): + padded_partition = torch.empty(param.partition_numel(), device=grad.device, dtype=grad.dtype) + if partition.numel() > 0: + padded_partition[:partition.numel()] = partition + grad_partitions.append(padded_partition) + else: + grad_partitions.append(partition) + grad_offset_in_buffer += grad.numel() + + return grad_partitions + @instrument_w_nvtx def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]: """average gradients and scatter partitions across ranks""" @@ -1223,15 +1252,6 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L # operations and so it can be used asynchronously grad_buffer = cuda_grad_buffer - if hasattr(self.inf_or_nan_tracker, "logical_or_"): - self.inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any()) - self.inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any()) - else: - # logical_or_ not available in older versions of pytorch - self.inf_or_nan_tracker += torch.isinf(grad_buffer).any() - self.inf_or_nan_tracker += torch.isnan(grad_buffer).any() - self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0 - # offload the gradient partition if applicable if self.offload_optimizer: i, dest_offset, _ = self.grad_position[self.get_param_id(param)] @@ -1567,7 +1587,6 @@ def free_grad_in_param_list(self, param_list): def reset_cpu_buffers(self): self.norm_for_param_grads = {} - self.local_overflow = False def log_timers(self, timer_names): if self.timers is None: @@ -1901,12 +1920,19 @@ def has_overflow_partitioned_grads_serial(self): def has_overflow(self, partition_gradients=True): if partition_gradients: with get_accelerator().stream(self.reduce_and_partition_stream): - self.local_overflow = bool(self.inf_or_nan_tracker.item()) + if hasattr(self.inf_or_nan_tracker, "logical_or_"): + self.inf_or_nan_tracker.logical_or_(torch.isinf(self.grad_partitions_flat_buffer).any()) + self.inf_or_nan_tracker.logical_or_(torch.isnan(self.grad_partitions_flat_buffer).any()) + else: + # logical_or_ not available in older versions of pytorch + self.inf_or_nan_tracker += torch.isinf(self.grad_partitions_flat_buffer).any() + self.inf_or_nan_tracker += torch.isnan(self.grad_partitions_flat_buffer).any() + self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0 + + overflow_gpu = self.inf_or_nan_tracker.clone().to(torch.uint8) self.inf_or_nan_tracker.zero_() - overflow = self.local_overflow - #overflow = self.has_overflow_partitioned_grads_serial() - overflow_gpu = get_accelerator().ByteTensor([overflow]) + get_accelerator().default_stream().wait_stream(self.reduce_and_partition_stream) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group) else: diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index 5773c060cf74..28576f1f4b74 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -736,10 +736,11 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: @pytest.mark.parametrize("init_context_manager", [True, False]) +@pytest.mark.parametrize("reduce_scatter", [True, False]) class TestZero3ParamPartitioningLargeParam(DistributedTest): world_size = 4 - def test(self, init_context_manager: bool, param_sz: int = 8100) -> None: + def test(self, init_context_manager: bool, reduce_scatter: bool, param_sz: int = 8100) -> None: class LargeParamModel(Module): @@ -767,6 +768,7 @@ def forward(self, x: Tensor) -> Tensor: "stage3_max_reuse_distance": 0, "contiguous_gradients": True, "overlap_comm": True, + "reduce_scatter": reduce_scatter, }, "optimizer": { "type": "Adam",