Skip to content

Commit

Permalink
move overflow tracker to optimizer.step
Browse files Browse the repository at this point in the history
Don't check overflow in gradients for every bucket.
Do overflow chack once on grad flat buffer just before optimizer step
  • Loading branch information
BacharL committed May 29, 2023
1 parent bd4d724 commit d6a8711
Showing 1 changed file with 16 additions and 21 deletions.
37 changes: 16 additions & 21 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ def __init__(self,

if self.offload_optimizer:
self.norm_for_param_grads = {}
self.local_overflow = False

# stores if a partition has been reduced in this step
self.is_partition_reduced = {}
Expand Down Expand Up @@ -394,20 +393,20 @@ def _setup_for_real_optimizer(self):
dtype=self.dtype,
device=get_accelerator().current_device_name())

grad_partitions_flat_buffer = None
self.grad_partitions_flat_buffer = None
self.__param_id_to_grad_partition: Dict[int, Tensor] = {}

all_params = list(itertools.chain.from_iterable(self.fp16_groups))

grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params),
dtype=self.dtype,
device=self.device)
self.grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() for p in all_params),
dtype=self.dtype,
device=self.device)
if self.offload_optimizer_pin_memory:
grad_partitions_flat_buffer = get_accelerator().pin_memory(grad_partitions_flat_buffer)
self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer)

offset = 0
for param in all_params:
self.__param_id_to_grad_partition[param.ds_id] = grad_partitions_flat_buffer.narrow(
self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow(
0, offset, param.partition_numel())
offset += param.partition_numel()

Expand Down Expand Up @@ -1252,15 +1251,6 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L
# operations and so it can be used asynchronously
grad_buffer = cuda_grad_buffer

if hasattr(self.inf_or_nan_tracker, "logical_or_"):
self.inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any())
self.inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any())
else:
# logical_or_ not available in older versions of pytorch
self.inf_or_nan_tracker += torch.isinf(grad_buffer).any()
self.inf_or_nan_tracker += torch.isnan(grad_buffer).any()
self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0

# offload the gradient partition if applicable
if self.offload_optimizer:
i, dest_offset, _ = self.grad_position[self.get_param_id(param)]
Expand Down Expand Up @@ -1591,7 +1581,6 @@ def free_grad_in_param_list(self, param_list):

def reset_cpu_buffers(self):
self.norm_for_param_grads = {}
self.local_overflow = False

def log_timers(self, timer_names):
if self.timers is None:
Expand Down Expand Up @@ -1925,12 +1914,18 @@ def has_overflow_partitioned_grads_serial(self):
def has_overflow(self, partition_gradients=True):
if partition_gradients:
with get_accelerator().stream(self.reduce_and_partition_stream):
self.local_overflow = bool(self.inf_or_nan_tracker.item())
if hasattr(self.inf_or_nan_tracker, "logical_or_"):
self.inf_or_nan_tracker.logical_or_(torch.isinf(self.grad_partitions_flat_buffer).any())
self.inf_or_nan_tracker.logical_or_(torch.isnan(self.grad_partitions_flat_buffer).any())
else:
# logical_or_ not available in older versions of pytorch
self.inf_or_nan_tracker += torch.isinf(self.grad_partitions_flat_buffer).any()
self.inf_or_nan_tracker += torch.isnan(self.grad_partitions_flat_buffer).any()
self.inf_or_nan_tracker = self.inf_or_nan_tracker > 0

overflow_gpu = self.inf_or_nan_tracker.clone().to(torch.uint8)
self.inf_or_nan_tracker.zero_()

overflow = self.local_overflow
#overflow = self.has_overflow_partitioned_grads_serial()
overflow_gpu = get_accelerator().ByteTensor([overflow])
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group)

else:
Expand Down

0 comments on commit d6a8711

Please sign in to comment.