Skip to content

Commit

Permalink
fix slice
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Feb 6, 2025
1 parent 49dbdf5 commit cb42eb0
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,8 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts) * self.num_generations,
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]

Expand Down Expand Up @@ -575,8 +575,8 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s

# Slice to keep only the local part of the data
process_slice = slice(
self.accelerator.process_index * len(prompts) * self.num_generations,
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
advantages = advantages[process_slice]

Expand Down

0 comments on commit cb42eb0

Please sign in to comment.