Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore backward after each batch for grad accum #1917

Merged
merged 10 commits into from
Oct 31, 2024

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Oct 29, 2024

The fix to normalize CE loss by total number of tokens in a step moved our backward call from per-batch to per-step. This means a bunch of activations are hanging around longer than they should and blowing up our memory. We should be able to call backward on the unnormalized values for each batch then manually scale the gradients just before optimizer step.

This is easy enough for our single-device recipes, but for our distributed recipes it's slightly more work. In fact it cannot be done in a way that is both (a) correct and (b) backwards compatible. If you want to know why this is the case, see the "Long digression" section below.

TLDR of the changes:

  1. Gradient accumulation logic changes: We now (a) call backward on every batch's unnormalized loss, (b) accumulate num_tokens over every batch, (c) (distributed only) all_reduce num_tokens on the last batch prior to stepping, then (d) manually scale gradients prior to stepping with the optimizer. This will revert the memory regression introduced by Normalize CE loss by total number of (non-padding) tokens #1875 but keep the correctness of the gradient accumulation logic.
  2. Logging changes: If we are stepping based on total_loss / total_tokens (where total is over all batches and over all ranks), that's what should show up in the logs. Similarly, we can now use the number of tokens over all ranks instead of just rank 0 (though we still normalize tokens/sec to per GPU)
  3. Optimizer in backward: This shouldn't be supported when gradient accumulation is enabled. Previously we didn't raise an explicit error about this (also it actually worked when we only call .backward() once per step, but we don't wanna do that). So now we'll raise an error (also I disabled a couple test cases that snuck in during this time period)
  4. Minor changes to KD and QAT recipes: Integrate utils.batch_to_device into QAT recipe. This was causing num_tokens to show up on CPU instead of GPU because labels weren't moved until later.

Test plan

All the single-device recipe tests succeed without any changes. However, for our distributed recipe tests some of the expected values need to change.

Why are you changing the expected values?!

Again, see the "long digression" section below, but we have been normalizing our loss by the local number of tokens, not the number of tokens seen over all ranks. Is this a huge deal? Honestly probably not, but technically it's not correct. How do I know this version is correct? I've run the following command both on main and on this PR:

pytest -m integration_test tests/recipes/test_full_finetune_distributed.py -k 'test_loss[False-llama2/7B_full-llama2-hf-1-4]'

In both cases I added logging of # of tokens and grad for the token embeddings weight. On main I also commented out the loss normalization so that we can get the raw values. Diff for my changes on main (changes on this PR are just the identical logging).

Note that after the first iteration the results differ due to the difference in how the gradients are calculated, so looking at the logs from just the first iteration:

On main

rank: 1, idx: 0, num_tokens: 12
rank: 0, idx: 0, num_tokens: 59
rank: 0, idx: 1, num_tokens: 60
rank: 1, idx: 1, num_tokens: 268
rank: 0, idx: 2, num_tokens: 14
rank: 1, idx: 2, num_tokens: 133
rank: 0, idx: 3, num_tokens: 11
rank: 1, idx: 3, num_tokens: 6
rank: 0, unnormalized grad: DTensor(local_tensor=-13.8838529586792, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Partial(sum),))
rank: 1, unnormalized grad: DTensor(local_tensor=1.0207021236419678, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Partial(sum),))

On this PR

rank: 1, idx: 0, num_tokens: 12
rank: 0, idx: 0, num_tokens: 59
rank: 0, idx: 1, num_tokens: 60
rank: 1, idx: 1, num_tokens: 268
rank: 0, idx: 2, num_tokens: 14
rank: 1, idx: 2, num_tokens: 133
rank: 0, idx: 3, num_tokens: 11
rank: 1, idx: 3, num_tokens: 6
rank: 0, grad: DTensor(local_tensor=-0.02466048300266266, device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Partial(sum),))
rank: 1, grad: DTensor(local_tensor=0.0018129688687622547 device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Partial(sum),))

But what does it all mean?

In both the preceding snippets, the first step sees 12 + 59 + 60 + 268 + 14 + 133 + 11 + 6 = 563 tokens.

The grad value logged is just the sum of the elements in the tensor. On main this is -13.8838529586792 on rank 0 and 1.0207021236419678 on rank 1. On this PR it is -0.02466048300266266 on rank 0 and 0.0018129688687622547 on rank 1. In both cases, the value on this PR == the unnormalized value on main / 563, which is what we would expect

End-to-end testing

Llama3 8B full finetune on 4 devices

Repro:

tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
max_steps_per_epoch=500 gradient_accumulation_steps=4

Peak allocated memory drops substantially after this PR (back to where it was before the first gradient accumulation PR, but with actual correct loss calculation now).

Screenshot 2024-10-30 at 5 53 14 PM

Llama 3.2 1B full finetune on 2 devices

tune run --nproc_per_node full_finetune_distributed --config llama3_2/1B_full max_steps_per_epoch=500

We can see that (a) peak allocated memory drops back down to where it was before #1875, (b) tokens per second is the same as #1875 (and faster than before), and (c) the loss curves look similar.

Screenshot 2024-10-30 at 5 57 13 PM Screenshot 2024-10-30 at 6 00 45 PM Screenshot 2024-10-30 at 6 01 40 PM

Long digression: data parallel and gradient accumulation aren't so different

First: on a single device

Let's consider a simple example of how to calculate the correctly-normalized CE loss on a single device (i.e. no data parallel) with two gradient accumulation steps. Say our unnormalized cross-entropy loss for the first batch is L1 and our (similarly unnormalized) loss for the second batch is L2, and we have n1 tokens in the first batch and n2 tokens in the second batch. Then the properly-normalized cross-entropy loss to step with will be (L1 + L2) / (n1 + n2). The naive approach taken in #1875 is to just accumulate a running sum of both loss and number of tokens, then call backward on the ratio running_loss / running_num_tokens just before optimizer step. The problem with this is that we only call backward once, so our activations stick around and blow up our memory. What can we do instead?

Sticking with the single device case, it's not too hard to fix this. Repeated calls to .backward() accumulate gradients, so e.g. the following two are numerically equivalent:

# Single backward call
loss_1 = model(batch_1)
loss_2 = model(batch_2)
summed_loss = loss_1 + loss_2
summed_loss.backward()
# Multiple backward calls
loss_1 = model(batch_1)
loss_1.backward()
loss_2 = model(batch_2)
loss_2.backward()

Then for the single device case we can actually take the second approach with L1 and L2. But remember that we still want to normalize by n1 + n2. This isn't too hard.. we can just manually scale the gradients ourselves just before the optimizer step (since grad(c*X) = c*grad(X) for any constant c this is equivalent to scaling the loss).

Adding data parallel

This is where things get slightly messier. Let's extend the example to two devices: using Lij to refer to the unnormalized cross-entropy loss for rank i in its jth batch, and similarly for nij with number of tokens.

So in the first batch the model will see:
On rank 1: loss L11 based on n11 tokens
On rank 2: loss L21 based on n21 tokens

In the second batch, it will see:
On rank 1: loss L12 based on n12 tokens
On rank 2: loss L22 based on n22 tokens

Similarly to the single-device case, the total number of tokens seen across all batches and all ranks will be n11 + n21 + n12 + n22. This means that the properly-normalized loss should be given by (L11 + L21 + L12 + L22) / (n11 + n21 + n12 + n22).

What are we doing today?

Currently we take a similar approach to the single-device case described previously: we accumulate the losses and tokens, then call .backward() on the ratio running_loss / running_num_tokens just before stepping with the optimizer. See the below code:

running_loss += self._loss_fn(logits, labels) * current_num_tokens
# free logits otherwise it peaks backward memory
del logits
# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()

What's wrong with this?

During data parallel training, the loss on a given rank is based only on the subset of data seen by that rank. Similarly, our calculation of running_num_tokens is based only on the tokens from that rank. This means that when we normalize we are normalizing only over iterations, not over ranks. Put another way, the line loss = running_loss / num_tokens in the above snippet will yield (L11 + L12) / (n11 + n12) on rank 1 and (L21 + L22) / (n21 + n22) on rank 2. Finally, we call .backward(), which calculates local grads before firing a hook to sync by reducing over all ranks. The upshot is that our loss winds up as [(L11 + L12) / (n11 + n12)] + [(L21 + L22) / (n21 + n22)], which is definitely not (L11 + L21 + L12 + L22) / (n11 + n21 + n12 + n22).

How can we fix it?

Aside from the correctness issue described in the previous section, we also need an approach that still calls .backward() on each batch to free the activation memory. How can we do this? Actually it's not so bad.. we just need to reduce the number of tokens before our final gradient normalization and let data parallel backward hooks take care of the rest. More explicitly:

First batch:
On rank 1: calculate loss L11 based on n11 tokens
On rank 2: calculate loss L21 based on n21 tokens
Call backward -> this triggers a grad sync and each rank now has the (local) grads for L11 + L21

Second batch:
On rank 1: calculate loss L12 based on n12 tokens
On rank 2: calculate loss L22 based on n22 tokens
Call backward -> this accumulates the grads from L11 + L21 locally, then triggers another sync so that we now have L11 + L21 + L12 + L22

Then we just need n11 + n21 + n12 + n22. Fortunately this is just an all-reduce on running_num_tokens. Then we manually scale the gradients just like in the single-device case.

In summary, we wind up with the following process:

For each batch:

  • Call backward on the unnormalized losses
  • Keep a running tally of the number of tokens seen (per rank)

Before optimizer step:

  • [data parallel only] Reduce the running tally of number of tokens over all ranks
  • Scale gradients by 1 / total_num_tokens

Copy link

pytorch-bot bot commented Oct 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1917

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 408e521 with merge base e99b890 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 29, 2024
@ebsmothers ebsmothers marked this pull request as draft October 29, 2024 13:59
@@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this really need its own file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do you wanna put it then? Otherwise I am gonna copy-paste this in every recipe which is worse imo

@@ -722,7 +732,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably normalize by local_num_tokens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: I am probably gonna keep it like this since it should be representative of the loss we are actually using to step (even though it means our loss curves will look slightly different than they do today)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it makes sense. Will it break all regression tests though?

loss.backward()
local_num_tokens = num_tokens.detach().clone()
torch.distributed.all_reduce(num_tokens)
training.scale_grads(self._model, self._world_size / num_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are so many lines taking care of the all_reduce, backward, etc, that it makes me wonder if this should be a utility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah maybe. In this case I feel like it's important enough (and tricky enough) logic to be done very explicitly. Whatever route we go I will ultimately make it more explicit what's happening here



@contextlib.contextmanager
def no_sync(model: nn.Module) -> Generator[None, None, None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name could be more descriptive, maybe no_grad_sync

@ebsmothers ebsmothers changed the title [very wip] restore backward after each batch for grad accum Restore backward after each batch for grad accum Oct 30, 2024
@ebsmothers ebsmothers marked this pull request as ready for review October 31, 2024 00:48
@ebsmothers
Copy link
Contributor Author

cc @andrewor14 for review of the QAT recipe changes

@ebsmothers
Copy link
Contributor Author

also cc @lindawangg for the KD recipe changes

@felipemello1
Copy link
Contributor

lgtm

Copy link
Contributor

@lindawangg lindawangg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KD changes looks good to me

loss.backward()
if not self._optimizer_in_bwd:
# Get total number of tokens across all ranks to normalize gradients
torch.distributed.all_reduce(num_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HE'S A GENIUS

}
return loss_values_map[model_type]

@pytest.mark.integration_test
@pytest.mark.parametrize(
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps",
[
("llama2/7B_qat_full", "llama2", "hf", 4, 1),
("llama3/8B_qat_full", "llama3", "tune", 4, 1),
# ("llama2/7B_qat_full", "llama2", "hf", 4, 1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented out?

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Math!

@felipemello1
Copy link
Contributor

The only thing that is not crystal clear to me is why TPS would increase. I guess that before we would also count padded tokens, which should make TPS looks higher in the older version.

@joecummings
Copy link
Contributor

The only thing that is not crystal clear to me is why TPS would increase. I guess that before we would also count padded tokens, which should make TPS looks higher in the older version.

Yeahhhhhhh, seems like the reduce would actually slow down training no?

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for coming up with the best solution to this problem! I left some comments around some potention dtype concerns and readability, but otherwise this looks great!

"""
for p in model.parameters():
if p.grad is not None:
p.grad *= scaler
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any concern here around overflows for lower dtypes? We could do a scaler range check based on dtype. Or is it better to leave it to the recipe to safely choose scaler values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed a bit offline, but this should be safe for bf16 and fp32. We can revisit upon integration of other dtypes


# For optimizer in backward, we need to normalize before calling backward
# This case and gradient accumulation are mutually exclusive
if self._optimizer_in_bwd:
Copy link
Contributor

@pbontrager pbontrager Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 783 - 810 could be much more readable as:

if self.optimizer_in_bwd:
	raise if self._clip_grad_norm # or do this in init
	...
	current_loss.backward()
elif (idx + 1) % self._gradient_accumulation_steps == 0:
	current_loss.backward()
	...
	scale_grads()
	if self._clip_grad_norm is not None:
		...
	self._optimizer.step()
	self._optimizer.zero_grad(...)

This could be used in all the distributed recipes.

@@ -787,15 +771,31 @@ def train(self) -> None:
# Compute loss
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_fn(logits, labels) * current_num_tokens
current_loss = self._loss_fn(logits, labels) * current_num_tokens
Copy link
Contributor

@pbontrager pbontrager Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there was ever a issue with numerical stability, another option for scaling the loss would be:

if grad_accumulation_step == 0:
	base_num_tokens = current_num_tokens
	torch.distributed.broadcast(base_num_tokens, src=0)

current_loss = loss_fn(logits, labels) * current_num_tokens / base_num_tokens

This might over complicate things but I wanted to leave this here if in the future it turns out a reduced gradient/loss is necessary for smaller dtypes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we can take a look at this in the future

@andrewor14
Copy link
Contributor

@ebsmothers For QAT, should we just update all at once (copy over the changes from full_finetune_distributed)? E.g. I already have a PR that does this, and I can just rebase after your PR is landed: #1854. Right now these two recipes aren't really synced, and if some PRs add changes to qat_distributed but others don't, I'm worried we'll be in a weird in-between state. What do you think?

@ebsmothers
Copy link
Contributor Author

@andrewor14 if it's all the same to you I may keep the QAT changes in. Mainly because this change will actually impact the memory usage and loss curves in a nontrivial way, so I'd rather not have any divergence between our recipes there, even for a short time. Lmk if that sounds alright to you though

@ebsmothers
Copy link
Contributor Author

@felipemello1 @joecummings discussed offline but the WPS is not measured in exactly the same way before vs after these changes. In the figures we have "old-version", "main", and "this-pr". In "old-version" I hacked it so that we are counting non-padding tokens only (this way it lines up with the other two). But both "old-version" and "main" are using only the tokens seen on rank 0. On "this-pr" we now have the tokens seen by all ranks, which we then normalize by the number of ranks. So this PR's version should be less noisy, but there will be some slight differences when compared to the other two versions.

Regarding the question of all_reduce slowing down training: I can do a more detailed comparison, but we are only reducing a single float value. When compared against various FSDP comms on large tensors happening in forward and backward this should be pretty minimal (but will confirm).

@joecummings joecummings mentioned this pull request Oct 31, 2024
34 tasks
@ebsmothers ebsmothers merged commit 2fa6a54 into pytorch:main Oct 31, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants