diff --git a/examples/olora_finetuning/README.md b/examples/olora_finetuning/README.md index fd6e5c3e0c..358a275289 100644 --- a/examples/olora_finetuning/README.md +++ b/examples/olora_finetuning/README.md @@ -69,7 +69,7 @@ train(olora_model) # Your training loop #Save the model after training olora_model.save_pretrained(output_dir, path_initial_model_for_weight_conversion=init_path) ``` -After completing training, you can save and convert your OLoRA model to a conventional LoRA model by setting `path_initial_model_for_weight_conversion` to `init_path`, that is the path of your untrained OLoRA model. This conversion enables you to use multiple adapters with your LoRA model. +After completing training, you can save and convert your OLoRA model to a conventional LoRA model by setting `path_initial_model_for_weight_conversion` to `init_path`, that is the path of your untrained OLoRA model. This conversion enables you to use multiple adapters with your LoRA model. Note that this conversion is not supported if `rslora` is used in combination with `rank_pattern` or `alpha_pattern`. ## Citation ``` @@ -81,4 +81,4 @@ After completing training, you can save and convert your OLoRA model to a conven archivePrefix={arXiv}, primaryClass={cs.CL} } -``` \ No newline at end of file +``` diff --git a/examples/pissa_finetuning/README.md b/examples/pissa_finetuning/README.md index a80aab8f24..abfddbf685 100644 --- a/examples/pissa_finetuning/README.md +++ b/examples/pissa_finetuning/README.md @@ -90,7 +90,7 @@ peft_model = PeftModel.from_pretrained(model, "pissa-llama-2-7b-lora") ``` Utilizing the converted LoRA does not require modifying the parameters of the base model. When multiple converted LoRAs are needed simultaneously, each adapter operates independently without interference, allowing for the adapters to be freely deleted or added. - +Note that this conversion is not supported if `rslora` is used in combination with `rank_pattern` or `alpha_pattern`. ### Fine-tune in 4-bit or 8-bit If quantization fine-tuning is desired, it is necessary to first decompose the original model at full precision and then reload the residual model in either 4-bit or 8-bit configurations. @@ -128,4 +128,4 @@ This approach ensures the preservation of high-frequency, out-of-distribution pa journal={arXiv preprint arXiv:2404.02948}, year={2024} } -``` \ No newline at end of file +``` diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 22216577b4..d1e2cabccb 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -231,9 +231,11 @@ def save_pretrained( difference in adapter before and after fine-tuning is calculated. This difference can be represented as the parameters of a standard LoRA adapter. Using this converted adapter does not require changes to the base model, thus conveniently allowing the use of multiple PiSSA or OLoRA adapters with LoRA adapters, - and the activation or deactivation of any adapters. + and the activation or deactivation of any adapters. Note that this conversion is not supported if + `rslora` is used in combination with `rank_pattern` or `alpha_pattern`. kwargs (additional keyword arguments, *optional*): Additional keyword arguments passed along to the `push_to_hub` method. + """ if os.path.isfile(save_directory): raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") @@ -258,6 +260,13 @@ def save_pretrained( path_initial_model_for_weight_conversion = convert_pissa_to_lora def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs): + if peft_config.use_rslora and (peft_config.rank_pattern or peft_config.alpha_pattern): + msg = ( + "Passing `path_initial_model_for_weight_conversion` to `save_pretrained` is not supported when " + "using `rank_pattern` or `alpha_pattern` at the same time as `use_rslora=True`." + ) + raise ValueError(msg) + if not any( str(peft_config.init_lora_weights).lower().startswith(prefix) for prefix in ["pissa", "olora", "true"] ): @@ -368,7 +377,17 @@ def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion, if path_initial_model_for_weight_conversion is not None: peft_config.init_lora_weights = True peft_config.r *= 2 - peft_config.lora_alpha *= 2 + if not peft_config.use_rslora: + peft_config.lora_alpha *= 2 + else: + # with rslora, we have scaling = alpha / sqrt(r), we thus adjust alpha to keep the same scaling + peft_config.lora_alpha *= 2**0.5 + + if peft_config.rank_pattern: + peft_config.rank_pattern = {key: 2 * val for key, val in peft_config.rank_pattern.items()} + if peft_config.alpha_pattern: + peft_config.alpha_pattern = {key: 2 * val for key, val in peft_config.alpha_pattern.items()} + peft_config.save_pretrained(output_dir, auto_mapping_dict=auto_mapping_dict) peft_config.inference_mode = inference_mode diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 9593712481..941582fe89 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -14,6 +14,7 @@ from __future__ import annotations +import warnings from dataclasses import dataclass, field from typing import Literal, Optional, Union @@ -346,6 +347,26 @@ def __post_init__(self): if self.loftq_config is None: raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.") + # Using post training conversion of modified base weights to restore their initial values (PiSSA, OLoRA) cannot + # be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends + # this when they'll eventually call save_pretrained (i.e. if they'll pass + # path_initial_model_for_weight_conversionl). Therefore, we only warn but don't raise an error here. + if ( + self.use_rslora + and (self.rank_pattern or self.alpha_pattern) + and ( + (isinstance(self.init_lora_weights, str) and (self.init_lora_weights.startswith("pissa"))) + or (self.init_lora_weights == "olora") + ) + ): + msg = ( + "Using Rank-Stabilized LoRA with rank_pattern/alpha_pattern and post-training conversion of modified " + "base weights (PiSSA, OLoRA) means that you won't be able to pass " + "`path_initial_model_for_weight_conversion` to `save_pretrained` to restore the initial values of the " + "base weights; if you intend to do this, please ensure not to use rslora or rank_pattern/alpha_pattern." + ) + warnings.warn(msg) + # convert loftq_config to dict if self.loftq_config and not isinstance(self.loftq_config, dict): self.loftq_config = vars(self.loftq_config) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index df544b606c..09c69a8635 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1470,6 +1470,7 @@ def test_offload_merge(self): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") +@pytest.mark.single_gpu_tests class TestPiSSA: r""" Tests for PiSSA to ensure that it reduces the quantization error compared to normal LoRA quantization. @@ -1656,7 +1657,9 @@ def forward(self, x): assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8 # save the model with conversion - peft_model.save_pretrained(tmp_path / "pissa-model-converted", convert_mutated_to_lora=tmp_path / "init-model") + peft_model.save_pretrained( + tmp_path / "pissa-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model" + ) model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model-converted") output_converted = model_converted(data)[0] @@ -1672,6 +1675,7 @@ def forward(self, x): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") +@pytest.mark.single_gpu_tests class TestOLoRA: r""" Tests for OLoRA to ensure that it reduces the quantization error compared to normal LoRA quantization. diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 32f6a647a8..bc6bce2eb9 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -338,6 +338,195 @@ def test_lora_pissa_conversion_same_output_after_loading(self, data, tmp_path): model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol ) + def test_lora_pissa_conversion_same_output_after_loading_with_rank_pattern(self, data, tmp_path): + # same as above, but using rank_pattern + model = self.get_model() + output_base = model(data)[0] + + # use rank_pattern here; note that since there is only a single linear layer, r is completely overridden + config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"], r=8, rank_pattern={"linear": 32}) + peft_model = get_peft_model(deepcopy(model), config) + # save the initial model + peft_model.peft_config["default"].init_lora_weights = True + peft_model.save_pretrained(tmp_path / "init-model") + peft_model.peft_config["default"].init_lora_weights = "pissa" + + # modify the weights, or else the adapter performs an identity transformation + peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0 + output_pissa = peft_model(data)[0] + + # sanity check + tol = 1e-06 + assert not torch.allclose(output_base, output_pissa, atol=tol, rtol=tol) + + # save the model normally + peft_model.save_pretrained(tmp_path / "pissa-model") + model_loaded = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model") + output_loaded = model_loaded(data)[0] + + assert torch.allclose(output_pissa, output_loaded, atol=tol, rtol=tol) + # sanity check: ranks should still be 8 as initially + assert model_loaded.peft_config["default"].r == 8 + assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 32 + # sanity check: the base model weights were indeed changed + assert not torch.allclose( + model.linear.weight, model_loaded.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + # save the model with conversion + peft_model.save_pretrained( + tmp_path / "pissa-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model" + ) + model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model-converted") + output_converted = model_converted(data)[0] + + assert torch.allclose(output_pissa, output_converted, atol=tol, rtol=tol) + # rank should be double of what it was initially + assert model_converted.peft_config["default"].r == 16 + assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 64 + # base model weights should be the same as the initial model + assert torch.allclose( + model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + def test_lora_pissa_conversion_same_output_after_loading_with_alpha_pattern(self, data, tmp_path): + # same as above, but using alpha_pattern + model = self.get_model() + output_base = model(data)[0] + + # use alpha_pattern here; note that since there is only a single linear layer, lora_alpha is completely + # overridden + config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"], alpha_pattern={"linear": 5}) + peft_model = get_peft_model(deepcopy(model), config) + # save the initial model + peft_model.peft_config["default"].init_lora_weights = True + peft_model.save_pretrained(tmp_path / "init-model") + peft_model.peft_config["default"].init_lora_weights = "pissa" + + # modify the weights, or else the adapter performs an identity transformation + peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0 + output_pissa = peft_model(data)[0] + + # sanity check + tol = 1e-06 + assert not torch.allclose(output_base, output_pissa, atol=tol, rtol=tol) + + # save the model normally + peft_model.save_pretrained(tmp_path / "pissa-model") + model_loaded = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model") + output_loaded = model_loaded(data)[0] + + assert torch.allclose(output_pissa, output_loaded, atol=tol, rtol=tol) + # sanity check: ranks should still be 8 as initially + assert model_loaded.peft_config["default"].r == 8 + assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8 + assert model_loaded.base_model.model.linear.scaling["default"] == 5 / 8 + # sanity check: the base model weights were indeed changed + assert not torch.allclose( + model.linear.weight, model_loaded.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + # save the model with conversion + peft_model.save_pretrained( + tmp_path / "pissa-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model" + ) + model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model-converted") + output_converted = model_converted(data)[0] + + assert torch.allclose(output_pissa, output_converted, atol=tol, rtol=tol) + # rank should be double of what it was initially + assert model_converted.peft_config["default"].r == 16 + assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 16 + assert model_converted.base_model.model.linear.scaling["default"] == 10 / 16 + # base model weights should be the same as the initial model + assert torch.allclose( + model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + def test_lora_pissa_conversion_same_output_after_loading_with_rslora(self, data, tmp_path): + model = self.get_model() + output_base = model(data)[0] + + config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"], r=8, use_rslora=True) + peft_model = get_peft_model(deepcopy(model), config) + # save the initial model + peft_model.peft_config["default"].init_lora_weights = True + peft_model.save_pretrained(tmp_path / "init-model") + peft_model.peft_config["default"].init_lora_weights = "pissa" + + # modify the weights, or else the adapter performs an identity transformation + peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0 + output_pissa = peft_model(data)[0] + + # sanity check + tol = 1e-06 + assert not torch.allclose(output_base, output_pissa, atol=tol, rtol=tol) + + # save the model normally + peft_model.save_pretrained(tmp_path / "pissa-model") + model_loaded = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model") + output_loaded = model_loaded(data)[0] + + assert torch.allclose(output_pissa, output_loaded, atol=tol, rtol=tol) + # sanity check: ranks should still be 8 as initially + assert model_loaded.peft_config["default"].r == 8 + assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8 + assert model_loaded.base_model.model.linear.scaling["default"] == 8 / (8**0.5) + # sanity check: the base model weights were indeed changed + assert not torch.allclose( + model.linear.weight, model_loaded.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + # save the model with conversion + peft_model.save_pretrained( + tmp_path / "pissa-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model" + ) + model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model-converted") + output_converted = model_converted(data)[0] + + assert torch.allclose(output_pissa, output_converted, atol=tol, rtol=tol) + # rank should be double of what it was initially + assert model_converted.peft_config["default"].r == 16 + assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 16 + # same scale as before with a little bit of floating point imprecision + assert model_converted.base_model.model.linear.scaling["default"] == pytest.approx(8 / (8**0.5)) + # base model weights should be the same as the initial model + assert torch.allclose( + model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + def test_pissa_rank_pattern_and_rslora_raises(self, tmp_path): + # it's not possible to determine the correct scale when using rslora with rank or alpha pattern, because the + # scale is not stored in the state_dict + model = self.get_model() + config = LoraConfig( + init_lora_weights="pissa", target_modules=["linear"], r=8, rank_pattern={"linear": 2}, use_rslora=True + ) + peft_model = get_peft_model(model, config) + peft_model.save_pretrained(tmp_path / "init-model") + + msg = re.escape("Passing `path_initial_model_for_weight_conversion` to `save_pretrained`") + with pytest.raises(ValueError, match=msg): + peft_model.save_pretrained( + tmp_path / "pissa-model", path_initial_model_for_weight_conversion=tmp_path / "init-model" + ) + + def test_pissa_alpha_pattern_and_rslora_raises(self, tmp_path): + # it's not possible to determine the correct scale when using rslora with rank or alpha pattern, because the + # scale is not stored in the state_dict + model = self.get_model() + config = LoraConfig( + init_lora_weights="pissa", target_modules=["linear"], r=8, alpha_pattern={"linear": 2}, use_rslora=True + ) + peft_model = get_peft_model(model, config) + peft_model.save_pretrained(tmp_path / "init-model") + + msg = re.escape("Passing `path_initial_model_for_weight_conversion` to `save_pretrained`") + with pytest.raises(ValueError, match=msg): + peft_model.save_pretrained( + tmp_path / "pissa-model", path_initial_model_for_weight_conversion=tmp_path / "init-model" + ) + # TODO: remove test for deprecated arg in PEFT v0.14.0 def test_lora_pissa_conversion_same_output_after_loading_with_deprecated_arg(self, data, tmp_path): model = self.get_model() @@ -423,6 +612,314 @@ def test_olora_conversion_same_output_after_loading(self, data, tmp_path): model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol ) + def test_olora_conversion_same_output_after_loading_with_rank_pattern(self, data, tmp_path): + # same as above, but using rank_pattern + model = self.get_model() + output_base = model(data)[0] + + # use rank_pattern here; note that since there is only a single linear layer, r is completely overridden + config = LoraConfig(init_lora_weights="olora", target_modules=["linear"], r=8, rank_pattern={"linear": 32}) + peft_model = get_peft_model(deepcopy(model), config) + # save the initial model + peft_model.save_pretrained(tmp_path / "init-model") + + # modify the weights, or else the adapter performs an identity transformation + peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0 + output_olora = peft_model(data)[0] + + # sanity check + tol = 1e-06 + assert not torch.allclose(output_base, output_olora, atol=tol, rtol=tol) + + # save the model normally + peft_model.save_pretrained(tmp_path / "olora-model") + model_loaded = PeftModel.from_pretrained(deepcopy(model), tmp_path / "olora-model") + output_loaded = model_loaded(data)[0] + + assert torch.allclose(output_olora, output_loaded, atol=tol, rtol=tol) + # sanity check: ranks should still be 8 as initially + assert model_loaded.peft_config["default"].r == 8 + assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 32 + # sanity check: the base model weights were indeed changed + assert not torch.allclose( + model.linear.weight, model_loaded.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + # save the model with conversion + peft_model.save_pretrained( + tmp_path / "olora-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model" + ) + model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "olora-model-converted") + output_converted = model_converted(data)[0] + + assert torch.allclose(output_olora, output_converted, atol=tol, rtol=tol) + # rank should be double of what it was initially + assert model_converted.peft_config["default"].r == 16 + assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 64 + # base model weights should be the same as the initial model + assert torch.allclose( + model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + def test_olora_conversion_same_output_after_loading_with_alpha_pattern(self, data, tmp_path): + # same as above, but using alpha_pattern + model = self.get_model() + output_base = model(data)[0] + + # use alpha_pattern here; note that since there is only a single linear layer, lora_alpha is completely + # overridden + config = LoraConfig(init_lora_weights="olora", target_modules=["linear"], alpha_pattern={"linear": 5}) + peft_model = get_peft_model(deepcopy(model), config) + # save the initial model + peft_model.save_pretrained(tmp_path / "init-model") + + # modify the weights, or else the adapter performs an identity transformation + peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0 + output_olora = peft_model(data)[0] + + # sanity check + tol = 1e-06 + assert not torch.allclose(output_base, output_olora, atol=tol, rtol=tol) + + # save the model normally + peft_model.save_pretrained(tmp_path / "olora-model") + model_loaded = PeftModel.from_pretrained(deepcopy(model), tmp_path / "olora-model") + output_loaded = model_loaded(data)[0] + + assert torch.allclose(output_olora, output_loaded, atol=tol, rtol=tol) + # sanity check: ranks should still be 8 as initially + assert model_loaded.peft_config["default"].r == 8 + assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8 + assert model_loaded.base_model.model.linear.scaling["default"] == 5 / 8 + # sanity check: the base model weights were indeed changed + assert not torch.allclose( + model.linear.weight, model_loaded.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + # save the model with conversion + peft_model.save_pretrained( + tmp_path / "olora-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model" + ) + model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "olora-model-converted") + output_converted = model_converted(data)[0] + + assert torch.allclose(output_olora, output_converted, atol=tol, rtol=tol) + # rank should be double of what it was initially + assert model_converted.peft_config["default"].r == 16 + assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 16 + assert model_converted.base_model.model.linear.scaling["default"] == 10 / 16 + # base model weights should be the same as the initial model + assert torch.allclose( + model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + def test_olora_conversion_same_output_after_loading_with_rslora(self, data, tmp_path): + # same as above, but using alpha_pattern + model = self.get_model() + output_base = model(data)[0] + + config = LoraConfig(init_lora_weights="olora", target_modules=["linear"], r=8, use_rslora=True) + peft_model = get_peft_model(deepcopy(model), config) + # save the initial model + peft_model.save_pretrained(tmp_path / "init-model") + + # modify the weights, or else the adapter performs an identity transformation + peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0 + output_olora = peft_model(data)[0] + + # sanity check + tol = 1e-06 + assert not torch.allclose(output_base, output_olora, atol=tol, rtol=tol) + + # save the model normally + peft_model.save_pretrained(tmp_path / "olora-model") + model_loaded = PeftModel.from_pretrained(deepcopy(model), tmp_path / "olora-model") + output_loaded = model_loaded(data)[0] + + assert torch.allclose(output_olora, output_loaded, atol=tol, rtol=tol) + # sanity check: ranks should still be 8 as initially + assert model_loaded.peft_config["default"].r == 8 + assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8 + assert model_loaded.base_model.model.linear.scaling["default"] == 8 / (8**0.5) + # sanity check: the base model weights were indeed changed + assert not torch.allclose( + model.linear.weight, model_loaded.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + # save the model with conversion + peft_model.save_pretrained( + tmp_path / "olora-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model" + ) + model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "olora-model-converted") + output_converted = model_converted(data)[0] + + assert torch.allclose(output_olora, output_converted, atol=tol, rtol=tol) + # rank should be double of what it was initially + assert model_converted.peft_config["default"].r == 16 + assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 16 + # same scale as before with a little bit of floating point imprecision + assert model_converted.base_model.model.linear.scaling["default"] == pytest.approx(8 / (8**0.5)) + # base model weights should be the same as the initial model + assert torch.allclose( + model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol + ) + + def test_olora_rank_pattern_and_rslora_raises(self, tmp_path): + # it's not possible to determine the correct scale when using rslora with rank or alpha pattern, because the + # scale is not stored in the state_dict + model = self.get_model() + config = LoraConfig( + init_lora_weights="olora", target_modules=["linear"], r=8, rank_pattern={"linear": 2}, use_rslora=True + ) + peft_model = get_peft_model(model, config) + peft_model.save_pretrained(tmp_path / "init-model") + + msg = re.escape("Passing `path_initial_model_for_weight_conversion` to `save_pretrained`") + with pytest.raises(ValueError, match=msg): + peft_model.save_pretrained( + tmp_path / "olora-model", path_initial_model_for_weight_conversion=tmp_path / "init-model" + ) + + def test_olora_alpha_pattern_and_rslora_raises(self, tmp_path): + # it's not possible to determine the correct scale when using rslora with rank or alpha pattern, because the + # scale is not stored in the state_dict + model = self.get_model() + config = LoraConfig( + init_lora_weights="olora", target_modules=["linear"], r=8, alpha_pattern={"linear": 2}, use_rslora=True + ) + peft_model = get_peft_model(model, config) + peft_model.save_pretrained(tmp_path / "init-model") + + msg = re.escape("Passing `path_initial_model_for_weight_conversion` to `save_pretrained`") + with pytest.raises(ValueError, match=msg): + peft_model.save_pretrained( + tmp_path / "olora-model", path_initial_model_for_weight_conversion=tmp_path / "init-model" + ) + + @pytest.mark.parametrize( + "config_kwargs, should_warn", + [ + # no warning + ({"init_lora_weights": "pissa", "target_modules": ["linear"]}, False), + ({"init_lora_weights": "pissa_niter_3", "target_modules": ["linear"]}, False), + ({"init_lora_weights": "olora", "target_modules": ["linear"]}, False), + ({"init_lora_weights": "pissa", "target_modules": ["linear"], "use_rslora": True}, False), + ({"init_lora_weights": "pissa_niter_3", "target_modules": ["linear"], "use_rslora": True}, False), + ({"init_lora_weights": "olora", "target_modules": ["linear"], "use_rslora": True}, False), + ({"init_lora_weights": "pissa", "target_modules": ["linear"], "rank_pattern": {"linear": 8}}, False), + ( + {"init_lora_weights": "pissa_niter_3", "target_modules": ["linear"], "rank_pattern": {"linear": 8}}, + False, + ), + ({"init_lora_weights": "olora", "target_modules": ["linear"], "rank_pattern": {"linear": 8}}, False), + ({"init_lora_weights": "pissa", "target_modules": ["linear"], "alpha_pattern": {"linear": 8}}, False), + ( + {"init_lora_weights": "pissa_niter_3", "target_modules": ["linear"], "alpha_pattern": {"linear": 8}}, + False, + ), + ({"init_lora_weights": "olora", "target_modules": ["linear"], "alpha_pattern": {"linear": 8}}, False), + # warning + ( + { + "init_lora_weights": "pissa", + "target_modules": ["linear"], + "use_rslora": True, + "rank_pattern": {"linear": 8}, + }, + True, + ), + ( + { + "init_lora_weights": "pissa_niter_3", + "target_modules": ["linear"], + "use_rslora": True, + "rank_pattern": {"linear": 8}, + }, + True, + ), + ( + { + "init_lora_weights": "olora", + "target_modules": ["linear"], + "use_rslora": True, + "rank_pattern": {"linear": 8}, + }, + True, + ), + ( + { + "init_lora_weights": "pissa", + "target_modules": ["linear"], + "use_rslora": True, + "alpha_pattern": {"linear": 8}, + }, + True, + ), + ( + { + "init_lora_weights": "pissa_niter_3", + "target_modules": ["linear"], + "use_rslora": True, + "alpha_pattern": {"linear": 8}, + }, + True, + ), + ( + { + "init_lora_weights": "olora", + "target_modules": ["linear"], + "use_rslora": True, + "alpha_pattern": {"linear": 8}, + }, + True, + ), + ( + { + "init_lora_weights": "pissa", + "target_modules": ["linear"], + "use_rslora": True, + "rank_pattern": {"linear": 8}, + "alpha_pattern": {"linear": 8}, + }, + True, + ), + ( + { + "init_lora_weights": "pissa_niter_3", + "target_modules": ["linear"], + "use_rslora": True, + "rank_pattern": {"linear": 8}, + "alpha_pattern": {"linear": 8}, + }, + True, + ), + ( + { + "init_lora_weights": "olora", + "target_modules": ["linear"], + "use_rslora": True, + "rank_pattern": {"linear": 8}, + "alpha_pattern": {"linear": 8}, + }, + True, + ), + ], + ) + def test_lora_config_pissa_olora_warns(self, config_kwargs, should_warn, recwarn): + # Using post training conversion of modified base weights to restore their initial values (PiSSA, OLoRA) cannot + # be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends + # this when they'll eventually call save_pretrained (i.e. if they'll pass + # path_initial_model_for_weight_conversionl). Therefore, we only warn but don't raise an error here. + msg = re.escape("Using Rank-Stabilized LoRA with rank_pattern/alpha_pattern and post-training conversion") + if should_warn: + LoraConfig(**config_kwargs) + assert len(recwarn.list) == 1 + with pytest.warns(UserWarning, match=msg): + LoraConfig(**config_kwargs) + else: + LoraConfig(**config_kwargs) + assert not recwarn.list + def test_lora_rslora_scaling(self): # default is True torch.manual_seed(0)