-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: Add support for APO-zero in KTOTrainer #1952
Conversation
I've confirm the training dynamics of APO-zero-unpaired are as intended. Compared to a KTO-run, APO-zero produces about a 40% faster training times since no KL values need to be calculated. I will run more downstream evaluations to understand the differences between KTO and APO-zero for unpaired alignment later. This PR can now be merged |
Thank you @KarelDO, it's a feature we're very happy to see coming to TRL. We'll be reviewing your PR soon for sure. Can you share the elements you have to confirm that everything is working as expected? Maybe curves, trained models, etc? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -60,6 +65,10 @@ class KTOConfig(TrainingArguments): | |||
Number of processes to use for processing the datasets. | |||
""" | |||
|
|||
loss_type: Literal[ | |||
"kto", | |||
"apo_zero_unpaired", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is "unpaired" really necessary? As far as I understand, there is no such thing as "paired" version for kto, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
APO-zero does have a paired and unpaired variant, and you could definitely construct a paired variant of KTO.
We can remove "_unpaired" here since the KTOTrainer also implies it, but I thought it would be good for people to actively think about the distinction when selecting a loss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes given we have an apo_zero
loss also in the DPOTrainer
, It's good to retain the _unpaired
distinction IMO
Would you mind adding this loss term to the intergration tests here:
Line 76 in 47ab034
@parameterized.expand( |
You might want to look at the DPO trainer for inspiration:
Line 251 in 47ab034
@parameterized.expand( |
Hey @qgallouedec , thanks! Here's a wandb report pdf with some training curves. I can't share the full WandB project yet, but I have a print out of the report attached. I've summarized the main take-aways below:
TL;DR: different loss functions, they sometimes behave similarly depending on underlying preference dataset, APO-zero trains faster. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this nice contribution @KarelDO ! Overall it LGTM once we have some unit / integration tests added.
return new_rows | ||
|
||
|
||
def maybe_reformat_dpo_to_kto(dataset: DatasetDict, num_proc: int = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For public methods, would you mind adding a docstring and a unit test please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
@@ -60,6 +65,10 @@ class KTOConfig(TrainingArguments): | |||
Number of processes to use for processing the datasets. | |||
""" | |||
|
|||
loss_type: Literal[ | |||
"kto", | |||
"apo_zero_unpaired", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes given we have an apo_zero
loss also in the DPOTrainer
, It's good to retain the _unpaired
distinction IMO
Would you mind adding this loss term to the intergration tests here:
Line 76 in 47ab034
@parameterized.expand( |
You might want to look at the DPO trainer for inspiration:
Line 251 in 47ab034
@parameterized.expand( |
Thanks for iterating! Would you mind fixing the code quality test and then we can merge! |
Thanks @lewtun , should be fixed now! |
Ah it seems some of the KTO tests are now failing after rebasing on main - would you mind fixing those 🙏 ? |
@lewtun we should be good now! |
Thanks for iterating! |
Feat: Add support for APO-zero in KTOTrainer
Now ready to merge
This PR adds support for the unpaired variant of APO-zero in the KTOTrainer. See the APO paper.
To achieve this, I:
loss_type
variable toKTOConfig
(similar toDPOConfig
)KTOTrainer
(similar toDPOTrainer
)KTOTrainer
to only calculate the KL when the loss requires this (calculating the KL is expensive, as a result APO-zero runs faster than KTO).Additionally, I updated the
kto.py
script to be interoperable with any dpo-formatted datasets:data_util.py
) which checks the format of the dataset and turns a dpo-formatted dataset into a kto-formatted dataset.dpo.py
).