-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add KTO Loss #475
base: main
Are you sure you want to change the base?
Add KTO Loss #475
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Take a brief look, I am not very familiar with KTO math but why do we not have KL_log_probs but original HF has https://github.com/huggingface/trl/blob/cd7156fb34ddf9a8c04fcd640a4067933461d44e/trl/trainer/kto_trainer.py#L1121. We also need to be careful about scaling. Seems in original HF, kto_loss
returns an unreduced version, but we probably need to reduce as mean. cc @shivam15s
About KL, I'll take a further look in About |
Summary
Close KTO Item of the Roadmap: #371
Implements the Kahneman-Tversky Optimization (KTO) loss function.
KTO Loss Function
For a policy π compared to a reference policy π₀:
When y is chosen:
When y is rejected:
where:
Intuition
KTO loss is inspired by prospect theory from behavioral economics, which models how humans make decisions under uncertainty.
The loss function is asymmetric, treating gains and losses differently, similar to
human decision-making patterns.
Credit by: https://www.youtube.com/watch?v=nSrj1J6ODoM&t=422s
Benchmark Result
Memory:
Speed:
Key Changes
LigerFusedLinearKTOLoss
classLigerFusedLinearKTOFunction
for the core KTO computationtest_kto_loss.py
HFKTOLoss
) based on Hugging Face's implementationReference
Testing Done
Test is passing now:
pytest test/chunked_loss/test_kto_loss.py
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence