Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable average tokens across devices #34373

Merged
merged 12 commits into from
Oct 28, 2024
8 changes: 8 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3636,7 +3636,15 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
model_name = unwrapped_model._get_name()
# User-defined compute_loss function
if self.compute_loss_func is not None:
if (self.args.average_tokens_across_devices and num_items_in_batch is not None and
self.args.world_size > 1):
num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device)
num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu())
device_count_for_loss = self.args.world_size
else:
device_count_for_loss = 1
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
loss *= device_count_for_loss
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! We also need the case for when we don't define this, e.g. it's passed to the model forward(). So what would be better is to perform the gather much earlier, and pass the new num_items_in_batch as part of the call to compute_loss.

And then perform the loss *= where we call loss *= self.args.gradient_accumulation_steps later (right before we call accelerator.backward())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for advice! Already fixed it, please check again.

else:
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,13 @@ class TrainingArguments:
},
)

average_tokens_across_devices: Optional[bool] = field(
default=False,
metadata={
"help": "Whether or not to average tokens across devices."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could share a bit more why this arg could be useful ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, please review my code.

}
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During the __post_init__ we call setup_devices. We can change average_tokens_across_devices value to False if the world size < 1 I think!

This then simplifies it earlier to just be if self.args.average_tokens_across_devices

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, already fixed it, please check.


def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
Expand Down
Loading