Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for FSDP+QLoRA and DeepSpeed ZeRO3+QLoRA #1416

Merged
merged 14 commits into from
Mar 13, 2024
17 changes: 15 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,15 @@ def __init__(
inspect.signature(prepare_model_for_kbit_training).parameters
)
gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {}
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
is_sharded_qlora = False
if getattr(model, "is_loaded_in_4bit", False):
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
for _, param in model.named_parameters():
if param.__class__.__name__ == "Params4bit":
is_sharded_qlora = param.data.device.type == "cpu"
break
if getattr(model, "is_loaded_in_8bit", False) or (
getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora
):
prepare_model_kwargs = {
"use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False)
}
Expand All @@ -213,7 +221,12 @@ def make_inputs_require_grad(module, input, output):
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

model = get_peft_model(model, peft_config)
if args is not None and args.bf16 and getattr(model, "is_loaded_in_4bit", False):
if (
args is not None
and args.bf16
and getattr(model, "is_loaded_in_4bit", False)
and not is_sharded_qlora
):
peft_module_casting_to_bf16(model)

if tokenizer is None:
Expand Down
Loading