-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable average tokens across devices #34373
Changes from 2 commits
590a522
647203b
8fd583d
190d534
ae9fbe9
4c8d02f
70919a1
8ca8c4d
1eab209
cdbb3d3
bd8ae99
588a80e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3603,6 +3603,9 @@ def training_step( | |
scaled_loss.backward() | ||
else: | ||
loss *= self.args.gradient_accumulation_steps | ||
if (self.args.average_tokens_across_devices and num_items_in_batch is not None and | ||
self.args.world_size > 1): | ||
loss *= self.args.world_size | ||
self.accelerator.backward(loss, **kwargs) | ||
|
||
return loss.detach() / self.args.gradient_accumulation_steps | ||
Comment on lines
3641
to
3642
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We modified a lot how loss is computed, are we sure that this is loss is the same as the one applied ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, it should be |
||
|
@@ -3617,6 +3620,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N | |
labels = inputs.pop("labels") | ||
else: | ||
labels = None | ||
if (self.args.average_tokens_across_devices and num_items_in_batch is not None and | ||
self.args.world_size > 1): | ||
num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device) | ||
num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu()) | ||
if self.model_accepts_loss_kwargs: | ||
loss_kwargs = {} | ||
if num_items_in_batch is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1530,6 +1530,13 @@ class TrainingArguments: | |
}, | ||
) | ||
|
||
average_tokens_across_devices: Optional[bool] = field( | ||
default=False, | ||
metadata={ | ||
"help": "Whether or not to average tokens across devices." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we could share a bit more why this arg could be useful ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, please review my code. |
||
} | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. During the This then simplifies it earlier to just be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, already fixed it, please check. |
||
|
||
def __post_init__(self): | ||
# Parse in args that could be `dict` sent in from the CLI as a string | ||
for field in _VALID_DICT_FIELDS: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both of these chunks I think can be under a
if num_items_in_batch is not None and self.model_accepts_loss_kwargs
, since both need to be valid for theloss *= self.args.gradient_accumulation_steps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think gradient accumulation is orthogonal to DDP, and used a new if statement. Please check my code. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is, it's a matter of
self.model.accepts_loss_kwargs