Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add support for APO-zero in KTOTrainer #1952

Merged
merged 17 commits into from
Sep 4, 2024

Conversation

KarelDO
Copy link
Contributor

@KarelDO KarelDO commented Aug 21, 2024

Feat: Add support for APO-zero in KTOTrainer

Now ready to merge

This PR adds support for the unpaired variant of APO-zero in the KTOTrainer. See the APO paper.

To achieve this, I:

  • Added a loss_type variable to KTOConfig (similar to DPOConfig)
  • Added the APO loss to KTOTrainer (similar to DPOTrainer)
  • Updated KTOTrainer to only calculate the KL when the loss requires this (calculating the KL is expensive, as a result APO-zero runs faster than KTO).

Additionally, I updated the kto.py script to be interoperable with any dpo-formatted datasets:

  • I added a util (in data_util.py) which checks the format of the dataset and turns a dpo-formatted dataset into a kto-formatted dataset.
  • I've updated the kto script to also work with ChatML formatted datasets (similar to what happens in dpo.py).

@KarelDO
Copy link
Contributor Author

KarelDO commented Aug 23, 2024

I've confirm the training dynamics of APO-zero-unpaired are as intended.

Compared to a KTO-run, APO-zero produces about a 40% faster training times since no KL values need to be calculated.

I will run more downstream evaluations to understand the differences between KTO and APO-zero for unpaired alignment later.

This PR can now be merged

@qgallouedec
Copy link
Member

Thank you @KarelDO, it's a feature we're very happy to see coming to TRL. We'll be reviewing your PR soon for sure. Can you share the elements you have to confirm that everything is working as expected? Maybe curves, trained models, etc?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -60,6 +65,10 @@ class KTOConfig(TrainingArguments):
Number of processes to use for processing the datasets.
"""

loss_type: Literal[
"kto",
"apo_zero_unpaired",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is "unpaired" really necessary? As far as I understand, there is no such thing as "paired" version for kto, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

APO-zero does have a paired and unpaired variant, and you could definitely construct a paired variant of KTO.

We can remove "_unpaired" here since the KTOTrainer also implies it, but I thought it would be good for people to actively think about the distinction when selecting a loss.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes given we have an apo_zero loss also in the DPOTrainer, It's good to retain the _unpaired distinction IMO

Would you mind adding this loss term to the intergration tests here:

@parameterized.expand(

You might want to look at the DPO trainer for inspiration:

@parameterized.expand(

@KarelDO
Copy link
Contributor Author

KarelDO commented Aug 24, 2024

Hey @qgallouedec , thanks!

Here's a wandb report pdf with some training curves. I can't share the full WandB project yet, but I have a print out of the report attached. I've summarized the main take-aways below:

  • APO-zero trains considerably faster
  • On an RLAIF preference dataset, the training dynamics of KTO and APO-zero seem identical
  • on a CLAIR preference dataset, the training dynamics of KTO and APO-zero differ. This is due to a higher KL on this dataset. APO-zero does display the intended training dynamic of smoothly increasing desirable rewards and decrease undesirable rewards without calculating a KL.

TL;DR: different loss functions, they sometimes behave similarly depending on underlying preference dataset, APO-zero trains faster.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this nice contribution @KarelDO ! Overall it LGTM once we have some unit / integration tests added.

return new_rows


def maybe_reformat_dpo_to_kto(dataset: DatasetDict, num_proc: int = None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For public methods, would you mind adding a docstring and a unit test please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@@ -60,6 +65,10 @@ class KTOConfig(TrainingArguments):
Number of processes to use for processing the datasets.
"""

loss_type: Literal[
"kto",
"apo_zero_unpaired",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes given we have an apo_zero loss also in the DPOTrainer, It's good to retain the _unpaired distinction IMO

Would you mind adding this loss term to the intergration tests here:

@parameterized.expand(

You might want to look at the DPO trainer for inspiration:

@parameterized.expand(

@lewtun
Copy link
Member

lewtun commented Aug 29, 2024

Thanks for iterating! Would you mind fixing the code quality test and then we can merge!

@KarelDO
Copy link
Contributor Author

KarelDO commented Aug 29, 2024

Thanks @lewtun , should be fixed now!

@lewtun
Copy link
Member

lewtun commented Sep 2, 2024

Ah it seems some of the KTO tests are now failing after rebasing on main - would you mind fixing those 🙏 ?

@karel-contextual
Copy link
Contributor

@lewtun we should be good now!

@lewtun
Copy link
Member

lewtun commented Sep 4, 2024

Thanks for iterating!

@lewtun lewtun merged commit 7acb9c2 into huggingface:main Sep 4, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants