Skip to content

Commit

Permalink
👐 DeepSpeed integration for GRPO (#2652)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Jan 25, 2025
1 parent 2578e95 commit aeb03cf
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 5 deletions.
4 changes: 2 additions & 2 deletions trl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
_import_structure = {
"modeling_base": ["GeometricMixtureWrapper", "PreTrainedModelWrapper", "create_reference_model"],
"modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"],
"utils": ["SUPPORTED_ARCHITECTURES", "setup_chat_format", "unwrap_model_for_generation"],
"utils": ["SUPPORTED_ARCHITECTURES", "prepare_deepspeed", "setup_chat_format", "unwrap_model_for_generation"],
}

try:
Expand All @@ -39,7 +39,7 @@
if TYPE_CHECKING:
from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import SUPPORTED_ARCHITECTURES, setup_chat_format, unwrap_model_for_generation
from .utils import SUPPORTED_ARCHITECTURES, prepare_deepspeed, setup_chat_format, unwrap_model_for_generation

try:
if not is_diffusers_available():
Expand Down
35 changes: 35 additions & 0 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import itertools
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, Optional, Union

Expand Down Expand Up @@ -193,3 +194,37 @@ def unwrap_model_for_generation(
add_hooks(model)
else:
yield unwrapped_model


def prepare_deepspeed(model, accelerator):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
stage = config_kwargs["zero_optimization"]["stage"]

if model is not None:
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and stage == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache
# @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)

# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO
# disabled (stage 0)
if stage != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
14 changes: 11 additions & 3 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
TrainerCallback,
is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..models import create_reference_model, unwrap_model_for_generation
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .grpo_config import GRPOConfig
from .utils import generate_model_card, get_comet_experiment_url

Expand Down Expand Up @@ -158,6 +159,7 @@ def __init__(
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
model_id = model
torch_dtype = model_init_kwargs.get("torch_dtype")
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
pass # torch_dtype is already a torch.dtype or "auto" or None
Expand All @@ -171,6 +173,7 @@ def __init__(
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
raise ValueError(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
Expand All @@ -181,7 +184,9 @@ def __init__(
model = get_peft_model(model, peft_config)

# Reference model
if peft_config is None:
if is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
elif peft_config is None:
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
else:
Expand Down Expand Up @@ -269,7 +274,10 @@ def data_collator(features): # No data collation is needed in GRPO
self.model_accepts_loss_kwargs = False

if self.ref_model is not None:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
Expand Down

0 comments on commit aeb03cf

Please sign in to comment.