-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: Add support for APO-zero in KTOTrainer #1952
Changes from all commits
db9cf6a
ddb8ac5
6a56cb7
4b30563
082c996
06fa0f8
da13cd6
1075c4b
1a9bcaf
4e32eed
13617e9
ad33bc7
a541923
67b52b9
1f3ab5a
eec33ba
3ddf513
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import unittest | ||
|
||
from datasets import Dataset, DatasetDict | ||
|
||
from trl.data_utils import maybe_reformat_dpo_to_kto | ||
|
||
|
||
class MaybeReformatDPOToKTOTester(unittest.TestCase): | ||
def setUp(self): | ||
# Create a sample DPO-formatted dataset for testing | ||
self.dpo_data = { | ||
"prompt": ["What is AI?", "Define machine learning."], | ||
"chosen": ["AI is artificial intelligence.", "Machine learning is a subset of AI."], | ||
"rejected": ["AI is a computer.", "Machine learning is a program."], | ||
} | ||
self.dpo_dataset = DatasetDict({"train": Dataset.from_dict(self.dpo_data)}) | ||
|
||
# Create a sample KTO-formatted dataset for testing | ||
self.kto_data = { | ||
"prompt": ["What is AI?", "Define machine learning.", "What is AI?", "Define machine learning."], | ||
"completion": [ | ||
"AI is artificial intelligence.", | ||
"Machine learning is a subset of AI.", | ||
"AI is a computer.", | ||
"Machine learning is a program.", | ||
], | ||
"label": [True, True, False, False], | ||
} | ||
self.kto_dataset = DatasetDict({"train": Dataset.from_dict(self.kto_data)}) | ||
|
||
def test_dpo_to_kto_conversion(self): | ||
# Test that a DPO-formatted dataset is correctly reformatted to KTO format | ||
reformatted_dataset = maybe_reformat_dpo_to_kto(self.dpo_dataset) | ||
self.assertEqual( | ||
reformatted_dataset["train"].to_dict(), | ||
self.kto_dataset["train"].to_dict(), | ||
"The DPO-formatted dataset was not correctly reformatted to KTO format.", | ||
) | ||
|
||
def test_already_kto_format(self): | ||
# Test that a KTO-formatted dataset remains unchanged | ||
reformatted_dataset = maybe_reformat_dpo_to_kto(self.kto_dataset) | ||
self.assertEqual( | ||
reformatted_dataset["train"].to_dict(), | ||
self.kto_dataset["train"].to_dict(), | ||
"The KTO-formatted dataset should remain unchanged.", | ||
) | ||
|
||
def test_invalid_format(self): | ||
# Test that a dataset with an incompatible format raises a ValueError | ||
invalid_data = { | ||
"input": ["What is AI?", "Define machine learning."], | ||
"output": ["AI is artificial intelligence.", "Machine learning is a subset of AI."], | ||
} | ||
invalid_dataset = DatasetDict({"train": Dataset.from_dict(invalid_data)}) | ||
|
||
with self.assertRaises(ValueError): | ||
maybe_reformat_dpo_to_kto(invalid_dataset) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright 2022 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from copy import deepcopy | ||
|
||
from datasets import DatasetDict | ||
|
||
|
||
def _reformat_row_dpo_to_kto(row: dict): | ||
"""Turn a DPO-formatted dataset row into two KTO-formatted rows.""" | ||
|
||
chosen_row = {"prompt": row["prompt"], "completion": row["chosen"], "label": [True] * len(row["chosen"])} | ||
rejected_row = { | ||
"prompt": row["prompt"], | ||
"completion": row["rejected"], | ||
"label": [False] * len(row["chosen"]), | ||
} | ||
new_rows = {k: chosen_row[k] + rejected_row[k] for k in chosen_row.keys()} | ||
return new_rows | ||
|
||
|
||
def maybe_reformat_dpo_to_kto(dataset: DatasetDict, num_proc: int = None): | ||
""" | ||
Reformat a dataset from the DPO format to the KTO format if necessary. | ||
|
||
This function checks whether the input dataset is already in the KTO format (containing "prompt", "completion", and "label" fields). | ||
If the dataset is in DPO format (with "prompt", "chosen", and "rejected" fields), it converts it to KTO format by: | ||
- Removing any unnecessary columns. | ||
- Reformatting each row to create a unified format suitable for KTO training. | ||
|
||
Args: | ||
dataset (DatasetDict): The dataset to potentially reformat. | ||
num_proc (int, optional): The number of processes to use for multiprocessing during dataset transformation. Defaults to None. | ||
|
||
Returns: | ||
DatasetDict: The reformatted dataset, if conversion was needed; otherwise, the original dataset. | ||
|
||
Raises: | ||
ValueError: If the dataset format is not compatible with KTO or DPO. | ||
""" | ||
keys = list(dataset["train"].features.keys()) | ||
|
||
# check if the dataset is in the KTO format or needs to be reformatted | ||
if "prompt" in keys and "completion" in keys and "label" in keys: | ||
return dataset | ||
elif "prompt" in keys and "rejected" in keys and "chosen" in keys: | ||
# remove unnecessary fields | ||
keys_to_remove = deepcopy(keys) | ||
keys_to_remove.remove("prompt") | ||
keys_to_remove.remove("chosen") | ||
keys_to_remove.remove("rejected") | ||
dataset = dataset.remove_columns(keys_to_remove) | ||
|
||
# turn each DPO-formatted row into two KTO-formatted rows. | ||
dataset = dataset.map( | ||
_reformat_row_dpo_to_kto, | ||
num_proc=num_proc, | ||
batched=True, | ||
remove_columns=["chosen", "rejected"], | ||
desc="Reformatting Dataset from DPO format to KTO format.", | ||
) | ||
return dataset | ||
else: | ||
raise ValueError("Dataset format not compatible with KTO.") |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,7 +12,7 @@ | |||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
from dataclasses import dataclass | ||||||
from typing import Dict, Optional | ||||||
from typing import Dict, Literal, Optional | ||||||
|
||||||
from transformers import TrainingArguments | ||||||
|
||||||
|
@@ -27,6 +27,11 @@ class KTOConfig(TrainingArguments): | |||||
command line. | ||||||
|
||||||
Parameters: | ||||||
loss_type (`str`, *optional*, defaults to `"kto"`): | ||||||
The type of unpaired loss to use. Possible values are: | ||||||
|
||||||
- `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper. | ||||||
- `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. | ||||||
max_length (`int`, *optional*, defaults to `None`): | ||||||
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. | ||||||
max_prompt_length (`int`, *optional*, defaults to `None`): | ||||||
|
@@ -60,6 +65,10 @@ class KTOConfig(TrainingArguments): | |||||
Number of processes to use for processing the datasets. | ||||||
""" | ||||||
|
||||||
loss_type: Literal[ | ||||||
"kto", | ||||||
"apo_zero_unpaired", | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is "unpaired" really necessary? As far as I understand, there is no such thing as "paired" version for kto, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. APO-zero does have a paired and unpaired variant, and you could definitely construct a paired variant of KTO. We can remove "_unpaired" here since the KTOTrainer also implies it, but I thought it would be good for people to actively think about the distinction when selecting a loss. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes given we have an Would you mind adding this loss term to the intergration tests here: Line 76 in 47ab034
You might want to look at the DPO trainer for inspiration: Line 251 in 47ab034
|
||||||
] = "kto" | ||||||
max_length: Optional[int] = None | ||||||
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.""" | ||||||
max_prompt_length: Optional[int] = None | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For public methods, would you mind adding a docstring and a unit test please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!