diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 994c6dced2..353b8d1298 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -1350,11 +1350,13 @@ def get_batch_loss_metrics( if all_num_chosen > 0: metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item() metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item() + metrics["logits/chosen"] = self.accelerator.gather(policy_chosen_logits.nansum()).nanmean().item() metrics["count/chosen"] = all_num_chosen if all_num_rejected > 0: metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item() metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item() + metrics["logits/rejected"] = self.accelerator.gather(policy_rejected_logits.nansum()).nanmean().item() metrics["count/rejected"] = all_num_rejected metrics["kl"] = kl.item() @@ -1512,10 +1514,10 @@ def evaluation_loop( random_batch = self.data_collator(random_batch_dataset) random_batch = self._prepare_inputs(random_batch) - target_indicies = [i for i in range(len(random_batch["kl"])) if random_batch["kl"][i] is False] + target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False] target_batch = { - "prompt_input_ids": itemgetter(*target_indicies)(random_batch["prompt_input_ids"]), - "prompt_attention_mask": itemgetter(*target_indicies)(random_batch["prompt_attention_mask"]), + "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies], + "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies], "prompt": itemgetter(*target_indicies)(random_batch["prompt"]), } policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, target_batch)