Skip to content

Commit

Permalink
FIX PiSSA & OLoRA with rank/alpha pattern, rslora (#1930)
Browse files Browse the repository at this point in the history
* FIX PiSSA & OLoRA with rank/alpha pattern, rslora

See #1929 (comment)

At the moment, when using PiSSA or OLoRA with weight conversion to
restore the original base weights, there is an error when either of
rank_pattern, alpha_pattern, or rslora is being used. This PR fixes
this.

The issue is that we need to double the rank of the LoRA adapter. Right
now, this is done by simply doubling r and alpha. But if rank_pattern
and alpha_pattern are being used, those need to be doubled too.

Furthermore, when using rslora, the scaling is again different, namely
alpha / sqrt(r). This also needs to be adjusted.

Unfortunately, when using rslora with rank_pattern and alpha_pattern,
this gets way more complicated. Since we don't store the scaling in the
state_dict, we would have to resolve all the patterns here to determine
the correct scaling, i.e. reimplement the whole matching and init logic.
This is a lot of work for a very edgy edge case.

Therefore, I opted to raise an error instead. This is not super nice, as
the error is only raised when trying to save the model, i.e. a lot of
time may already have been spent to train the model. But we cannot know
this earlier, so not much can be done.

Overall, this fix is ugly because it further couples unrelated code. For
instance, if we add new init methods that affect the scaling, we need to
remember to change the saving logic accordingly. If anyone has a better
idea, LMK.

* Make style

* Also warn during init if there is a potential

... for saving not to work

* Ensure that GPU tests are run for PiSSA+OLoRA

* Use renamed argument name

* Make style

* Reviewer feedback: Better document the change

* Add clarifying comments to tests
  • Loading branch information
BenjaminBossan authored Jul 19, 2024
1 parent 5268495 commit e02b938
Show file tree
Hide file tree
Showing 6 changed files with 548 additions and 7 deletions.
4 changes: 2 additions & 2 deletions examples/olora_finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ train(olora_model) # Your training loop
#Save the model after training
olora_model.save_pretrained(output_dir, path_initial_model_for_weight_conversion=init_path)
```
After completing training, you can save and convert your OLoRA model to a conventional LoRA model by setting `path_initial_model_for_weight_conversion` to `init_path`, that is the path of your untrained OLoRA model. This conversion enables you to use multiple adapters with your LoRA model.
After completing training, you can save and convert your OLoRA model to a conventional LoRA model by setting `path_initial_model_for_weight_conversion` to `init_path`, that is the path of your untrained OLoRA model. This conversion enables you to use multiple adapters with your LoRA model. Note that this conversion is not supported if `rslora` is used in combination with `rank_pattern` or `alpha_pattern`.

## Citation
```
Expand All @@ -81,4 +81,4 @@ After completing training, you can save and convert your OLoRA model to a conven
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
```
4 changes: 2 additions & 2 deletions examples/pissa_finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ peft_model = PeftModel.from_pretrained(model, "pissa-llama-2-7b-lora")
```
Utilizing the converted LoRA does not require modifying the parameters of the base model. When multiple converted LoRAs are needed simultaneously, each adapter operates independently without interference, allowing for the adapters to be freely deleted or added.


Note that this conversion is not supported if `rslora` is used in combination with `rank_pattern` or `alpha_pattern`.

### Fine-tune in 4-bit or 8-bit
If quantization fine-tuning is desired, it is necessary to first decompose the original model at full precision and then reload the residual model in either 4-bit or 8-bit configurations.
Expand Down Expand Up @@ -128,4 +128,4 @@ This approach ensures the preservation of high-frequency, out-of-distribution pa
journal={arXiv preprint arXiv:2404.02948},
year={2024}
}
```
```
23 changes: 21 additions & 2 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,11 @@ def save_pretrained(
difference in adapter before and after fine-tuning is calculated. This difference can be represented as
the parameters of a standard LoRA adapter. Using this converted adapter does not require changes to the
base model, thus conveniently allowing the use of multiple PiSSA or OLoRA adapters with LoRA adapters,
and the activation or deactivation of any adapters.
and the activation or deactivation of any adapters. Note that this conversion is not supported if
`rslora` is used in combination with `rank_pattern` or `alpha_pattern`.
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments passed along to the `push_to_hub` method.
"""
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
Expand All @@ -258,6 +260,13 @@ def save_pretrained(
path_initial_model_for_weight_conversion = convert_pissa_to_lora

def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs):
if peft_config.use_rslora and (peft_config.rank_pattern or peft_config.alpha_pattern):
msg = (
"Passing `path_initial_model_for_weight_conversion` to `save_pretrained` is not supported when "
"using `rank_pattern` or `alpha_pattern` at the same time as `use_rslora=True`."
)
raise ValueError(msg)

if not any(
str(peft_config.init_lora_weights).lower().startswith(prefix) for prefix in ["pissa", "olora", "true"]
):
Expand Down Expand Up @@ -368,7 +377,17 @@ def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion,
if path_initial_model_for_weight_conversion is not None:
peft_config.init_lora_weights = True
peft_config.r *= 2
peft_config.lora_alpha *= 2
if not peft_config.use_rslora:
peft_config.lora_alpha *= 2
else:
# with rslora, we have scaling = alpha / sqrt(r), we thus adjust alpha to keep the same scaling
peft_config.lora_alpha *= 2**0.5

if peft_config.rank_pattern:
peft_config.rank_pattern = {key: 2 * val for key, val in peft_config.rank_pattern.items()}
if peft_config.alpha_pattern:
peft_config.alpha_pattern = {key: 2 * val for key, val in peft_config.alpha_pattern.items()}

peft_config.save_pretrained(output_dir, auto_mapping_dict=auto_mapping_dict)
peft_config.inference_mode = inference_mode

Expand Down
21 changes: 21 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional, Union

Expand Down Expand Up @@ -346,6 +347,26 @@ def __post_init__(self):
if self.loftq_config is None:
raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.")

# Using post training conversion of modified base weights to restore their initial values (PiSSA, OLoRA) cannot
# be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends
# this when they'll eventually call save_pretrained (i.e. if they'll pass
# path_initial_model_for_weight_conversionl). Therefore, we only warn but don't raise an error here.
if (
self.use_rslora
and (self.rank_pattern or self.alpha_pattern)
and (
(isinstance(self.init_lora_weights, str) and (self.init_lora_weights.startswith("pissa")))
or (self.init_lora_weights == "olora")
)
):
msg = (
"Using Rank-Stabilized LoRA with rank_pattern/alpha_pattern and post-training conversion of modified "
"base weights (PiSSA, OLoRA) means that you won't be able to pass "
"`path_initial_model_for_weight_conversion` to `save_pretrained` to restore the initial values of the "
"base weights; if you intend to do this, please ensure not to use rslora or rank_pattern/alpha_pattern."
)
warnings.warn(msg)

# convert loftq_config to dict
if self.loftq_config and not isinstance(self.loftq_config, dict):
self.loftq_config = vars(self.loftq_config)
Expand Down
6 changes: 5 additions & 1 deletion tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,7 @@ def test_offload_merge(self):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
@pytest.mark.single_gpu_tests
class TestPiSSA:
r"""
Tests for PiSSA to ensure that it reduces the quantization error compared to normal LoRA quantization.
Expand Down Expand Up @@ -1656,7 +1657,9 @@ def forward(self, x):
assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8

# save the model with conversion
peft_model.save_pretrained(tmp_path / "pissa-model-converted", convert_mutated_to_lora=tmp_path / "init-model")
peft_model.save_pretrained(
tmp_path / "pissa-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model"
)
model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model-converted")
output_converted = model_converted(data)[0]

Expand All @@ -1672,6 +1675,7 @@ def forward(self, x):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
@pytest.mark.single_gpu_tests
class TestOLoRA:
r"""
Tests for OLoRA to ensure that it reduces the quantization error compared to normal LoRA quantization.
Expand Down
Loading

0 comments on commit e02b938

Please sign in to comment.