Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add support for APO-zero in KTOTrainer #1952

Merged
merged 17 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, setup_chat_format
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_reformat_dpo_to_kto, setup_chat_format


# Define and parse arguments.
Expand Down Expand Up @@ -97,10 +97,17 @@ class ScriptArguments:
# Load the dataset
dataset = load_dataset(script_args.dataset_name)

# If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
dataset = maybe_reformat_dpo_to_kto(dataset, num_proc=kto_args.dataset_num_proc)

# Apply chat template
def format_dataset(example):
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
if isinstance(example["completion"], str):
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
else:
example["prompt"] = tokenizer.apply_chat_template(example["completion"][:-1], tokenize=False)
example["completion"] = tokenizer.apply_chat_template([example["completion"][-1]], tokenize=False)
return example

# Compute that only on the main process for faster data processing.
Expand Down
71 changes: 71 additions & 0 deletions tests/test_dataset_reformat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

from datasets import Dataset, DatasetDict

from trl.data_utils import maybe_reformat_dpo_to_kto


class MaybeReformatDPOToKTOTester(unittest.TestCase):
def setUp(self):
# Create a sample DPO-formatted dataset for testing
self.dpo_data = {
"prompt": ["What is AI?", "Define machine learning."],
"chosen": ["AI is artificial intelligence.", "Machine learning is a subset of AI."],
"rejected": ["AI is a computer.", "Machine learning is a program."],
}
self.dpo_dataset = DatasetDict({"train": Dataset.from_dict(self.dpo_data)})

# Create a sample KTO-formatted dataset for testing
self.kto_data = {
"prompt": ["What is AI?", "Define machine learning.", "What is AI?", "Define machine learning."],
"completion": [
"AI is artificial intelligence.",
"Machine learning is a subset of AI.",
"AI is a computer.",
"Machine learning is a program.",
],
"label": [True, True, False, False],
}
self.kto_dataset = DatasetDict({"train": Dataset.from_dict(self.kto_data)})

def test_dpo_to_kto_conversion(self):
# Test that a DPO-formatted dataset is correctly reformatted to KTO format
reformatted_dataset = maybe_reformat_dpo_to_kto(self.dpo_dataset)
self.assertEqual(
reformatted_dataset["train"].to_dict(),
self.kto_dataset["train"].to_dict(),
"The DPO-formatted dataset was not correctly reformatted to KTO format.",
)

def test_already_kto_format(self):
# Test that a KTO-formatted dataset remains unchanged
reformatted_dataset = maybe_reformat_dpo_to_kto(self.kto_dataset)
self.assertEqual(
reformatted_dataset["train"].to_dict(),
self.kto_dataset["train"].to_dict(),
"The KTO-formatted dataset should remain unchanged.",
)

def test_invalid_format(self):
# Test that a dataset with an incompatible format raises a ValueError
invalid_data = {
"input": ["What is AI?", "Define machine learning."],
"output": ["AI is artificial intelligence.", "Machine learning is a subset of AI."],
}
invalid_dataset = DatasetDict({"train": Dataset.from_dict(invalid_data)})

with self.assertRaises(ValueError):
maybe_reformat_dpo_to_kto(invalid_dataset)
17 changes: 10 additions & 7 deletions tests/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,17 @@ def _init_dummy_dataset(self):

@parameterized.expand(
[
["gpt2", True, True],
["gpt2", True, False],
# ["t5", True],
["gpt2", False, True],
["gpt2", False, False],
# ["t5", False],
["gpt2", "kto", True, True],
["gpt2", "kto", True, False],
["gpt2", "kto", False, True],
["gpt2", "kto", False, False],
["gpt2", "apo_zero_unpaired", True, True],
["gpt2", "apo_zero_unpaired", True, False],
["gpt2", "apo_zero_unpaired", False, True],
["gpt2", "apo_zero_unpaired", False, False],
]
)
def test_kto_trainer(self, name, pre_compute, eval_dataset):
def test_kto_trainer(self, name, loss_type, pre_compute, eval_dataset):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
Expand All @@ -95,6 +97,7 @@ def test_kto_trainer(self, name, pre_compute, eval_dataset):
eval_strategy="steps",
beta=0.1,
precompute_ref_log_probs=pre_compute,
loss_type=loss_type,
report_to="none",
)

Expand Down
2 changes: 2 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
"MultitaskPromptTuningConfig",
"MultitaskPromptTuningInit",
],
"data_utils": ["maybe_reformat_dpo_to_kto"],
}

try:
Expand Down Expand Up @@ -162,6 +163,7 @@
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
from .commands.cli_utils import init_zero_verbose, SFTScriptArguments, DPOScriptArguments, TrlParser
from .data_utils import maybe_reformat_dpo_to_kto

try:
if not is_diffusers_available():
Expand Down
2 changes: 1 addition & 1 deletion trl/commands/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from rich.console import Console


SUPPORTED_COMMANDS = ["sft", "dpo", "chat"]
SUPPORTED_COMMANDS = ["sft", "dpo", "chat", "kto"]


def main():
Expand Down
74 changes: 74 additions & 0 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy

from datasets import DatasetDict


def _reformat_row_dpo_to_kto(row: dict):
"""Turn a DPO-formatted dataset row into two KTO-formatted rows."""

chosen_row = {"prompt": row["prompt"], "completion": row["chosen"], "label": [True] * len(row["chosen"])}
rejected_row = {
"prompt": row["prompt"],
"completion": row["rejected"],
"label": [False] * len(row["chosen"]),
}
new_rows = {k: chosen_row[k] + rejected_row[k] for k in chosen_row.keys()}
return new_rows


def maybe_reformat_dpo_to_kto(dataset: DatasetDict, num_proc: int = None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For public methods, would you mind adding a docstring and a unit test please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

"""
Reformat a dataset from the DPO format to the KTO format if necessary.

This function checks whether the input dataset is already in the KTO format (containing "prompt", "completion", and "label" fields).
If the dataset is in DPO format (with "prompt", "chosen", and "rejected" fields), it converts it to KTO format by:
- Removing any unnecessary columns.
- Reformatting each row to create a unified format suitable for KTO training.

Args:
dataset (DatasetDict): The dataset to potentially reformat.
num_proc (int, optional): The number of processes to use for multiprocessing during dataset transformation. Defaults to None.

Returns:
DatasetDict: The reformatted dataset, if conversion was needed; otherwise, the original dataset.

Raises:
ValueError: If the dataset format is not compatible with KTO or DPO.
"""
keys = list(dataset["train"].features.keys())

# check if the dataset is in the KTO format or needs to be reformatted
if "prompt" in keys and "completion" in keys and "label" in keys:
return dataset
elif "prompt" in keys and "rejected" in keys and "chosen" in keys:
# remove unnecessary fields
keys_to_remove = deepcopy(keys)
keys_to_remove.remove("prompt")
keys_to_remove.remove("chosen")
keys_to_remove.remove("rejected")
dataset = dataset.remove_columns(keys_to_remove)

# turn each DPO-formatted row into two KTO-formatted rows.
dataset = dataset.map(
_reformat_row_dpo_to_kto,
num_proc=num_proc,
batched=True,
remove_columns=["chosen", "rejected"],
desc="Reformatting Dataset from DPO format to KTO format.",
)
return dataset
else:
raise ValueError("Dataset format not compatible with KTO.")
11 changes: 10 additions & 1 deletion trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Dict, Literal, Optional

from transformers import TrainingArguments

Expand All @@ -27,6 +27,11 @@ class KTOConfig(TrainingArguments):
command line.

Parameters:
loss_type (`str`, *optional*, defaults to `"kto"`):
The type of unpaired loss to use. Possible values are:

- `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
- `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
max_length (`int`, *optional*, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
max_prompt_length (`int`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -60,6 +65,10 @@ class KTOConfig(TrainingArguments):
Number of processes to use for processing the datasets.
"""

loss_type: Literal[
"kto",
"apo_zero_unpaired",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is "unpaired" really necessary? As far as I understand, there is no such thing as "paired" version for kto, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

APO-zero does have a paired and unpaired variant, and you could definitely construct a paired variant of KTO.

We can remove "_unpaired" here since the KTOTrainer also implies it, but I thought it would be good for people to actively think about the distinction when selecting a loss.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes given we have an apo_zero loss also in the DPOTrainer, It's good to retain the _unpaired distinction IMO

Would you mind adding this loss term to the intergration tests here:

@parameterized.expand(

You might want to look at the DPO trainer for inspiration:

@parameterized.expand(

] = "kto"
max_length: Optional[int] = None
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
max_prompt_length: Optional[int] = None
Expand Down
Loading
Loading