From 50f6debc1bb4e36c2357e772d8a7f2dd06a5a7b2 Mon Sep 17 00:00:00 2001 From: Issa Memari Date: Wed, 11 Sep 2024 11:26:09 +0200 Subject: [PATCH 1/3] Fix logits compuation in KTO trainer prediction step --- trl/trainer/kto_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 74c8eaba12..beb80439f9 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -1324,8 +1324,7 @@ def prediction_step( "eval_logits/chosen": metrics["logits/chosen"], "eval_logits/rejected": metrics["logits/rejected"], } - logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys) - logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device) + logits = torch.tensor([v for k, v in logits_dict.items() if k not in ignore_keys]).to(self.accelerator.device) labels = torch.zeros(logits.shape[0], device=self.accelerator.device) return (loss.detach(), logits, labels) From 9b97a2a82a2e7742a58fb25d92edc3fc3f6a3442 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 11 Sep 2024 13:10:39 +0200 Subject: [PATCH 2/3] Update trl/trainer/kto_trainer.py --- trl/trainer/kto_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index beb80439f9..060ed894b9 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -1324,7 +1324,7 @@ def prediction_step( "eval_logits/chosen": metrics["logits/chosen"], "eval_logits/rejected": metrics["logits/rejected"], } - logits = torch.tensor([v for k, v in logits_dict.items() if k not in ignore_keys]).to(self.accelerator.device) + logits = torch.tensor([v for k, v in logits_dict.items() if k not in ignore_keys], device=self.accelerator.device) labels = torch.zeros(logits.shape[0], device=self.accelerator.device) return (loss.detach(), logits, labels) From fd9b583c74bb7499306917f27f593fda4815f900 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 11 Sep 2024 13:12:05 +0200 Subject: [PATCH 3/3] Update trl/trainer/kto_trainer.py --- trl/trainer/kto_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 060ed894b9..51fdf4c9e9 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -1324,7 +1324,9 @@ def prediction_step( "eval_logits/chosen": metrics["logits/chosen"], "eval_logits/rejected": metrics["logits/rejected"], } - logits = torch.tensor([v for k, v in logits_dict.items() if k not in ignore_keys], device=self.accelerator.device) + logits = torch.tensor( + [v for k, v in logits_dict.items() if k not in ignore_keys], device=self.accelerator.device + ) labels = torch.zeros(logits.shape[0], device=self.accelerator.device) return (loss.detach(), logits, labels)