-
Notifications
You must be signed in to change notification settings - Fork 440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restore backward after each batch for grad accum #1917
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1917
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 408e521 with merge base e99b890 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -0,0 +1,14 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this really need its own file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do you wanna put it then? Otherwise I am gonna copy-paste this in every recipe which is worse imo
…istributed full finetune now
@@ -722,7 +732,7 @@ def train(self) -> None: | |||
# Update the number of steps when the weights are updated | |||
self.global_step += 1 | |||
|
|||
loss_to_log = loss.item() | |||
loss_to_log = running_loss.item() / num_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably normalize by local_num_tokens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: I am probably gonna keep it like this since it should be representative of the loss we are actually using to step (even though it means our loss curves will look slightly different than they do today)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think it makes sense. Will it break all regression tests though?
recipes/full_finetune_distributed.py
Outdated
loss.backward() | ||
local_num_tokens = num_tokens.detach().clone() | ||
torch.distributed.all_reduce(num_tokens) | ||
training.scale_grads(self._model, self._world_size / num_tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are so many lines taking care of the all_reduce, backward, etc, that it makes me wonder if this should be a utility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah maybe. In this case I feel like it's important enough (and tricky enough) logic to be done very explicitly. Whatever route we go I will ultimately make it more explicit what's happening here
torchtune/training/_distributed.py
Outdated
|
||
|
||
@contextlib.contextmanager | ||
def no_sync(model: nn.Module) -> Generator[None, None, None]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name could be more descriptive, maybe no_grad_sync
cc @andrewor14 for review of the QAT recipe changes |
also cc @lindawangg for the KD recipe changes |
lgtm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
KD changes looks good to me
loss.backward() | ||
if not self._optimizer_in_bwd: | ||
# Get total number of tokens across all ranks to normalize gradients | ||
torch.distributed.all_reduce(num_tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HE'S A GENIUS
} | ||
return loss_values_map[model_type] | ||
|
||
@pytest.mark.integration_test | ||
@pytest.mark.parametrize( | ||
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps", | ||
[ | ||
("llama2/7B_qat_full", "llama2", "hf", 4, 1), | ||
("llama3/8B_qat_full", "llama3", "tune", 4, 1), | ||
# ("llama2/7B_qat_full", "llama2", "hf", 4, 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Math!
The only thing that is not crystal clear to me is why TPS would increase. I guess that before we would also count padded tokens, which should make TPS looks higher in the older version. |
Yeahhhhhhh, seems like the reduce would actually slow down training no? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for coming up with the best solution to this problem! I left some comments around some potention dtype concerns and readability, but otherwise this looks great!
""" | ||
for p in model.parameters(): | ||
if p.grad is not None: | ||
p.grad *= scaler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any concern here around overflows for lower dtypes? We could do a scaler range check based on dtype. Or is it better to leave it to the recipe to safely choose scaler values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed a bit offline, but this should be safe for bf16 and fp32. We can revisit upon integration of other dtypes
|
||
# For optimizer in backward, we need to normalize before calling backward | ||
# This case and gradient accumulation are mutually exclusive | ||
if self._optimizer_in_bwd: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 783 - 810 could be much more readable as:
if self.optimizer_in_bwd:
raise if self._clip_grad_norm # or do this in init
...
current_loss.backward()
elif (idx + 1) % self._gradient_accumulation_steps == 0:
current_loss.backward()
...
scale_grads()
if self._clip_grad_norm is not None:
...
self._optimizer.step()
self._optimizer.zero_grad(...)
This could be used in all the distributed recipes.
@@ -787,15 +771,31 @@ def train(self) -> None: | |||
# Compute loss | |||
# Loss is normalized by default so we multiply by the number of tokens | |||
# This way we can normalize by the total number of tokens if we're accumulating gradients | |||
running_loss += self._loss_fn(logits, labels) * current_num_tokens | |||
current_loss = self._loss_fn(logits, labels) * current_num_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there was ever a issue with numerical stability, another option for scaling the loss would be:
if grad_accumulation_step == 0:
base_num_tokens = current_num_tokens
torch.distributed.broadcast(base_num_tokens, src=0)
current_loss = loss_fn(logits, labels) * current_num_tokens / base_num_tokens
This might over complicate things but I wanted to leave this here if in the future it turns out a reduced gradient/loss is necessary for smaller dtypes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, we can take a look at this in the future
@ebsmothers For QAT, should we just update all at once (copy over the changes from full_finetune_distributed)? E.g. I already have a PR that does this, and I can just rebase after your PR is landed: #1854. Right now these two recipes aren't really synced, and if some PRs add changes to qat_distributed but others don't, I'm worried we'll be in a weird in-between state. What do you think? |
@andrewor14 if it's all the same to you I may keep the QAT changes in. Mainly because this change will actually impact the memory usage and loss curves in a nontrivial way, so I'd rather not have any divergence between our recipes there, even for a short time. Lmk if that sounds alright to you though |
@felipemello1 @joecummings discussed offline but the WPS is not measured in exactly the same way before vs after these changes. In the figures we have "old-version", "main", and "this-pr". In "old-version" I hacked it so that we are counting non-padding tokens only (this way it lines up with the other two). But both "old-version" and "main" are using only the tokens seen on rank 0. On "this-pr" we now have the tokens seen by all ranks, which we then normalize by the number of ranks. So this PR's version should be less noisy, but there will be some slight differences when compared to the other two versions. Regarding the question of all_reduce slowing down training: I can do a more detailed comparison, but we are only reducing a single float value. When compared against various FSDP comms on large tensors happening in forward and backward this should be pretty minimal (but will confirm). |
The fix to normalize CE loss by total number of tokens in a step moved our backward call from per-batch to per-step. This means a bunch of activations are hanging around longer than they should and blowing up our memory. We should be able to call backward on the unnormalized values for each batch then manually scale the gradients just before optimizer step.
This is easy enough for our single-device recipes, but for our distributed recipes it's slightly more work. In fact it cannot be done in a way that is both (a) correct and (b) backwards compatible. If you want to know why this is the case, see the "Long digression" section below.
TLDR of the changes:
.backward()
once per step, but we don't wanna do that). So now we'll raise an error (also I disabled a couple test cases that snuck in during this time period)utils.batch_to_device
into QAT recipe. This was causing num_tokens to show up on CPU instead of GPU because labels weren't moved until later.Test plan
All the single-device recipe tests succeed without any changes. However, for our distributed recipe tests some of the expected values need to change.
Why are you changing the expected values?!
Again, see the "long digression" section below, but we have been normalizing our loss by the local number of tokens, not the number of tokens seen over all ranks. Is this a huge deal? Honestly probably not, but technically it's not correct. How do I know this version is correct? I've run the following command both on main and on this PR:
In both cases I added logging of # of tokens and grad for the token embeddings weight. On main I also commented out the loss normalization so that we can get the raw values. Diff for my changes on main (changes on this PR are just the identical logging).
Note that after the first iteration the results differ due to the difference in how the gradients are calculated, so looking at the logs from just the first iteration:
On main
On this PR
But what does it all mean?
In both the preceding snippets, the first step sees 12 + 59 + 60 + 268 + 14 + 133 + 11 + 6 = 563 tokens.
The grad value logged is just the sum of the elements in the tensor. On main this is -13.8838529586792 on rank 0 and 1.0207021236419678 on rank 1. On this PR it is -0.02466048300266266 on rank 0 and 0.0018129688687622547 on rank 1. In both cases, the value on this PR == the unnormalized value on main / 563, which is what we would expect
End-to-end testing
Llama3 8B full finetune on 4 devices
Repro:
Peak allocated memory drops substantially after this PR (back to where it was before the first gradient accumulation PR, but with actual correct loss calculation now).
Llama 3.2 1B full finetune on 2 devices
We can see that (a) peak allocated memory drops back down to where it was before #1875, (b) tokens per second is the same as #1875 (and faster than before), and (c) the loss curves look similar.
Long digression: data parallel and gradient accumulation aren't so different
First: on a single device
Let's consider a simple example of how to calculate the correctly-normalized CE loss on a single device (i.e. no data parallel) with two gradient accumulation steps. Say our unnormalized cross-entropy loss for the first batch is L1 and our (similarly unnormalized) loss for the second batch is L2, and we have n1 tokens in the first batch and n2 tokens in the second batch. Then the properly-normalized cross-entropy loss to step with will be (L1 + L2) / (n1 + n2). The naive approach taken in #1875 is to just accumulate a running sum of both loss and number of tokens, then call backward on the ratio running_loss / running_num_tokens just before optimizer step. The problem with this is that we only call backward once, so our activations stick around and blow up our memory. What can we do instead?
Sticking with the single device case, it's not too hard to fix this. Repeated calls to
.backward()
accumulate gradients, so e.g. the following two are numerically equivalent:Then for the single device case we can actually take the second approach with L1 and L2. But remember that we still want to normalize by n1 + n2. This isn't too hard.. we can just manually scale the gradients ourselves just before the optimizer step (since grad(c*X) = c*grad(X) for any constant c this is equivalent to scaling the loss).
Adding data parallel
This is where things get slightly messier. Let's extend the example to two devices: using Lij to refer to the unnormalized cross-entropy loss for rank i in its jth batch, and similarly for nij with number of tokens.
So in the first batch the model will see:
On rank 1: loss L11 based on n11 tokens
On rank 2: loss L21 based on n21 tokens
In the second batch, it will see:
On rank 1: loss L12 based on n12 tokens
On rank 2: loss L22 based on n22 tokens
Similarly to the single-device case, the total number of tokens seen across all batches and all ranks will be n11 + n21 + n12 + n22. This means that the properly-normalized loss should be given by (L11 + L21 + L12 + L22) / (n11 + n21 + n12 + n22).
What are we doing today?
Currently we take a similar approach to the single-device case described previously: we accumulate the losses and tokens, then call .backward() on the ratio running_loss / running_num_tokens just before stepping with the optimizer. See the below code:
torchtune/recipes/full_finetune_distributed.py
Lines 724 to 732 in a1bcb97
What's wrong with this?
During data parallel training, the loss on a given rank is based only on the subset of data seen by that rank. Similarly, our calculation of running_num_tokens is based only on the tokens from that rank. This means that when we normalize we are normalizing only over iterations, not over ranks. Put another way, the line
loss = running_loss / num_tokens
in the above snippet will yield (L11 + L12) / (n11 + n12) on rank 1 and (L21 + L22) / (n21 + n22) on rank 2. Finally, we call .backward(), which calculates local grads before firing a hook to sync by reducing over all ranks. The upshot is that our loss winds up as [(L11 + L12) / (n11 + n12)] + [(L21 + L22) / (n21 + n22)], which is definitely not (L11 + L21 + L12 + L22) / (n11 + n21 + n12 + n22).How can we fix it?
Aside from the correctness issue described in the previous section, we also need an approach that still calls .backward() on each batch to free the activation memory. How can we do this? Actually it's not so bad.. we just need to reduce the number of tokens before our final gradient normalization and let data parallel backward hooks take care of the rest. More explicitly:
First batch:
On rank 1: calculate loss L11 based on n11 tokens
On rank 2: calculate loss L21 based on n21 tokens
Call backward -> this triggers a grad sync and each rank now has the (local) grads for L11 + L21
Second batch:
On rank 1: calculate loss L12 based on n12 tokens
On rank 2: calculate loss L22 based on n22 tokens
Call backward -> this accumulates the grads from L11 + L21 locally, then triggers another sync so that we now have L11 + L21 + L12 + L22
Then we just need n11 + n21 + n12 + n22. Fortunately this is just an all-reduce on running_num_tokens. Then we manually scale the gradients just like in the single-device case.
In summary, we wind up with the following process:
For each batch:
Before optimizer step: