Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable average tokens across devices #34373

Merged
merged 12 commits into from
Oct 28, 2024
10 changes: 9 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3602,7 +3602,12 @@ def training_step(
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss *= self.args.gradient_accumulation_steps
if num_items_in_batch is not None:
if self.compute_loss_func or self.model_accepts_loss_kwargs:
loss *= self.args.gradient_accumulation_steps
# Average tokens across devices is orthogonal to gradient accumulation
if self.args.average_tokens_across_devices:
loss *= self.args.world_size
self.accelerator.backward(loss, **kwargs)

return loss.detach() / self.args.gradient_accumulation_steps
Comment on lines 3641 to 3642
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We modified a lot how loss is computed, are we sure that this is loss is the same as the one applied ?

Copy link
Contributor

@muellerzr muellerzr Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, it should be loss.detach() / self.args.gradient_accumulation_steps / self.accelerator.num_processes (dividing by num processes if and only if we did our loss function num tokens logic)

Expand All @@ -3617,6 +3622,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
labels = inputs.pop("labels")
else:
labels = None
if self.args.average_tokens_across_devices and num_items_in_batch is not None:
num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device)
num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu())
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,13 @@ class TrainingArguments:
},
)

average_tokens_across_devices: Optional[bool] = field(
default=False,
metadata={
"help": "Whether or not to average tokens across devices."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could share a bit more why this arg could be useful ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, please review my code.

}
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During the __post_init__ we call setup_devices. We can change average_tokens_across_devices value to False if the world size < 1 I think!

This then simplifies it earlier to just be if self.args.average_tokens_across_devices

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, already fixed it, please check.


def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
Expand Down Expand Up @@ -1763,6 +1770,10 @@ def __post_init__(self):
if self.framework == "pt" and is_torch_available():
self.device

# Disable average tokens when using single device
if self.world_size == 1:
self.average_tokens_across_devices = False

if self.torchdynamo is not None:
warnings.warn(
"`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
Expand Down
Loading