diff --git a/examples/nlp/gpt/conf/gpt_reinforce_actor.yaml b/examples/nlp/gpt/conf/gpt_reinforce_actor.yaml index 40ec9c54d..3fb6b5011 100644 --- a/examples/nlp/gpt/conf/gpt_reinforce_actor.yaml +++ b/examples/nlp/gpt/conf/gpt_reinforce_actor.yaml @@ -9,6 +9,9 @@ trainer: precision: bf16 reinforce: + rpo_metric: "sq_loo" # sq_loss or bwd_kl + gt_reward_scale: 1. # the scale of the RM's GT rewards + # How many steps we train warmup the critic for (without training the policy) # this may help prevent the critic loss from hockey sticking since # the critic is initialized from the reward model and may not be initially diff --git a/nemo_aligner/algorithms/reinforce.py b/nemo_aligner/algorithms/reinforce.py index bf9e6900b..257cbb738 100644 --- a/nemo_aligner/algorithms/reinforce.py +++ b/nemo_aligner/algorithms/reinforce.py @@ -39,6 +39,7 @@ from nemo_aligner.utils.parallel_state import is_trt_llm_reshard, trt_llm_reshard_region from nemo_aligner.utils.ppo_utils import ( calculate_rloo_baseline, + calculate_rewards_logprobs, calculate_kl_penalty, create_mask, ) @@ -362,17 +363,34 @@ def _run_inference(self, dataloader_builder, consumed_samples, is_validation): init_policy_kl = masked_mean(init_policy_kl, mask, dim=-1) # Calculate RLOO baseline - rewards_with_kl = balanced_local_batch["rewards"] - self.cfg.initial_policy_kl_penalty * init_policy_kl - baseline, baseline_std = calculate_rloo_baseline( - prompts=balanced_local_batch["prompt_tokens"], - reward=rewards_with_kl, - mask=balanced_local_batch["is_end"].float(), - normalize_variance=self.cfg.normalize_variance - ) + if self.cfg.rpo_metric == "sq_loo": + rewards_with_kl = balanced_local_batch["rewards"] - self.cfg.initial_policy_kl_penalty * init_policy_kl + baseline, baseline_std = calculate_rloo_baseline( + prompts=balanced_local_batch["prompt_tokens"], + reward=rewards_with_kl, + mask=balanced_local_batch["is_end"].float(), + normalize_variance=self.cfg.normalize_variance + ) + + if self.cfg.normalize_variance: + rewards_with_kl /= baseline_std + baseline /= baseline_std + elif self.cfg.rpo_metric == "bwd_kl": + logprobs_gt_rewards = calculate_rewards_logprobs( + prompts=balanced_local_batch["prompt_tokens"], + reward=self.cfg.gt_reward_scale * balanced_local_batch["rewards"], + mask=balanced_local_batch["is_end"].float(), + ) + logprobs_predicted_rewards = calculate_rewards_logprobs( + prompts=balanced_local_batch["prompt_tokens"], + reward=self.cfg.initial_policy_kl_penalty * init_policy_kl, + mask=balanced_local_batch["is_end"].float(), + ) - if self.cfg.normalize_variance: - rewards_with_kl /= baseline_std - baseline /= baseline_std + rewards_with_kl = logprobs_gt_rewards + baseline = logprobs_predicted_rewards + else: + raise ValueError(f"The rpo_metric = {self.cfg.rpo_metric} is not supported") balanced_local_batch["rewards_with_kl"] = rewards_with_kl balanced_local_batch["baseline"] = baseline diff --git a/nemo_aligner/utils/ppo_utils.py b/nemo_aligner/utils/ppo_utils.py index 38fb36062..b847b7b00 100644 --- a/nemo_aligner/utils/ppo_utils.py +++ b/nemo_aligner/utils/ppo_utils.py @@ -121,4 +121,25 @@ def calculate_rloo_baseline(prompts, reward, mask, normalize_variance): baseline_std[prompt_idx] = loo_std + 1e-3 print(loo_std.shape, baseline_std) - return baseline, baseline_std \ No newline at end of file + return baseline, baseline_std + + +def calculate_rewards_logprobs(prompts, reward, mask): + ''' + For each prompt and its corresponding rewards r1, r2, ..., rk. + Compute the log probs of the rewards: log (exp(ri) / (exp(r1) + ... + exp(rk))). + ''' + unique_prompts = torch.unique(prompts, dim=0) + + logprobs = torch.zeros_like(reward) + reward_device = reward.get_device() + for i in range(len(unique_prompts)): + is_matching_prompt = (prompts == unique_prompts[i]).all(1) + prompt_idx = torch.arange(len(prompts), device=reward_device)[is_matching_prompt] + + if mask[prompt_idx].sum() <= 1: + logprobs[prompt_idx] = torch.ones_like(reward[prompt_idx]) + else: + logprobs[prompt_idx] = torch.nn.functional.log_softmax(reward[prompt_idx] * mask[prompt_idx], dim=0) + + return logprobs