Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed multirank + multialpha for sequential LoRAs, added correct support of LoRA-C3Lier conversion #937

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/conceptual_guides/lora.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ As with other methods supported by PEFT, to fine-tune a model using LoRA, you ne
- `modules_to_save`: List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task.
- `layers_to_transform`: List of layers to be transformed by LoRA. If not specified, all layers in `target_modules` are transformed.
- `layers_pattern`: Pattern to match layer names in `target_modules`, if `layers_to_transform` is specified. By default `PeftModel` will look at common layer pattern (`layers`, `h`, `blocks`, etc.), use it for exotic and custom models.
- `rank_pattern`: The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`.
- `alpha_pattern`: The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `lora_alpha`.

## LoRA examples

Expand Down
247 changes: 127 additions & 120 deletions examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import argparse
import os
import re
from typing import Callable, List, Optional, Union
from dataclasses import dataclass
from typing import Dict, Optional

import safetensors
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel

Expand All @@ -21,44 +20,66 @@
LORA_PREFIX_TEXT_ENCODER = "lora_te"


def get_modules_names(
root_module: nn.Module,
target_replace_modules_linear: Optional[List[str]] = [],
target_replace_modules_conv2d: Optional[List[str]] = [],
):
# Combine replacement modules
target_replace_modules = target_replace_modules_linear + target_replace_modules_conv2d

# Store result
modules_names = set()
# https://github.com/kohya-ss/sd-scripts/blob/c924c47f374ac1b6e33e71f82948eb1853e2243f/networks/lora.py#L720
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
if len(name) == 0:
continue
for child_name, child_module in module.named_modules():
if len(child_name) == 0:
continue
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"

if (is_linear and module.__class__.__name__ in target_replace_modules_linear) or (
is_conv2d and module.__class__.__name__ in target_replace_modules_conv2d
):
modules_names.add(f"{name}.{child_name}")

return sorted(modules_names)


def get_rank_alpha(
layer_names: List[str],
value_getter: Callable[[str], Union[int, float]],
filter_string: str,
) -> Union[int, float]:
values = [value_getter(p) for p in filter(lambda x: bool(re.search(filter_string, x)), layer_names)]
value = values[0]
assert all(v == value for v in values), f"All LoRA ranks and alphas must be same, found: {values}"
return value
@dataclass
class LoRAInfo:
kohya_key: str
peft_key: str
alpha: Optional[float] = None
rank: Optional[int] = None
lora_A: Optional[torch.Tensor] = None
lora_B: Optional[torch.Tensor] = None

def peft_state_dict(self) -> Dict[str, torch.Tensor]:
if self.lora_A is None or self.lora_B is None:
raise ValueError("One of weights is not present - either lora_A or lora_B")
return {f"{peft_key}.lora_A.weight": self.lora_A, f"{peft_key}.lora_B.weight": self.lora_A}


def construct_peft_loraconfig(info: Dict[str, LoRAInfo]) -> LoraConfig:
"""Constructs LoraConfig from data extracted from kohya checkpoint

Args:
info (Dict[str, LoRAInfo]): Information extracted from kohya checkpoint

Returns:
LoraConfig: config for constructing LoRA
"""

# Unpack all ranks and alphas
ranks = {x[0]: x[1].rank for x in info.items()}
alphas = {x[0]: x[1].alpha or x[1].rank for x in info.items()}

# Determine which modules needs to be transformed
target_modules = list(info.keys())

# Determine most common rank and alpha
r = max(set(ranks.values()), key=list(ranks.values()).count)
lora_alpha = max(set(alphas.values()), key=list(alphas.values()).count)

# Determine which modules have different rank and alpha
rank_pattern = dict(filter(lambda x: x[1] != r, ranks.items()))
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, alphas.items()))

config = LoraConfig(
r=r,
lora_alpha=lora_alpha,
target_modules=target_modules,
lora_dropout=0.0,
bias="none",
init_lora_weights=False,
rank_pattern=rank_pattern,
alpha_pattern=alpha_pattern,
)

return config


def combine_peft_state_dict(info: Dict[str, LoRAInfo]) -> Dict[str, torch.Tensor]:
result = {}
for key_name, key_info in info.items():
result[f"base_model.model.{key_name}.lora_A.weight"] = key_info.lora_A
result[f"base_model.model.{key_name}.lora_B.weight"] = key_info.lora_B
return result


if __name__ == "__main__":
Expand All @@ -75,93 +96,79 @@ def get_rank_alpha(
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
args = parser.parse_args()

# Find text encoder modules to add LoRA to
# Load all models that we need to add adapter to
text_encoder = CLIPTextModel.from_pretrained(args.sd_checkpoint, subfolder="text_encoder")
text_encoder_modules_names = get_modules_names(
text_encoder, target_replace_modules_linear=TEXT_ENCODER_TARGET_REPLACE_MODULE
)

# Find unet2d modules to add LoRA to
unet = UNet2DConditionModel.from_pretrained(args.sd_checkpoint, subfolder="unet")
unet_modules_names = get_modules_names(
unet,
target_replace_modules_linear=UNET_TARGET_REPLACE_MODULE,
target_replace_modules_conv2d=UNET_TARGET_REPLACE_MODULE,
)

# Construct possible mapping from kohya keys to peft keys
models_keys = {}
for model, model_key, model_name in [
(text_encoder, LORA_PREFIX_TEXT_ENCODER, "text_encoder"),
(unet, LORA_PREFIX_UNET, "unet"),
]:
models_keys.update(
{
f"{model_key}.{peft_key}".replace(".", "_"): peft_key
for peft_key in (x[0] for x in model.named_modules())
}
)

# Store conversion info (model_type -> peft_key -> LoRAInfo)
lora_info: Dict[str, Dict[str, LoRAInfo]] = {
"text_encoder": {},
"unet": {},
}

# Open kohya_ss checkpoint
with safetensors.safe_open(args.kohya_lora_path, framework="pt", device="cpu") as f:
# Extract information about LoRA structure
metadata = f.metadata()
if (metadata is not None) and ("ss_network_dim" in metadata) and ("ss_network_alpha" in metadata):
# LoRA rank and alpha are in safetensors metadata, just get it
lora_r = lora_text_encoder_r = int(metadata["ss_network_dim"])
lora_alpha = lora_text_encoder_alpha = float(metadata["ss_network_alpha"])
else:
# LoRA rank and alpha are not present, so infer them
lora_r = get_rank_alpha(
f.keys(), lambda n: f.get_tensor(n).size(0), f"^{LORA_PREFIX_UNET}\w+\.lora_down\.weight$"
)
lora_text_encoder_r = get_rank_alpha(
f.keys(), lambda n: f.get_tensor(n).size(0), f"^{LORA_PREFIX_TEXT_ENCODER}\w+\.lora_down\.weight$"
)
lora_alpha = get_rank_alpha(f.keys(), lambda n: f.get_tensor(n).item(), f"^{LORA_PREFIX_UNET}\w+\.alpha$")
lora_text_encoder_alpha = get_rank_alpha(
f.keys(), lambda n: f.get_tensor(n).item(), f"^{LORA_PREFIX_TEXT_ENCODER}\w+\.alpha$"
)

# Create LoRA for text encoder
text_encoder_config = LoraConfig(
r=lora_text_encoder_r,
lora_alpha=lora_text_encoder_alpha,
target_modules=text_encoder_modules_names,
lora_dropout=0.0,
bias="none",
)
text_encoder = get_peft_model(text_encoder, text_encoder_config)
text_encoder_lora_state_dict = {x: None for x in get_peft_model_state_dict(text_encoder).keys()}

# Load text encoder values from kohya_ss LoRA
for peft_te_key in text_encoder_lora_state_dict.keys():
kohya_ss_te_key = peft_te_key.replace("base_model.model", LORA_PREFIX_TEXT_ENCODER)
kohya_ss_te_key = kohya_ss_te_key.replace("lora_A", "lora_down")
kohya_ss_te_key = kohya_ss_te_key.replace("lora_B", "lora_up")
kohya_ss_te_key = kohya_ss_te_key.replace(".", "_", kohya_ss_te_key.count(".") - 2)
text_encoder_lora_state_dict[peft_te_key] = f.get_tensor(kohya_ss_te_key).to(text_encoder.dtype)

# Load converted kohya_ss text encoder LoRA back to PEFT
set_peft_model_state_dict(text_encoder, text_encoder_lora_state_dict)
# Iterate through available info and unpack all the values
for key in f.keys():
kohya_key, kohya_type = key.split(".")[:2]

# Find which model this key belongs to
if kohya_key.startswith(LORA_PREFIX_TEXT_ENCODER):
model_type = "text_encoder"
elif kohya_key.startswith(LORA_PREFIX_UNET):
model_type = "unet"
else:
raise ValueError(f"Cannot determine model for key: {key}")

# Find corresponding peft key
if kohya_key not in models_keys:
raise ValueError(f"Cannot find corresponding key for diffusers/transformers model: {kohya_key}")
peft_key = models_keys[kohya_key]

if peft_key not in lora_info[model_type]:
lora_info[model_type][peft_key] = LoRAInfo(kohya_key=kohya_key, peft_key=peft_key)

if kohya_type == "alpha":
lora_info[model_type][peft_key].alpha = f.get_tensor(key).item()
elif kohya_type == "lora_down":
tensor = f.get_tensor(key)
lora_info[model_type][peft_key].lora_A = tensor
lora_info[model_type][peft_key].rank = tensor.shape[0]
elif kohya_type == "lora_up":
tensor = f.get_tensor(key)
lora_info[model_type][peft_key].lora_B = f.get_tensor(key)
lora_info[model_type][peft_key].rank = tensor.shape[1]
else:
raise ValueError(f"Unknown weight name in key: {key} - {kohya_type}")

# Process each model
for model, model_name in [(text_encoder, "text_encoder"), (unet, "unet")]:
config = construct_peft_loraconfig(lora_info[model_name])
model = get_peft_model(model, config)

keys_peft = list(get_peft_model_state_dict(model).keys())
keys_new = list(combine_peft_state_dict(lora_info[model_name]).keys())

set_peft_model_state_dict(model, combine_peft_state_dict(lora_info[model_name]))

if args.half:
text_encoder.to(torch.float16)

# Save text encoder result
text_encoder.save_pretrained(
os.path.join(args.dump_path, "text_encoder"),
)
model.to(torch.float16)

# Create LoRA for unet2d
unet_config = LoraConfig(
r=lora_r, lora_alpha=lora_alpha, target_modules=unet_modules_names, lora_dropout=0.0, bias="none"
)
unet = get_peft_model(unet, unet_config)
unet_lora_state_dict = {x: None for x in get_peft_model_state_dict(unet).keys()}

# Load unet2d values from kohya_ss LoRA
for peft_unet_key in unet_lora_state_dict.keys():
kohya_ss_unet_key = peft_unet_key.replace("base_model.model", LORA_PREFIX_UNET)
kohya_ss_unet_key = kohya_ss_unet_key.replace("lora_A", "lora_down")
kohya_ss_unet_key = kohya_ss_unet_key.replace("lora_B", "lora_up")
kohya_ss_unet_key = kohya_ss_unet_key.replace(".", "_", kohya_ss_unet_key.count(".") - 2)
unet_lora_state_dict[peft_unet_key] = f.get_tensor(kohya_ss_unet_key).to(unet.dtype)

# Load converted kohya_ss unet LoRA back to PEFT
set_peft_model_state_dict(unet, unet_lora_state_dict)

if args.half:
unet.to(torch.float16)

# Save text encoder result
unet.save_pretrained(
os.path.join(args.dump_path, "unet"),
)
# Save model to disk
model.save_pretrained(os.path.join(args.dump_path, model_name))
10 changes: 8 additions & 2 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class LoraConfig(PeftConfig):
layers_pattern (`str`):
The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer
pattern is not in the common layers pattern.
rank_pattern (`dict`):
The mapping from layer names or regexp expression to ranks which are different from the default rank
specified by `r`.
alpha_pattern (`dict`):
The mapping from layer names or regexp expression to alphas which are different from the default alpha
specified by `lora_alpha`.
"""

r: int = field(default=8, metadata={"help": "Lora attention dimension"})
Expand Down Expand Up @@ -95,7 +101,7 @@ class LoraConfig(PeftConfig):
default_factory=dict,
metadata={
"help": (
"The mapping from layer names to ranks which are different from the default rank specified by `r`. "
"The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. "
"For example, `{model.decoder.layers.0.encoder_attn.k_proj: 8`}"
)
},
Expand All @@ -104,7 +110,7 @@ class LoraConfig(PeftConfig):
default_factory=dict,
metadata={
"help": (
"The mapping from layer names to alphas which are different from the default alpha specified by `lora_alpha`. "
"The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `lora_alpha`. "
"For example, `{model.decoder.layers.0.encoder_attn.k_proj: 32`}"
)
},
Expand Down
9 changes: 7 additions & 2 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import warnings
from dataclasses import asdict, replace
from enum import Enum
from itertools import chain

import torch
from torch import nn
Expand Down Expand Up @@ -161,9 +162,13 @@ def _create_and_replace(
parent,
**optional_kwargs,
):
# Regexp matching - Find key which matches current target_name in patterns provided
current_key = optional_kwargs["current_key"]
r = lora_config.rank_pattern.get(current_key, lora_config.r)
alpha = lora_config.alpha_pattern.get(current_key, lora_config.lora_alpha)
pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys()))
target_name_key = next(filter(lambda key: re.match(f".*\.{key}$", current_key), pattern_keys), target_name)

r = lora_config.rank_pattern.get(target_name_key, lora_config.r)
alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha)
bias = hasattr(target, "bias") and target.bias is not None
kwargs = {
"r": r,
Expand Down
32 changes: 32 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,35 @@ def run_with_disable(config_kwargs, bias):
@parameterized.expand(TEST_CASES)
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)


class TestMultiRankAdapter(unittest.TestCase):
"""Tests related to multirank LoRA adapters"""

def test_multirank(self):
config_1 = LoraConfig(
r=8,
lora_alpha=8,
init_lora_weights=False,
target_modules=["lin0", "lin1"],
)
config_2 = LoraConfig(
r=8,
lora_alpha=8,
init_lora_weights=False,
target_modules=["lin0", "lin1"],
rank_pattern={"lin0": 4},
alpha_pattern={"lin0": 4},
)

# Add first adapter
model = get_peft_model(MLP(), config_1, adapter_name="first")

# Add second adapter
model.add_adapter("second", config_2)

# Extract current and expected ranks
rank_current = model.lin0.lora_A["second"].weight.shape[0]
rank_expected = config_2.rank_pattern["lin0"]

self.assertTrue(rank_current == rank_expected, f"Rank {rank_current} is not equal to expected {rank_expected}")