Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable average tokens across devices #34373

Merged
merged 12 commits into from
Oct 28, 2024
7 changes: 7 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3603,6 +3603,9 @@ def training_step(
scaled_loss.backward()
else:
loss *= self.args.gradient_accumulation_steps
if (self.args.average_tokens_across_devices and num_items_in_batch is not None and
self.args.world_size > 1):
loss *= self.args.world_size
Copy link
Contributor

@muellerzr muellerzr Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both of these chunks I think can be under a if num_items_in_batch is not None and self.model_accepts_loss_kwargs, since both need to be valid for the loss *= self.args.gradient_accumulation_steps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think gradient accumulation is orthogonal to DDP, and used a new if statement. Please check my code. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, it's a matter of self.model.accepts_loss_kwargs

self.accelerator.backward(loss, **kwargs)

return loss.detach() / self.args.gradient_accumulation_steps
Comment on lines 3641 to 3642
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We modified a lot how loss is computed, are we sure that this is loss is the same as the one applied ?

Copy link
Contributor

@muellerzr muellerzr Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, it should be loss.detach() / self.args.gradient_accumulation_steps / self.accelerator.num_processes (dividing by num processes if and only if we did our loss function num tokens logic)

Expand All @@ -3617,6 +3620,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
labels = inputs.pop("labels")
else:
labels = None
if (self.args.average_tokens_across_devices and num_items_in_batch is not None and
self.args.world_size > 1):
num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device)
num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu())
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,13 @@ class TrainingArguments:
},
)

average_tokens_across_devices: Optional[bool] = field(
default=False,
metadata={
"help": "Whether or not to average tokens across devices."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could share a bit more why this arg could be useful ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, please review my code.

}
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During the __post_init__ we call setup_devices. We can change average_tokens_across_devices value to False if the world size < 1 I think!

This then simplifies it earlier to just be if self.args.average_tokens_across_devices

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, already fixed it, please check.


def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
Expand Down
Loading