From 58c088899670bba21cdf39f06536e0735fdd68c1 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 13 Mar 2024 15:13:45 +0530 Subject: [PATCH] Add support for FSDP+QLoRA and DeepSpeed ZeRO3+QLoRA (#1416) * don't do mp casting * don't use `prepare_for_kbit` when using fsdp+qlora or dsz3+qlora * changes to enable fsdp+qlora and dsz3+qlora * revert * Update sft_trainer.py * quality * fix deprecation using changes from PR https://github.com/huggingface/trl/pull/1415 * fixes * quality * Update trl/trainer/sft_trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * quality * relaunch tests --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- trl/trainer/sft_trainer.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 380cedf7cb..0dd06dc728 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -186,7 +186,18 @@ def __init__( inspect.signature(prepare_model_for_kbit_training).parameters ) gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + is_sharded_qlora = False + # Below is to support QLoRA + FSDP / DS-Zero3 - one should never call + # peft_module_casting_to_bf16 or prepare_model_for_kbit_training when doing + # QLoRA + FSDP / DS-Zero3 + if getattr(model, "is_loaded_in_4bit", False): + for _, param in model.named_parameters(): + if param.__class__.__name__ == "Params4bit": + is_sharded_qlora = param.data.device.type == "cpu" + break + if getattr(model, "is_loaded_in_8bit", False) or ( + getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora + ): prepare_model_kwargs = { "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False) } @@ -213,7 +224,12 @@ def make_inputs_require_grad(module, input, output): model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) model = get_peft_model(model, peft_config) - if args is not None and args.bf16 and getattr(model, "is_loaded_in_4bit", False): + if ( + args is not None + and args.bf16 + and getattr(model, "is_loaded_in_4bit", False) + and not is_sharded_qlora + ): peft_module_casting_to_bf16(model) if tokenizer is None: