diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 916f42072b..8e7555d86d 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -964,14 +964,14 @@ def get_batch_loss_metrics( reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean() - metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean() - metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean() - metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean() - metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean() - metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean() - metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean() - metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean() + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu() + metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu() + metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu() + metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() return losses.mean(), metrics