Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX PiSSA & OLoRA with rank/alpha pattern, rslora #1930

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ def save_pretrained(
path_initial_model_for_weight_conversion = convert_pissa_to_lora

def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs):
if peft_config.use_rslora and (peft_config.rank_pattern or peft_config.alpha_pattern):
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
msg = (
"Passing `path_initial_model_for_weight_conversion` to `save_pretrained` is not supported when "
"using `rank_pattern` or `alpha_pattern` at the same time as `use_rslora=True`."
)
raise ValueError(msg)

if not any(
str(peft_config.init_lora_weights).lower().startswith(prefix) for prefix in ["pissa", "olora", "true"]
):
Expand Down Expand Up @@ -368,7 +375,17 @@ def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion,
if path_initial_model_for_weight_conversion is not None:
peft_config.init_lora_weights = True
peft_config.r *= 2
peft_config.lora_alpha *= 2
if not peft_config.use_rslora:
peft_config.lora_alpha *= 2
else:
# with rslora, we have scaling = alpha / sqrt(r), we thus adjust alpha to keep the same scaling
peft_config.lora_alpha *= 2**0.5

if peft_config.rank_pattern:
peft_config.rank_pattern = {key: 2 * val for key, val in peft_config.rank_pattern.items()}
if peft_config.alpha_pattern:
peft_config.alpha_pattern = {key: 2 * val for key, val in peft_config.alpha_pattern.items()}

peft_config.save_pretrained(output_dir, auto_mapping_dict=auto_mapping_dict)
peft_config.inference_mode = inference_mode

Expand Down
21 changes: 21 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional, Union

Expand Down Expand Up @@ -346,6 +347,26 @@ def __post_init__(self):
if self.loftq_config is None:
raise ValueError("`loftq_config` must be specified when `init_lora_weights` is 'loftq'.")

# Using post training conversion of modified base weights to restore their initial values (PiSSA, OLoRA) cannot
# be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends
# this when they'll eventually call save_pretrained (i.e. if they'll pass
# path_initial_model_for_weight_conversionl). Therefore, we only warn but don't raise an error here.
if (
self.use_rslora
and (self.rank_pattern or self.alpha_pattern)
and (
(isinstance(self.init_lora_weights, str) and (self.init_lora_weights.startswith("pissa")))
or (self.init_lora_weights == "olora")
)
):
msg = (
"Using Rank-Stabilized LoRA with rank_pattern/alpha_pattern and post-training conversion of modified "
"base weights (PiSSA, OLoRA) means that you won't be able to pass "
"`path_initial_model_for_weight_conversion` to `save_pretrained` to restore the initial values of the "
"base weights; if you intend to do this, please ensure not to use rslora or rank_pattern/alpha_pattern."
)
warnings.warn(msg)
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

# convert loftq_config to dict
if self.loftq_config and not isinstance(self.loftq_config, dict):
self.loftq_config = vars(self.loftq_config)
Expand Down
6 changes: 5 additions & 1 deletion tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,7 @@ def test_offload_merge(self):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
@pytest.mark.single_gpu_tests
class TestPiSSA:
r"""
Tests for PiSSA to ensure that it reduces the quantization error compared to normal LoRA quantization.
Expand Down Expand Up @@ -1656,7 +1657,9 @@ def forward(self, x):
assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8

# save the model with conversion
peft_model.save_pretrained(tmp_path / "pissa-model-converted", convert_mutated_to_lora=tmp_path / "init-model")
peft_model.save_pretrained(
tmp_path / "pissa-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model"
)
model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model-converted")
output_converted = model_converted(data)[0]

Expand All @@ -1672,6 +1675,7 @@ def forward(self, x):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
@pytest.mark.single_gpu_tests
class TestOLoRA:
r"""
Tests for OLoRA to ensure that it reduces the quantization error compared to normal LoRA quantization.
Expand Down
Loading
Loading