Skip to content

Commit

Permalink
rename kto loss (#1127)
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif authored Dec 22, 2023
1 parent 06b7959 commit 814fe39
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical

The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used. Use the `loss_type="cdpo"` argument to the trainer to use it.

The [KTO](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf) loss is derived to directly maximize the utility of LLM generations instead of the log-likelihood of prefereces. Thus the dataset are not neccsarily prefereces but rather desirable vs undersirable pairs. Use the `loss_type="kto"` argument to the trainer to utilize this loss.
The [KTO](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf) loss is derived to directly maximize the utility of LLM generations instead of the log-likelihood of preferences. Thus the dataset are not necessarily preferences but rather desirable vs undesirable completions. For paired preference data as required by the `DPOTrainer`, use the `loss_type="kto_pair"` argument to the trainer to utilize this loss, while for the more general case of desired and undesirable data, use the as of yet unimplemented `KTOTrainer`.

## Logging

Expand Down
4 changes: 2 additions & 2 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def _init_dummy_dataset(self):
["t5", "hinge", False],
["gpt2", "ipo", False],
["t5", "ipo", True],
["gpt2", "kto", True],
["t5", "kto", False],
["gpt2", "kto_pair", True],
["t5", "kto_pair", False],
]
)
def test_dpo_trainer(self, name, loss_type, pre_compute):
Expand Down
7 changes: 4 additions & 3 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class DPOTrainer(Trainer):
ref_model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the ref model from a string
"""

_tag_name = "trl-dpo"

def __init__(
Expand Down Expand Up @@ -320,7 +321,7 @@ def make_inputs_require_grad(module, input, output):
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False

if loss_type in ["hinge", "ipo", "kto"] and label_smoothing > 0:
if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0:
warnings.warn(
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
)
Expand Down Expand Up @@ -801,7 +802,7 @@ def dpo_loss(
elif self.loss_type == "ipo":
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
losses = (logits - 1 / (2 * self.beta)) ** 2
elif self.loss_type == "kto":
elif self.loss_type == "kto_pair":
# eqn (7) of the HALOs paper
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)
Expand All @@ -818,7 +819,7 @@ def dpo_loss(
)
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto']"
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
)

chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
Expand Down

0 comments on commit 814fe39

Please sign in to comment.