Skip to content

Commit

Permalink
roll back to distribute generation
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Feb 6, 2025
1 parent 9025cbc commit 0695722
Showing 1 changed file with 47 additions and 83 deletions.
130 changes: 47 additions & 83 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
import torch
import torch.utils.data
import transformers
from accelerate import Accelerator
from accelerate.utils import broadcast, broadcast_object_list, gather_object
from accelerate.utils import broadcast_object_list, gather, gather_object
from accelerate.utils.other import is_compiled_module
from datasets import Dataset, IterableDataset
from packaging import version
from torch import Tensor, nn
from torch import nn
from torch.utils.data import Sampler
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -96,55 +95,6 @@ def __len__(self):
return self.num_samples * self.repeat_count


def broadcast_and_slice_dict(
accelerator: Accelerator, tensor_dict: Union[dict[str, Tensor], None], from_process: int = 0
) -> dict[str, Tensor]:
"""
Broadcasts a dictionary of tensors from one process to all processes and slices the tensors based on the process
index.
```
Process 0 Process 0 Process 1
[[ 1, 2, 3], -> [[1, 2, 3], [[ 7, 8, 9],
[ 4, 5, 6], [4, 5, 6]] [10, 11, 12]]
[ 7, 8, 9],
[10, 11, 12]]
}
```
"""

is_from = accelerator.local_process_index == from_process

# Only rank 0 has the tensor_dict, others start with None
metadata = {k: (v.shape, v.dtype) for k, v in tensor_dict.items()} if is_from else None

# Broadcast metadata to all processes
metadata = [metadata] # Wrap in a list to make it mutable
metadata = broadcast_object_list(metadata, from_process=from_process)
metadata = metadata[0] # Unwrap

# Non-main processes initialize empty tensor_dict
if not is_from:
tensor_dict = {
k: torch.empty(shape, dtype=dtype, device=accelerator.device) for k, (shape, dtype) in metadata.items()
}

# Broadcast tensors
for k in tensor_dict.keys():
tensor_dict[k] = broadcast(tensor_dict[k], from_process=from_process)

# Compute slice indices
B = next(iter(tensor_dict.values())).shape[0] # Get first dimension size
assert B % accelerator.num_processes == 0, "Batch size must be divisible by world size"
chunk_size = B // accelerator.num_processes

# Slice the tensors
start, end = accelerator.local_process_index * chunk_size, (accelerator.local_process_index + 1) * chunk_size
sliced_dict = {k: v[start:end] for k, v in tensor_dict.items()}

return sliced_dict


class GRPOTrainer(Trainer):
"""
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
Expand Down Expand Up @@ -487,28 +437,6 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
return torch.stack(per_token_logps)

def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
# When using vLLM, ne need to first update the model weights
if self.args.use_vllm and self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
state_dict = unwrapped_model._orig_mod.state_dict()
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
self._last_loaded_step = self.state.global_step

# Gather inputs and process them in the main process. This is important because the rewards are normalized
# per group.
inputs = gather_object(inputs)
prepared = self._prepare_main(inputs) if self.accelerator.is_main_process else None
prepared = broadcast_and_slice_dict(self.accelerator, prepared)
return prepared

def _prepare_main(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
Expand All @@ -524,9 +452,36 @@ def _prepare_main(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str

# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
state_dict = unwrapped_model._orig_mod.state_dict()
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
self._last_loaded_step = self.state.global_step

# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
outputs = self.llm.generate(prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
else:
completion_ids = [None] * len(all_prompts_text) * self.num_generations

# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts) * self.num_generations,
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
)
completion_ids = completion_ids[process_slice]

# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
Expand Down Expand Up @@ -574,7 +529,6 @@ def _prepare_main(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str
if is_conversational(inputs[0]):
completions = [[{"role": "assistant", "content": completion}] for completion in completions]

# Compute the rewards
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
Expand All @@ -598,6 +552,9 @@ def _prepare_main(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

# Gather the reward per function: this part is crucial, because the rewards are normalized per group
rewards_per_func = gather(rewards_per_func)

# Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1)

Expand All @@ -610,6 +567,13 @@ def _prepare_main(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

# Slice to keep only the local part of the data
process_slice = slice(
self.accelerator.process_index * len(prompts) * self.num_generations,
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
)
advantages = advantages[process_slice]

# Log the metrics
reward_per_func = rewards_per_func.mean(0)
for i, reward_func in enumerate(self.reward_funcs):
Expand All @@ -623,12 +587,12 @@ def _prepare_main(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())

return {
"prompt_ids": prompt_ids.contiguous(),
"prompt_mask": prompt_mask.contiguous(),
"completion_ids": completion_ids.contiguous(),
"completion_mask": completion_mask.contiguous(),
"ref_per_token_logps": ref_per_token_logps.contiguous(),
"advantages": advantages.contiguous(),
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"ref_per_token_logps": ref_per_token_logps,
"advantages": advantages,
}

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
Expand Down

0 comments on commit 0695722

Please sign in to comment.