From 3a87dfa2133daf9ba338a86a2c404f0ff1b903de Mon Sep 17 00:00:00 2001 From: Liran Bachar Date: Sun, 28 May 2023 13:08:50 +0300 Subject: [PATCH] move overflow tracker to optimizer.step Don't check overflow in gradients for every bucket. Do overflow chack once on grad flat buffer just before optimizer step --- deepspeed/runtime/zero/stage3.py | 33 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 21a85d886b7d..83cca20233c7 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -304,7 +304,6 @@ def __init__(self, if self.offload_optimizer: self.norm_for_param_grads = {} - self.local_overflow = False # stores if a partition has been reduced in this step self.is_partition_reduced = {} @@ -394,20 +393,20 @@ def _setup_for_real_optimizer(self): dtype=self.dtype, device=get_accelerator().current_device_name()) - grad_partitions_flat_buffer = None + self.grad_partitions_flat_buffer = None self.__param_id_to_grad_partition: Dict[int, Tensor] = {} all_params = list(itertools.chain.from_iterable(self.fp16_groups)) - grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params), + self.grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params), dtype=self.dtype, device=self.device) if self.offload_optimizer_pin_memory: - grad_partitions_flat_buffer = get_accelerator().pin_memory(grad_partitions_flat_buffer) + self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer) offset = 0 for param in all_params: - self.__param_id_to_grad_partition[param.ds_id] = grad_partitions_flat_buffer.narrow( + self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow( 0, offset, param.partition_numel()) offset += param.partition_numel() @@ -1252,15 +1251,6 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L # operations and so it can be used asynchronously grad_buffer = cuda_grad_buffer - if hasattr(self.inf_or_nan_tracker, "logical_or_"): - self.inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any()) - self.inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any()) - else: - # logical_or_ not available in older versions of pytorch - self.inf_or_nan_tracker += torch.isinf(grad_buffer).any() - self.inf_or_nan_tracker += torch.isnan(grad_buffer).any() - self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0 - # offload the gradient partition if applicable if self.offload_optimizer: i, dest_offset, _ = self.grad_position[self.get_param_id(param)] @@ -1591,7 +1581,6 @@ def free_grad_in_param_list(self, param_list): def reset_cpu_buffers(self): self.norm_for_param_grads = {} - self.local_overflow = False def log_timers(self, timer_names): if self.timers is None: @@ -1925,12 +1914,18 @@ def has_overflow_partitioned_grads_serial(self): def has_overflow(self, partition_gradients=True): if partition_gradients: with get_accelerator().stream(self.reduce_and_partition_stream): - self.local_overflow = bool(self.inf_or_nan_tracker.item()) + if hasattr(self.inf_or_nan_tracker, "logical_or_"): + self.inf_or_nan_tracker.logical_or_(torch.isinf(self.grad_partitions_flat_buffer).any()) + self.inf_or_nan_tracker.logical_or_(torch.isnan(self.grad_partitions_flat_buffer).any()) + else: + # logical_or_ not available in older versions of pytorch + self.inf_or_nan_tracker += torch.isinf(self.grad_partitions_flat_buffer).any() + self.inf_or_nan_tracker += torch.isnan(self.grad_partitions_flat_buffer).any() + self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0 + + overflow_gpu = self.inf_or_nan_tracker.clone().to(torch.uint8) self.inf_or_nan_tracker.zero_() - overflow = self.local_overflow - #overflow = self.has_overflow_partitioned_grads_serial() - overflow_gpu = get_accelerator().ByteTensor([overflow]) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group) else: