From b5199f40c0ed524f2fabb535b045f84365cfda01 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 11 Oct 2024 20:31:21 +0800 Subject: [PATCH] [Misc][LoRA] Support loading LoRA weights for target_modules in reg format (#9275) Signed-off-by: Sumit Dubey --- tests/lora/conftest.py | 5 +++++ tests/lora/test_lora_checkpoints.py | 17 ++++++++++++-- vllm/lora/models.py | 7 ++++-- vllm/lora/utils.py | 35 ++++++++++++++++++++++++++++- 4 files changed, 59 insertions(+), 5 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index da98fac99cf22..405c0d0efad65 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -199,6 +199,11 @@ def baichuan_zero_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init") +@pytest.fixture(scope="session") +def baichuan_regex_lora_files(): + return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex") + + @pytest.fixture(scope="session") def minicpmv_lora_files(): return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon") diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index 3514dcb7aedf4..9a529e27b4cd8 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -5,7 +5,9 @@ from vllm.lora.models import LoRAModel from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM -lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"] +lora_lst = [ + "baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b" +] @pytest.mark.parametrize("lora_name", lora_lst) @@ -13,6 +15,7 @@ def test_load_checkpoints( lora_name, baichuan_lora_files, baichuan_zero_lora_files, + baichuan_regex_lora_files, chatglm3_lora_files, ): supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules @@ -36,7 +39,7 @@ def test_load_checkpoints( embedding_modules=embedding_modules, embedding_padding_modules=embed_padding_modules) elif lora_name == "baichuan7B-zero": - #Test that the target_modules contain prefix + # Test that the target_modules contain prefix # such as "model.layers.0.self_atten.W_pack", and # the test should pass. LoRAModel.from_local_checkpoint( @@ -46,6 +49,16 @@ def test_load_checkpoints( device="cpu", embedding_modules=embedding_modules, embedding_padding_modules=embed_padding_modules) + elif lora_name == "baichuan7B-zero-regex": + # Test that the `target_modules` in the form of regular expressions, + # such as `model\\..*(W_pack|o_proj)`, and the test should pass. + LoRAModel.from_local_checkpoint( + baichuan_regex_lora_files, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embed_padding_modules) else: # For the baichuan7B model, load chatglm3-6b's LoRA, # and the test should raise the following error. diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 91e9f55e82433..0dc54516f8671 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -23,6 +23,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.punica import PunicaWrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, + is_regex_target_modules, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -233,6 +234,8 @@ def from_local_checkpoint( # modules. unexpected_modules = [] target_modules = config["target_modules"] + if not isinstance(target_modules, list): + target_modules = [target_modules] for module in target_modules: # Compatible with more modules, # such as:layers.11.self_attn.k_proj @@ -243,8 +246,8 @@ def from_local_checkpoint( # expected_lora_modules. It is not reliable. See # https://github.com/vllm-project/vllm/pull/5909. But there's no # other better mechanism. - if unexpected_modules: - print(unexpected_modules, "modules") + if unexpected_modules and not is_regex_target_modules( + config["target_modules"], expected_lora_modules): raise ValueError( f"While loading {lora_dir}, expected" f" target modules in {expected_lora_modules}" diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index ee983328e2c5b..a780429f413d3 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -1,5 +1,6 @@ import os -from typing import List, Optional, Set, Tuple, Type +import re +from typing import List, Optional, Set, Tuple, Type, Union import huggingface_hub from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, @@ -113,6 +114,38 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: raise ValueError(f"{name} is unsupported LoRA weight") +def is_regex_target_modules(load_modules: Union[str, List[str]], + expected_lora_modules: List[str]) -> bool: + """ + PEFT supports passing `target_modules` in the form of regular expressions, + such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to + determine whether the suffix in the regular expression is present in the + `expected_lora_modules`. + """ + + def is_valid_regex(pattern): + try: + re.compile(pattern) + return True + except re.error: + return False + + def is_subset(sub_list, full_list): + return set(sub_list).issubset(set(full_list)) + + # Similar to PEFT's processing logic, regex-related operations are only + # executed when the load_modules is a `str`. + if not isinstance(load_modules, str): + return False + + if is_valid_regex(load_modules): + match = re.search(r"\((.*?)\)\$?$", load_modules) + if match: + suffix = match.group(1).split("|") + return is_subset(suffix, expected_lora_modules) + return False + + def get_adapter_absolute_path(lora_path: str) -> str: """ Resolves the given lora_path to an absolute local path.