Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore backward after each batch for grad accum #1917

Merged
merged 10 commits into from
Oct 31, 2024
53 changes: 27 additions & 26 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,26 +152,10 @@ def __init__(self, cfg: DictConfig) -> None:
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)

# activation checkpointing/offloading
self._enable_activation_checkpointing = cfg.get(
"enable_activation_checkpointing", False
)
self._enable_activation_offloading = cfg.get(
"enable_activation_offloading", False
)
if self._enable_activation_offloading:
if self._device.type != "cuda":
raise RuntimeError(
"enable_activation_offloading should only be True when training on CUDA"
)
if not self._enable_activation_checkpointing:
raise RuntimeError(
"enable_activation_offloading should only be True when enable_activation_checkpointing is True"
)
elif self._enable_activation_checkpointing:
log.info(
"Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
"Enabling activation offloading should reduce memory further."
if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd:
raise RuntimeError(
"Gradient accumulation is not supported with optimizer in bwd."
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
)

# activation checkpointing/offloading
Expand Down Expand Up @@ -720,7 +704,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

_, rank = training.get_world_size_and_rank()
world_size, rank = training.get_world_size_and_rank()

# zero out the gradients before starting training
if not self._optimizer_in_bwd:
Expand Down Expand Up @@ -787,15 +771,31 @@ def train(self) -> None:
# Compute loss
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_fn(logits, labels) * current_num_tokens
current_loss = self._loss_fn(logits, labels) * current_num_tokens
Copy link
Contributor

@pbontrager pbontrager Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there was ever a issue with numerical stability, another option for scaling the loss would be:

if grad_accumulation_step == 0:
	base_num_tokens = current_num_tokens
	torch.distributed.broadcast(base_num_tokens, src=0)

current_loss = loss_fn(logits, labels) * current_num_tokens / base_num_tokens

This might over complicate things but I wanted to leave this here if in the future it turns out a reduced gradient/loss is necessary for smaller dtypes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we can take a look at this in the future


# free logits otherwise it peaks backward memory
del logits

running_loss += current_loss
felipemello1 marked this conversation as resolved.
Show resolved Hide resolved

# For optimizer in backward, we need to normalize before calling backward
# This case and gradient accumulation are mutually exclusive
if self._optimizer_in_bwd:
Copy link
Contributor

@pbontrager pbontrager Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 783 - 810 could be much more readable as:

if self.optimizer_in_bwd:
	raise if self._clip_grad_norm # or do this in init
	...
	current_loss.backward()
elif (idx + 1) % self._gradient_accumulation_steps == 0:
	current_loss.backward()
	...
	scale_grads()
	if self._clip_grad_norm is not None:
		...
	self._optimizer.step()
	self._optimizer.zero_grad(...)

This could be used in all the distributed recipes.

torch.distributed.all_reduce(num_tokens)
torch.distributed.all_reduce(running_loss)
current_loss = current_loss / num_tokens

current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
if not self._optimizer_in_bwd:
# Get total number of tokens across all ranks to normalize gradients
torch.distributed.all_reduce(num_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HE'S A GENIUS

# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
felipemello1 marked this conversation as resolved.
Show resolved Hide resolved
if self._clip_grad_norm is not None:
if self._optimizer_in_bwd:
raise NotImplementedError(
Expand All @@ -812,7 +812,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably normalize by local_num_tokens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: I am probably gonna keep it like this since it should be representative of the loss we are actually using to step (even though it means our loss curves will look slightly different than they do today)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it makes sense. Will it break all regression tests though?

pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand All @@ -833,7 +833,8 @@ def train(self) -> None:
else self._optim_ckpt_wrapper
),
),
"tokens_per_second_per_gpu": num_tokens / time_per_step,
"tokens_per_second_per_gpu": num_tokens
/ (time_per_step * world_size),
}
if self._log_peak_memory_stats:
log_dict.update(
Expand Down
9 changes: 5 additions & 4 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,12 +686,13 @@ def train(self) -> None:

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_step(batch) * current_num_tokens
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -706,7 +707,7 @@ def train(self) -> None:
self._lr_scheduler.step()
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
24 changes: 17 additions & 7 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

_, rank = training.get_world_size_and_rank()
world_size, rank = training.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down Expand Up @@ -857,7 +857,7 @@ def train(self) -> None:
):
torch.cuda.memory._record_memory_history()

batch = {k: v.to(self._device) for k, v in batch.items()}
utils.batch_to_device(batch, self._device)

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
Expand All @@ -869,13 +869,22 @@ def train(self) -> None:
class_loss, kd_loss = self._loss_step(batch)
running_class_loss += class_loss * current_num_tokens
running_kd_loss += kd_loss * current_num_tokens
current_loss = (
1 - self._kd_ratio
) * class_loss + self._kd_ratio * kd_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
class_loss = running_class_loss / num_tokens
kd_loss = running_kd_loss / num_tokens
loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss
loss.backward()
# Get total number of tokens across all ranks to normalize gradients
torch.distributed.all_reduce(num_tokens)
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_class_loss)
torch.distributed.all_reduce(running_kd_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
class_loss_to_log = running_class_loss.item() / num_tokens
kd_loss_to_log = running_kd_loss.item() / num_tokens
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
Expand Down Expand Up @@ -903,7 +912,8 @@ def train(self) -> None:
"class_loss": class_loss_to_log,
"kd_loss": kd_loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
"tokens_per_second_per_gpu": num_tokens
/ (time_per_step * world_size),
}
if self._log_peak_memory_stats:
log_dict.update(
Expand Down
15 changes: 7 additions & 8 deletions recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,15 +704,14 @@ def train(self) -> None:
class_loss, kd_loss = self._loss_step(batch)
running_class_loss += class_loss * current_num_tokens
running_kd_loss += kd_loss * current_num_tokens
current_loss = (
1 - self._kd_ratio
) * class_loss + self._kd_ratio * kd_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
class_loss = running_class_loss / num_tokens
kd_loss = running_kd_loss / num_tokens
loss = (
1 - self._kd_ratio
) * class_loss + self._kd_ratio * kd_loss
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -724,8 +723,8 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

class_loss_to_log = class_loss.item()
kd_loss_to_log = kd_loss.item()
class_loss_to_log = running_class_loss.item() / num_tokens
kd_loss_to_log = running_kd_loss.item() / num_tokens
loss_to_log = (
1 - self._kd_ratio
) * class_loss_to_log + self._kd_ratio * kd_loss_to_log
Expand Down
20 changes: 14 additions & 6 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

_, rank = training.get_world_size_and_rank()
world_size, rank = training.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down Expand Up @@ -812,15 +812,22 @@ def train(self) -> None:
# Compute loss
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_fn(logits, labels) * current_num_tokens
current_loss = self._loss_fn(logits, labels) * current_num_tokens

# free logits otherwise it peaks backward memory
del logits

running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
# Get total number of tokens across all ranks to normalize gradients
torch.distributed.all_reduce(num_tokens)
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -833,7 +840,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand All @@ -848,7 +855,8 @@ def train(self) -> None:
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
"tokens_per_second_per_gpu": num_tokens
/ (time_per_step * world_size),
}
if self._log_peak_memory_stats:
log_dict.update(
Expand Down
9 changes: 5 additions & 4 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,12 +692,13 @@ def train(self) -> None:

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_step(batch) * current_num_tokens
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
training.scale_grads(self._model, 1 / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -709,7 +710,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
35 changes: 21 additions & 14 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,7 @@ def train(self) -> None:
"""
# clean up before training begins
training.cleanup_before_training()

_, rank = training.get_world_size_and_rank()
world_size, rank = training.get_world_size_and_rank()

# zero out the gradients before starting training
self._optimizer.zero_grad()
Expand Down Expand Up @@ -668,18 +667,16 @@ def train(self) -> None:

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step

utils.batch_to_device(batch, self._device)

current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens
labels = batch.pop("labels")

labels = labels.to(self._device)
mask = mask.to(self._device) if mask is not None else None
input_pos = (
input_pos.to(self._device) if input_pos is not None else None
)

logits = self._model(tokens, mask=mask, input_pos=input_pos)
logits = self._model(**batch)

# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
Expand All @@ -692,22 +689,30 @@ def train(self) -> None:
logits = logits.reshape(-1, logits.size(-1))

# Compute loss
running_loss += self._loss_fn(logits, labels) * current_num_tokens
current_loss = self._loss_fn(logits, labels) * current_num_tokens

# free logits otherwise it peaks backward memory
del logits

running_loss += current_loss
current_loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
# Get total number of tokens across all ranks to normalize gradients
torch.distributed.all_reduce(num_tokens)
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, 1 / num_tokens)

self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = loss.item()
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand All @@ -722,7 +727,9 @@ def train(self) -> None:
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
"tokens_per_second_per_gpu": (
num_tokens / time_per_step * world_size
),
}
if self._log_peak_memory_stats:
log_dict.update(
Expand Down
13 changes: 6 additions & 7 deletions tests/recipes/test_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,20 @@ def _get_test_config_overrides(self):

def _fetch_expected_loss_values(self, model_type):
loss_values_map = {
"llama2": [10.5136, 10.4813, 10.5088, 10.5250],
"llama3": [12.0673, 11.9072, 11.9302, 11.9355],
"llama2": [10.5209, 10.5217, 10.4945, 10.5136],
"llama3": [11.9839, 11.9684, 11.9596, 11.93656],
}
return loss_values_map[model_type]

@pytest.mark.integration_test
@pytest.mark.parametrize(
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps",
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd",
[
("llama2/7B_full", "llama2", "hf", 1, 4),
("llama3/8B_full", "llama3", "tune", 1, 4),
("llama3/8B_full", "llama3", "tune", 4, 1),
("llama2/7B_full", "llama2", "hf", 1, 4, False),
("llama3/8B_full", "llama3", "tune", 1, 4, False),
("llama3/8B_full", "llama3", "tune", 4, 1, True),
],
)
@pytest.mark.parametrize("optim_in_bwd", [True, False])
@gpu_test(gpu_count=2)
def test_loss(
self,
Expand Down
Loading
Loading