Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚖️ Add use_soft_judge option to WinRateCallback #2347

Merged
merged 14 commits into from
Nov 15, 2024
49 changes: 47 additions & 2 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@


class HalfPairwiseJudge(BasePairwiseJudge):
"""Naive pairwise judge that always returns [1, 0]"""
"""Naive pairwise judge that always returns [1, 0] for two prompts"""

def judge(self, prompts, completions, shuffle_order=True):
def judge(self, prompts, completions, shuffle_order=True, return_scores=False):
# just check that the batch size is 2
assert len(prompts) == 2
if return_scores:
return [0.3, 0.9]
return [1, 0]


Expand Down Expand Up @@ -132,6 +134,49 @@ def test_without_ref_model(self):
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
self.assertListEqual(winrate_history, self.expected_winrates)

def test_soft_judge(self):
"""Test that the soft judge functionality works correctly"""
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
eval_strategy="steps",
eval_steps=2, # evaluate every 2 steps
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
per_device_eval_batch_size=2,
report_to="none",
)
trainer = TrainerWithRefModel(
model=self.model,
ref_model=self.ref_model,
args=training_args,
train_dataset=self.dataset["train"],
eval_dataset=self.dataset["test"],
processing_class=self.tokenizer,
)
win_rate_callback = WinRateCallback(
judge=self.judge, trainer=trainer, generation_config=self.generation_config, use_soft_judge=True
)
trainer.add_callback(win_rate_callback)
trainer.train()

# Expected values based on judge returning [0.3, 0.9] for each pair
expected_soft_winrates = [
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.0, "step": 0},
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.5, "step": 2},
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.0, "step": 4},
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.5, "step": 6},
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.0, "step": 8},
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.5, "step": 10},
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 3.0, "step": 12},
]

winrate_history = [
{k: h[k] for k in ["eval_avg_win_prob", "eval_win_rate", "epoch", "step"]}
for h in trainer.state.log_history
if "eval_avg_win_prob" in h
]
self.assertListEqual(winrate_history, expected_soft_winrates)

@require_peft
def test_lora(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
32 changes: 28 additions & 4 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ class WinRateCallback(TrainerCallback):
in the evaluation dataset.
shuffle_order (`bool`, *optional*, defaults to `True`):
Whether to shuffle the order of the completions before judging.
use_soft_judge (`bool`, *optional*, defaults to `False`):
Whether to use a soft judge that returns a win probability between 0 and 1 for the first completion vs the
second.
"""

def __init__(
Expand All @@ -239,12 +242,14 @@ def __init__(
generation_config: Optional[GenerationConfig] = None,
num_prompts: Optional[int] = None,
shuffle_order: bool = True,
use_soft_judge: bool = False,
):
self.judge = judge
self.trainer = trainer
self.shuffle_order = shuffle_order
self.generation_config = generation_config
self.ref_completions = []
self.use_soft_judge = use_soft_judge

if self.trainer.eval_dataset is None:
raise ValueError("Trainer must have an evaluation dataset to use the WinRateCallback.")
Expand Down Expand Up @@ -281,15 +286,24 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control:
)
# Compute initial win rate as a reference point
completions = list(zip(self.ref_completions, self.ref_completions))
winner_indices = self.judge.judge(prompts, completions, self.shuffle_order)
if self.use_soft_judge:
ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True)
winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs]
ref_win_probs = gather_object(ref_win_probs)
else:
winner_indices = self.judge.judge(prompts, completions, self.shuffle_order)
prompts = gather_object(prompts)
completions = gather_object(completions)
winner_indices = gather_object(winner_indices)

# Logging
if self.trainer.accelerator.is_main_process:
win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices)
self.trainer.log({"eval_win_rate": win_rate})
if self.use_soft_judge:
avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs)
self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate})
else:
self.trainer.log({"eval_win_rate": win_rate})

if "wandb" in args.report_to:
import wandb
Expand Down Expand Up @@ -323,15 +337,25 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra
)

completions = list(zip(self.ref_completions, completions))
winner_indices = self.judge.judge(prompts, completions, self.shuffle_order)

if self.use_soft_judge:
ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True)
winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs]
ref_win_probs = gather_object(ref_win_probs)
else:
winner_indices = self.judge.judge(prompts, completions, self.shuffle_order)
prompts = gather_object(prompts)
completions = gather_object(completions)
winner_indices = gather_object(winner_indices)

# Logging
if self.trainer.accelerator.is_main_process:
win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices)
kashif marked this conversation as resolved.
Show resolved Hide resolved
self.trainer.log({"eval_win_rate": win_rate})
if self.use_soft_judge:
avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs)
self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate})
else:
self.trainer.log({"eval_win_rate": win_rate})

if "wandb" in args.report_to:
import wandb
Expand Down
Loading