diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 380cedf7cb..0dd06dc728 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -186,7 +186,18 @@ def __init__( inspect.signature(prepare_model_for_kbit_training).parameters ) gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + is_sharded_qlora = False + # Below is to support QLoRA + FSDP / DS-Zero3 - one should never call + # peft_module_casting_to_bf16 or prepare_model_for_kbit_training when doing + # QLoRA + FSDP / DS-Zero3 + if getattr(model, "is_loaded_in_4bit", False): + for _, param in model.named_parameters(): + if param.__class__.__name__ == "Params4bit": + is_sharded_qlora = param.data.device.type == "cpu" + break + if getattr(model, "is_loaded_in_8bit", False) or ( + getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora + ): prepare_model_kwargs = { "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False) } @@ -213,7 +224,12 @@ def make_inputs_require_grad(module, input, output): model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) model = get_peft_model(model, peft_config) - if args is not None and args.bf16 and getattr(model, "is_loaded_in_4bit", False): + if ( + args is not None + and args.bf16 + and getattr(model, "is_loaded_in_4bit", False) + and not is_sharded_qlora + ): peft_module_casting_to_bf16(model) if tokenizer is None: