-
Notifications
You must be signed in to change notification settings - Fork 442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restore backward after each batch for grad accum #1917
Changes from 8 commits
b62af9f
3f8c7aa
494b96b
99acd4e
474b533
32d652d
8e978e7
a878829
83cba27
408e521
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -152,26 +152,10 @@ def __init__(self, cfg: DictConfig) -> None: | |
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps | ||
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) | ||
|
||
# activation checkpointing/offloading | ||
self._enable_activation_checkpointing = cfg.get( | ||
"enable_activation_checkpointing", False | ||
) | ||
self._enable_activation_offloading = cfg.get( | ||
"enable_activation_offloading", False | ||
) | ||
if self._enable_activation_offloading: | ||
if self._device.type != "cuda": | ||
raise RuntimeError( | ||
"enable_activation_offloading should only be True when training on CUDA" | ||
) | ||
if not self._enable_activation_checkpointing: | ||
raise RuntimeError( | ||
"enable_activation_offloading should only be True when enable_activation_checkpointing is True" | ||
) | ||
elif self._enable_activation_checkpointing: | ||
log.info( | ||
"Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " | ||
"Enabling activation offloading should reduce memory further." | ||
if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd: | ||
raise RuntimeError( | ||
"Gradient accumulation is not supported with optimizer in bwd." | ||
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." | ||
) | ||
|
||
# activation checkpointing/offloading | ||
|
@@ -720,7 +704,7 @@ def train(self) -> None: | |
# clean up before training begins | ||
training.cleanup_before_training() | ||
|
||
_, rank = training.get_world_size_and_rank() | ||
world_size, rank = training.get_world_size_and_rank() | ||
|
||
# zero out the gradients before starting training | ||
if not self._optimizer_in_bwd: | ||
|
@@ -787,15 +771,31 @@ def train(self) -> None: | |
# Compute loss | ||
# Loss is normalized by default so we multiply by the number of tokens | ||
# This way we can normalize by the total number of tokens if we're accumulating gradients | ||
running_loss += self._loss_fn(logits, labels) * current_num_tokens | ||
current_loss = self._loss_fn(logits, labels) * current_num_tokens | ||
|
||
# free logits otherwise it peaks backward memory | ||
del logits | ||
|
||
running_loss += current_loss | ||
felipemello1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# For optimizer in backward, we need to normalize before calling backward | ||
# This case and gradient accumulation are mutually exclusive | ||
if self._optimizer_in_bwd: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think 783 - 810 could be much more readable as: if self.optimizer_in_bwd:
raise if self._clip_grad_norm # or do this in init
...
current_loss.backward()
elif (idx + 1) % self._gradient_accumulation_steps == 0:
current_loss.backward()
...
scale_grads()
if self._clip_grad_norm is not None:
...
self._optimizer.step()
self._optimizer.zero_grad(...) This could be used in all the distributed recipes. |
||
torch.distributed.all_reduce(num_tokens) | ||
torch.distributed.all_reduce(running_loss) | ||
current_loss = current_loss / num_tokens | ||
|
||
current_loss.backward() | ||
|
||
# Step with optimizer | ||
if (idx + 1) % self._gradient_accumulation_steps == 0: | ||
loss = running_loss / num_tokens | ||
loss.backward() | ||
if not self._optimizer_in_bwd: | ||
# Get total number of tokens across all ranks to normalize gradients | ||
torch.distributed.all_reduce(num_tokens) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. HE'S A GENIUS |
||
# This will ensure that the logged loss matches what we're optimizing | ||
torch.distributed.all_reduce(running_loss) | ||
# Manually scale the gradients from unnormalized loss by total # of tokens | ||
training.scale_grads(self._model, 1 / num_tokens) | ||
felipemello1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self._clip_grad_norm is not None: | ||
if self._optimizer_in_bwd: | ||
raise NotImplementedError( | ||
|
@@ -812,7 +812,7 @@ def train(self) -> None: | |
# Update the number of steps when the weights are updated | ||
self.global_step += 1 | ||
|
||
loss_to_log = loss.item() | ||
loss_to_log = running_loss.item() / num_tokens | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should probably normalize by local_num_tokens? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update: I am probably gonna keep it like this since it should be representative of the loss we are actually using to step (even though it means our loss curves will look slightly different than they do today) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think it makes sense. Will it break all regression tests though? |
||
pbar.update(1) | ||
pbar.set_description( | ||
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" | ||
|
@@ -833,7 +833,8 @@ def train(self) -> None: | |
else self._optim_ckpt_wrapper | ||
), | ||
), | ||
"tokens_per_second_per_gpu": num_tokens / time_per_step, | ||
"tokens_per_second_per_gpu": num_tokens | ||
/ (time_per_step * world_size), | ||
} | ||
if self._log_peak_memory_stats: | ||
log_dict.update( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there was ever a issue with numerical stability, another option for scaling the loss would be:
This might over complicate things but I wanted to leave this here if in the future it turns out a reduced gradient/loss is necessary for smaller dtypes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, we can take a look at this in the future