Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DPO] add KTO loss #1075

Merged
merged 5 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,17 @@ dpo_trainer.train()

Note that the `beta` is the temperature parameter for the DPO loss, typically something in the range of `0.1` to `0.5`. We ignore the reference model as `beta` -> 0.

## Loss function
## Loss functions

Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the DPO authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.

The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `DPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.

The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer.

The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used.
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used. Use the `loss_type="cdpo"` argument to the trainer to use it.

The [KTO](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf) loss is derived to directly maximize the utility of LLM generations instead of the log-likelihood of prefereces. Thus the dataset are not neccsarily prefereces but rather desirable vs undersirable pairs. Use the `loss_type="kto"` argument to the trainer to utilize this loss.

## Logging

Expand Down
4 changes: 3 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def _init_dummy_dataset(self):
# fmt: on
return Dataset.from_dict(dummy_dataset_dict)

@parameterized.expand([["gpt2", "sigmoid"], ["t5", "hinge"], ["gpt2", "ipo"], ["t5", "ipo"]])
@parameterized.expand(
[["gpt2", "sigmoid"], ["t5", "hinge"], ["gpt2", "ipo"], ["t5", "ipo"], ["gpt2", "kto"], ["t5", "kto"]]
)
def test_dpo_trainer(self, name, loss_type):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
Expand Down
27 changes: 22 additions & 5 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ class DPOTrainer(Trainer):
beta (`float`, defaults to 0.1):
The beta factor in DPO loss. Higher beta means less divergence from the initial policy. For the IPO loss, beta is the regularization parameter denoted by tau in the paper.
label_smoothing (`float`, defaults to 0):
The robust DPO label smoothing parameter that should be between 0 and 0.5.
The robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report that should be between 0 and 0.5.
loss_type (`str`, defaults to `"sigmoid"`):
The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from SLiC paper or `"ipo"` from IPO paper.
The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper, or `"kto"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf).
args (`transformers.TrainingArguments`):
The arguments to use for training.
data_collator (`transformers.DataCollator`):
Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
beta: float = 0.1,
label_smoothing: float = 0,
loss_type: Literal["sigmoid", "hinge", "ipo"] = "sigmoid",
loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] = "sigmoid",
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
label_pad_token_id: int = -100,
Expand Down Expand Up @@ -311,7 +311,7 @@ def make_inputs_require_grad(module, input, output):
self.label_pad_token_id = label_pad_token_id
self.padding_value = padding_value

if loss_type in ["hinge", "ipo"] and label_smoothing > 0:
if loss_type in ["hinge", "ipo", "kto"] and label_smoothing > 0:
warnings.warn(
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
)
Expand Down Expand Up @@ -465,8 +465,25 @@ def dpo_loss(
elif self.loss_type == "ipo":
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
losses = (logits - 1 / (2 * self.beta)) ** 2
elif self.loss_type == "kto":
# eqn (7) of the HALOs paper
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)

chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
# As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half.
losses = torch.cat(
(
1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
),
0,
)
else:
raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo']")
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto']"
)

chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
Expand Down