Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix logits compuation in KTO trainer prediction step #2050

Merged
merged 3 commits into from
Sep 11, 2024

Conversation

issamemari
Copy link
Contributor

Description of the issue

There is a bug in the following few lines of code in kto_trainer.py

logits_dict = {
    "eval_logits/chosen": metrics["logits/chosen"],
    "eval_logits/rejected": metrics["logits/rejected"],
}
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)

This assumes that the values in the logits_dict are tensors, but they are not. These values are computed in get_batch_loss_metrics where the the logits are averaged and .item() is called on the resulting tensor to get a float.

This causes the following error when running a KTO training:

AttributeError: 'float' object has no attribute 'unsqueeze'

What does this PR do?

  • Treat the values of logits_dict as floats
  • Fixes AttributeError: 'float' object has no attribute 'unsqueeze' during the evaluation phase of a KTO training

trl/trainer/kto_trainer.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

trl/trainer/kto_trainer.py Outdated Show resolved Hide resolved
@kashif kashif merged commit 9c043e5 into huggingface:main Sep 11, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants