Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable average tokens across devices #34373

Merged
merged 12 commits into from
Oct 28, 2024
10 changes: 9 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3602,7 +3602,12 @@ def training_step(
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss *= self.args.gradient_accumulation_steps
if num_items_in_batch is not None:
if self.compute_loss_func or self.model_accepts_loss_kwargs:
loss *= self.args.gradient_accumulation_steps
# Average tokens across devices is orthogonal to gradient accumulation
if self.args.average_tokens_across_devices:
loss *= self.args.world_size
self.accelerator.backward(loss, **kwargs)

return loss.detach() / self.args.gradient_accumulation_steps
Comment on lines 3641 to 3642
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We modified a lot how loss is computed, are we sure that this is loss is the same as the one applied ?

Copy link
Contributor

@muellerzr muellerzr Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, it should be loss.detach() / self.args.gradient_accumulation_steps / self.accelerator.num_processes (dividing by num processes if and only if we did our loss function num tokens logic)

Expand All @@ -3617,6 +3622,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
labels = inputs.pop("labels")
else:
labels = None
if self.args.average_tokens_across_devices and num_items_in_batch is not None:
num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device)
num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu())
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
Expand Down
22 changes: 22 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,15 @@ class TrainingArguments:
},
)

average_tokens_across_devices: Optional[bool] = field(
default=False,
metadata={
"help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to "
"synchronize num_tokens_in_batch for precise loss calculation. Reference: "
"https://github.com/huggingface/transformers/issues/34242"
},
)

def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
Expand Down Expand Up @@ -1763,6 +1772,19 @@ def __post_init__(self):
if self.framework == "pt" and is_torch_available():
self.device

# Disable average tokens when using single device
if self.average_tokens_across_devices:
try:
if self.world_size == 1:
logger.warning(
"average_tokens_across_devices is set to True but it is invalid when world size is"
"1. Turn it to False automatically."
)
self.average_tokens_across_devices = False
except ImportError as e:
logger.warning(f"Can not specify world size due to {e}. Turn average_tokens_across_devices to False.")
self.average_tokens_across_devices = False

if self.torchdynamo is not None:
warnings.warn(
"`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
Expand Down
Loading