From 0817a3d7256511b88a5faa6717558a10d92763ed Mon Sep 17 00:00:00 2001 From: Alexander Kovalchuk Date: Fri, 15 Sep 2023 17:03:49 +0300 Subject: [PATCH 1/4] Fixed multirank multialpha for sequential loras, added tests, fixed docs --- docs/source/conceptual_guides/lora.mdx | 2 ++ src/peft/tuners/lora/config.py | 8 +++++-- src/peft/tuners/lora/model.py | 9 ++++++-- tests/test_custom_models.py | 32 ++++++++++++++++++++++++++ 4 files changed, 47 insertions(+), 4 deletions(-) diff --git a/docs/source/conceptual_guides/lora.mdx b/docs/source/conceptual_guides/lora.mdx index ff028ca4bd..4f3027241c 100644 --- a/docs/source/conceptual_guides/lora.mdx +++ b/docs/source/conceptual_guides/lora.mdx @@ -77,6 +77,8 @@ As with other methods supported by PEFT, to fine-tune a model using LoRA, you ne - `modules_to_save`: List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task. - `layers_to_transform`: List of layers to be transformed by LoRA. If not specified, all layers in `target_modules` are transformed. - `layers_pattern`: Pattern to match layer names in `target_modules`, if `layers_to_transform` is specified. By default `PeftModel` will look at common layer pattern (`layers`, `h`, `blocks`, etc.), use it for exotic and custom models. +- `rank_pattern`: The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. +- `alpha_pattern`: The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `lora_alpha`. ## LoRA examples diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index c219d721fd..dff66822f5 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -45,6 +45,10 @@ class LoraConfig(PeftConfig): layers_pattern (`str`): The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer pattern is not in the common layers pattern. + rank_pattern (`dict`): + The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. + alpha_pattern (`dict`): + The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `lora_alpha`. """ r: int = field(default=8, metadata={"help": "Lora attention dimension"}) @@ -95,7 +99,7 @@ class LoraConfig(PeftConfig): default_factory=dict, metadata={ "help": ( - "The mapping from layer names to ranks which are different from the default rank specified by `r`. " + "The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. " "For example, `{model.decoder.layers.0.encoder_attn.k_proj: 8`}" ) }, @@ -104,7 +108,7 @@ class LoraConfig(PeftConfig): default_factory=dict, metadata={ "help": ( - "The mapping from layer names to alphas which are different from the default alpha specified by `lora_alpha`. " + "The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `lora_alpha`. " "For example, `{model.decoder.layers.0.encoder_attn.k_proj: 32`}" ) }, diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index b3a7b8ab9e..83641ea456 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -16,6 +16,7 @@ import warnings from dataclasses import asdict, replace from enum import Enum +from itertools import chain import torch from torch import nn @@ -161,9 +162,13 @@ def _create_and_replace( parent, **optional_kwargs, ): + # Regexp matching - Find key which matches current target_name in patterns provided current_key = optional_kwargs["current_key"] - r = lora_config.rank_pattern.get(current_key, lora_config.r) - alpha = lora_config.alpha_pattern.get(current_key, lora_config.lora_alpha) + pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys())) + target_name_key = next(filter(lambda key: re.match(f".*\.{key}$", current_key), pattern_keys), target_name) + + r = lora_config.rank_pattern.get(target_name_key, lora_config.r) + alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha) bias = hasattr(target, "bias") and target.bias is not None kwargs = { "r": r, diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index a801b30e03..4690e9ec4a 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -368,3 +368,35 @@ def run_with_disable(config_kwargs, bias): @parameterized.expand(TEST_CASES) def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) + + +class TestMultiRankAdapter(unittest.TestCase): + """Tests related to multirank LoRA adapters""" + + def test_multirank(self): + config_1 = LoraConfig( + r=8, + lora_alpha=8, + init_lora_weights=False, + target_modules=["lin0", "lin1"], + ) + config_2 = LoraConfig( + r=8, + lora_alpha=8, + init_lora_weights=False, + target_modules=["lin0", "lin1"], + rank_pattern={"lin0": 4}, + alpha_pattern={"lin0": 4}, + ) + + # Add first adapter + model = get_peft_model(MLP(), config_1, adapter_name="first") + + # Add second adapter + model.add_adapter("second", config_2) + + # Extract current and expected ranks + rank_current = model.lin0.lora_A["second"].weight.shape[0] + rank_expected = config_2.rank_pattern["lin0"] + + self.assertTrue(rank_current == rank_expected, f"Rank {rank_current} is not equal to expected {rank_expected}") From 2009aa55c2a0c2865346f279c876284bd726fd91 Mon Sep 17 00:00:00 2001 From: Alexander Kovalchuk Date: Fri, 15 Sep 2023 17:04:30 +0300 Subject: [PATCH 2/4] Refactored kohya_ss conversion script for proper support of LoRA-C3Lier --- .../convert_kohya_ss_sd_lora_to_peft.py | 251 +++++++++--------- 1 file changed, 131 insertions(+), 120 deletions(-) diff --git a/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py b/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py index 21e9204e49..5f59e8b87f 100644 --- a/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py +++ b/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py @@ -1,15 +1,15 @@ import argparse import os -import re -from typing import Callable, List, Optional, Union +from dataclasses import dataclass +from typing import Dict, Optional import safetensors import torch -import torch.nn as nn from diffusers import UNet2DConditionModel from transformers import CLIPTextModel from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict +from peft.tuners.lora import LoraConfig # Default kohya_ss LoRA replacement modules @@ -21,44 +21,69 @@ LORA_PREFIX_TEXT_ENCODER = "lora_te" -def get_modules_names( - root_module: nn.Module, - target_replace_modules_linear: Optional[List[str]] = [], - target_replace_modules_conv2d: Optional[List[str]] = [], -): - # Combine replacement modules - target_replace_modules = target_replace_modules_linear + target_replace_modules_conv2d - - # Store result - modules_names = set() - # https://github.com/kohya-ss/sd-scripts/blob/c924c47f374ac1b6e33e71f82948eb1853e2243f/networks/lora.py#L720 - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - if len(name) == 0: - continue - for child_name, child_module in module.named_modules(): - if len(child_name) == 0: - continue - is_linear = child_module.__class__.__name__ == "Linear" - is_conv2d = child_module.__class__.__name__ == "Conv2d" - - if (is_linear and module.__class__.__name__ in target_replace_modules_linear) or ( - is_conv2d and module.__class__.__name__ in target_replace_modules_conv2d - ): - modules_names.add(f"{name}.{child_name}") - - return sorted(modules_names) - - -def get_rank_alpha( - layer_names: List[str], - value_getter: Callable[[str], Union[int, float]], - filter_string: str, -) -> Union[int, float]: - values = [value_getter(p) for p in filter(lambda x: bool(re.search(filter_string, x)), layer_names)] - value = values[0] - assert all(v == value for v in values), f"All LoRA ranks and alphas must be same, found: {values}" - return value +@dataclass +class LoRAInfo: + kohya_key: str + peft_key: str + alpha: Optional[float] = None + rank: Optional[int] = None + lora_A: Optional[torch.Tensor] = None + lora_B: Optional[torch.Tensor] = None + + def peft_state_dict(self) -> Dict[str, torch.Tensor]: + if self.lora_A is None or self.lora_B is None: + raise ValueError("One of weights is not present - either lora_A or lora_B") + return {f"{peft_key}.lora_A.weight": self.lora_A, f"{peft_key}.lora_B.weight": self.lora_A} + + +def construct_peft_loraconfig(info: Dict[str, LoRAInfo]) -> LoraConfig: + """Constructs LoraConfig from data extracted from kohya checkpoint + + Args: + info (Dict[str, LoRAInfo]): Information extracted from kohya checkpoint + + Raises: + NotImplementedError: Raises if some layers have different ranks or alphas + + Returns: + LoraConfig: config for constructing LoRA + """ + + # Unpack all ranks and alphas + ranks = {x[0]: x[1].rank for x in info.items()} + alphas = {x[0]: x[1].alpha or x[1].rank for x in info.items()} + + # Determine which modules needs to be transformed + target_modules = list(info.keys()) + + # Determine most common rank and alpha + r = max(set(ranks.values()), key=list(ranks.values()).count) + lora_alpha = max(set(alphas.values()), key=list(alphas.values()).count) + + # Determine which modules have different rank and alpha + rank_pattern = dict(filter(lambda x: x[1] != r, ranks.items())) + alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, alphas.items())) + + config = LoraConfig( + r=r, + lora_alpha=lora_alpha, + target_modules=target_modules, + lora_dropout=0.0, + bias="none", + init_lora_weights=False, + rank_pattern=rank_pattern, + alpha_pattern=alpha_pattern, + ) + + return config + + +def combine_peft_state_dict(info: Dict[str, LoRAInfo]) -> Dict[str, torch.Tensor]: + result = {} + for key_name, key_info in info.items(): + result[f"base_model.model.{key_name}.lora_A.weight"] = key_info.lora_A + result[f"base_model.model.{key_name}.lora_B.weight"] = key_info.lora_B + return result if __name__ == "__main__": @@ -75,93 +100,79 @@ def get_rank_alpha( parser.add_argument("--half", action="store_true", help="Save weights in half precision.") args = parser.parse_args() - # Find text encoder modules to add LoRA to + # Load all models that we need to add adapter to text_encoder = CLIPTextModel.from_pretrained(args.sd_checkpoint, subfolder="text_encoder") - text_encoder_modules_names = get_modules_names( - text_encoder, target_replace_modules_linear=TEXT_ENCODER_TARGET_REPLACE_MODULE - ) - - # Find unet2d modules to add LoRA to unet = UNet2DConditionModel.from_pretrained(args.sd_checkpoint, subfolder="unet") - unet_modules_names = get_modules_names( - unet, - target_replace_modules_linear=UNET_TARGET_REPLACE_MODULE, - target_replace_modules_conv2d=UNET_TARGET_REPLACE_MODULE, - ) + + # Construct possible mapping from kohya keys to peft keys + models_keys = {} + for model, model_key, model_name in [ + (text_encoder, LORA_PREFIX_TEXT_ENCODER, "text_encoder"), + (unet, LORA_PREFIX_UNET, "unet"), + ]: + models_keys.update( + { + f"{model_key}.{peft_key}".replace(".", "_"): peft_key + for peft_key in (x[0] for x in model.named_modules()) + } + ) + + # Store conversion info (model_type -> peft_key -> LoRAInfo) + lora_info: Dict[str, Dict[str, LoRAInfo]] = { + "text_encoder": {}, + "unet": {}, + } # Open kohya_ss checkpoint with safetensors.safe_open(args.kohya_lora_path, framework="pt", device="cpu") as f: # Extract information about LoRA structure metadata = f.metadata() - if (metadata is not None) and ("ss_network_dim" in metadata) and ("ss_network_alpha" in metadata): - # LoRA rank and alpha are in safetensors metadata, just get it - lora_r = lora_text_encoder_r = int(metadata["ss_network_dim"]) - lora_alpha = lora_text_encoder_alpha = float(metadata["ss_network_alpha"]) - else: - # LoRA rank and alpha are not present, so infer them - lora_r = get_rank_alpha( - f.keys(), lambda n: f.get_tensor(n).size(0), f"^{LORA_PREFIX_UNET}\w+\.lora_down\.weight$" - ) - lora_text_encoder_r = get_rank_alpha( - f.keys(), lambda n: f.get_tensor(n).size(0), f"^{LORA_PREFIX_TEXT_ENCODER}\w+\.lora_down\.weight$" - ) - lora_alpha = get_rank_alpha(f.keys(), lambda n: f.get_tensor(n).item(), f"^{LORA_PREFIX_UNET}\w+\.alpha$") - lora_text_encoder_alpha = get_rank_alpha( - f.keys(), lambda n: f.get_tensor(n).item(), f"^{LORA_PREFIX_TEXT_ENCODER}\w+\.alpha$" - ) - - # Create LoRA for text encoder - text_encoder_config = LoraConfig( - r=lora_text_encoder_r, - lora_alpha=lora_text_encoder_alpha, - target_modules=text_encoder_modules_names, - lora_dropout=0.0, - bias="none", - ) - text_encoder = get_peft_model(text_encoder, text_encoder_config) - text_encoder_lora_state_dict = {x: None for x in get_peft_model_state_dict(text_encoder).keys()} - - # Load text encoder values from kohya_ss LoRA - for peft_te_key in text_encoder_lora_state_dict.keys(): - kohya_ss_te_key = peft_te_key.replace("base_model.model", LORA_PREFIX_TEXT_ENCODER) - kohya_ss_te_key = kohya_ss_te_key.replace("lora_A", "lora_down") - kohya_ss_te_key = kohya_ss_te_key.replace("lora_B", "lora_up") - kohya_ss_te_key = kohya_ss_te_key.replace(".", "_", kohya_ss_te_key.count(".") - 2) - text_encoder_lora_state_dict[peft_te_key] = f.get_tensor(kohya_ss_te_key).to(text_encoder.dtype) - - # Load converted kohya_ss text encoder LoRA back to PEFT - set_peft_model_state_dict(text_encoder, text_encoder_lora_state_dict) - if args.half: - text_encoder.to(torch.float16) - - # Save text encoder result - text_encoder.save_pretrained( - os.path.join(args.dump_path, "text_encoder"), - ) - - # Create LoRA for unet2d - unet_config = LoraConfig( - r=lora_r, lora_alpha=lora_alpha, target_modules=unet_modules_names, lora_dropout=0.0, bias="none" - ) - unet = get_peft_model(unet, unet_config) - unet_lora_state_dict = {x: None for x in get_peft_model_state_dict(unet).keys()} - - # Load unet2d values from kohya_ss LoRA - for peft_unet_key in unet_lora_state_dict.keys(): - kohya_ss_unet_key = peft_unet_key.replace("base_model.model", LORA_PREFIX_UNET) - kohya_ss_unet_key = kohya_ss_unet_key.replace("lora_A", "lora_down") - kohya_ss_unet_key = kohya_ss_unet_key.replace("lora_B", "lora_up") - kohya_ss_unet_key = kohya_ss_unet_key.replace(".", "_", kohya_ss_unet_key.count(".") - 2) - unet_lora_state_dict[peft_unet_key] = f.get_tensor(kohya_ss_unet_key).to(unet.dtype) - - # Load converted kohya_ss unet LoRA back to PEFT - set_peft_model_state_dict(unet, unet_lora_state_dict) + # Iterate through available info and unpack all the values + for key in f.keys(): + kohya_key, kohya_type = key.split(".")[:2] + + # Find which model this key belongs to + if kohya_key.startswith(LORA_PREFIX_TEXT_ENCODER): + model_type = "text_encoder" + elif kohya_key.startswith(LORA_PREFIX_UNET): + model_type = "unet" + else: + raise ValueError(f"Cannot determine model for key: {key}") + + # Find corresponding peft key + if kohya_key not in models_keys: + raise ValueError(f"Cannot find corresponding key for diffusers/transformers model: {kohya_key}") + peft_key = models_keys[kohya_key] + + if peft_key not in lora_info[model_type]: + lora_info[model_type][peft_key] = LoRAInfo(kohya_key=kohya_key, peft_key=peft_key) + + if kohya_type == "alpha": + lora_info[model_type][peft_key].alpha = f.get_tensor(key).item() + elif kohya_type == "lora_down": + tensor = f.get_tensor(key) + lora_info[model_type][peft_key].lora_A = tensor + lora_info[model_type][peft_key].rank = tensor.shape[0] + elif kohya_type == "lora_up": + tensor = f.get_tensor(key) + lora_info[model_type][peft_key].lora_B = f.get_tensor(key) + lora_info[model_type][peft_key].rank = tensor.shape[1] + else: + raise ValueError(f"Unknown weight name in key: {key} - {kohya_type}") + + # Process each model + for model, model_name in [(text_encoder, "text_encoder"), (unet, "unet")]: + config = construct_peft_loraconfig(lora_info[model_name]) + model = get_peft_model(model, config) + + keys_peft = list(get_peft_model_state_dict(model).keys()) + keys_new = list(combine_peft_state_dict(lora_info[model_name]).keys()) + + set_peft_model_state_dict(model, combine_peft_state_dict(lora_info[model_name])) if args.half: - unet.to(torch.float16) + model.to(torch.float16) - # Save text encoder result - unet.save_pretrained( - os.path.join(args.dump_path, "unet"), - ) + # Save model to disk + model.save_pretrained(os.path.join(args.dump_path, model_name)) From dc945781bdb807c8d85c1b6df8bb907339ca6ab0 Mon Sep 17 00:00:00 2001 From: Alexander Kovalchuk Date: Sun, 17 Sep 2023 12:32:35 +0300 Subject: [PATCH 3/4] Fixed styling --- .../lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py | 1 - src/peft/tuners/lora/config.py | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py b/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py index 5f59e8b87f..266d5d352d 100644 --- a/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py +++ b/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py @@ -9,7 +9,6 @@ from transformers import CLIPTextModel from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict -from peft.tuners.lora import LoraConfig # Default kohya_ss LoRA replacement modules diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index dff66822f5..d85aa79239 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -46,9 +46,11 @@ class LoraConfig(PeftConfig): The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer pattern is not in the common layers pattern. rank_pattern (`dict`): - The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. + The mapping from layer names or regexp expression to ranks which are different from the default rank + specified by `r`. alpha_pattern (`dict`): - The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `lora_alpha`. + The mapping from layer names or regexp expression to alphas which are different from the default alpha + specified by `lora_alpha`. """ r: int = field(default=8, metadata={"help": "Lora attention dimension"}) From 9c46c16ddaeaf0bc8d445f3b7d3546ccb3ef4db6 Mon Sep 17 00:00:00 2001 From: Alexander Kovalchuk Date: Mon, 18 Sep 2023 11:37:45 +0300 Subject: [PATCH 4/4] Removed old comment from docstring --- examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py b/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py index 266d5d352d..264b8ea07f 100644 --- a/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py +++ b/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py @@ -41,9 +41,6 @@ def construct_peft_loraconfig(info: Dict[str, LoRAInfo]) -> LoraConfig: Args: info (Dict[str, LoRAInfo]): Information extracted from kohya checkpoint - Raises: - NotImplementedError: Raises if some layers have different ranks or alphas - Returns: LoraConfig: config for constructing LoRA """