From 21379173089701aad9d50ebc6f0d665ce9694c7c Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 6 Aug 2024 11:38:13 +0800 Subject: [PATCH 01/18] init --- vllm/lora/models.py | 7 ++++ vllm/model_executor/models/minicpmv.py | 48 ++++++++++++++++++++------ vllm/worker/model_runner.py | 6 ++-- 3 files changed, 47 insertions(+), 14 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 017a1002bb9a7..fe33f2544e9a1 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -439,6 +439,13 @@ def _create_lora_modules(self): self.model, module_name, from_layer(module, self.lora_slots, self.lora_config, packed_moduled_lst, self.model.config)) + # In some models, especially multimodal ones, layers with the same + # name may have different types, such as nn.Linear and + # ReplicatedLinear. The nn.Linear layers cannot be replaced with + # LoRA layers, leading to assertion errors. The following check + # aims to prevent this issue + if not isinstance(new_module, BaseLayerWithLoRA): + continue # LinearScalingRotaryEmbeddingWithLora is used to handle # long context lora. Register relevant metadata. if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora): diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 0388259595628..55f68811db3dc 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -37,7 +37,7 @@ from transformers.configuration_utils import PretrainedConfig from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear @@ -59,6 +59,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from .idefics2_vision_model import Idefics2VisionTransformer +from .interfaces import SupportsLoRA logger = init_logger(__name__) @@ -808,7 +809,26 @@ def is_default_weight_loading(self, name: str) -> bool: return "resampler" in name or "vpm" in name -class MiniCPMV2_5(MiniCPMVBaseModel): +class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "fc1", "fc2", + "out_proj", "kv_proj" + ] + embedding_modules = {} + embedding_padding_modules = [] def __init__( self, @@ -816,6 +836,7 @@ def __init__( multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__(config, multimodal_config, cache_config, quant_config) assert self.version == (2, 5) @@ -993,20 +1014,25 @@ def is_default_weight_loading(self, name: str) -> bool: @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) @INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) -class MiniCPMV(MiniCPMVBaseModel): +class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA): """ Different versions of MiniCPMV use different visual encoders and LLMs, which is not conducive to the current integration logic of LoRA and bitsandbytes in vLLM. Therefore, it is necessary to separate them. """ - - def __new__( - cls, - config: PretrainedConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + packed_modules_mapping = {} + + # LoRA specific attributes + supported_lora_modules = [] + embedding_modules = {} + embedding_padding_modules = [] + + def __new__(cls, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): if not hasattr(config, "version"): if config.hidden_size == 2304 and config.query_num == 64: version = (2, 0) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f9c26e0c318b1..7ddbda7ee8ba4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -734,9 +734,9 @@ def load_model(self) -> None: if self.lora_config: assert supports_lora(self.model), "Model does not support LoRA" - assert not supports_vision( - self.model - ), "To be tested: vision language model with LoRA settings." + # assert not supports_vision( + # self.model + # ), "To be tested: vision language model with LoRA settings." self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, From 5edda37733143f378ee5a65e509c2f6c9590f6eb Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 6 Aug 2024 18:08:31 +0800 Subject: [PATCH 02/18] optimize minicpmv implementation --- vllm/lora/models.py | 4 +- vllm/model_executor/models/minicpmv.py | 53 ++++++++++++++++++++------ 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index fe33f2544e9a1..310b9a2e267e9 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -442,8 +442,8 @@ def _create_lora_modules(self): # In some models, especially multimodal ones, layers with the same # name may have different types, such as nn.Linear and # ReplicatedLinear. The nn.Linear layers cannot be replaced with - # LoRA layers, leading to assertion errors. The following check - # aims to prevent this issue + # LoRA layers, leading to assertion error. The following check + # aims to prevent this error if not isinstance(new_module, BaseLayerWithLoRA): continue # LinearScalingRotaryEmbeddingWithLora is used to handle diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 55f68811db3dc..7ccdbbac36388 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -65,7 +65,6 @@ _KEYS_TO_MODIFY_MAPPING = { "llm.lm_head": "lm_head", - "llm.model": "llm", } @@ -483,6 +482,21 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int): return llm_inputs +class LLMWrapper(nn.Module): + """ + To align with the key names of LoRA trained with PEFT, we need to add an + additional layer to the llm's implementation. + """ + + def __init__(self, llm: nn.Module, name: str) -> None: + super().__init__() + self.model_name = name + setattr(self, name, llm) + + def forward(self, *args, **kwargs) -> Any: + return getattr(self, self.model_name)(*args, **kwargs) + + class MiniCPMVBaseModel(nn.Module, SupportsVision): """ The abstract class of MiniCPMV can only be inherited, but cannot be @@ -521,7 +535,7 @@ def get_embedding( input_ids: torch.Tensor, image_inputs: Optional[MiniCPMVImageInputs], ) -> Tuple[torch.Tensor, torch.Tensor]: - vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) + vlm_embedding: torch.Tensor = self.get_llm_embedding(input_ids) if hasattr(self.config, "scale_emb"): vlm_embedding *= self.config.scale_emb @@ -710,6 +724,9 @@ def get_vision_embedding( ) -> torch.Tensor: raise NotImplementedError + def get_llm_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: raise NotImplementedError @@ -736,9 +753,11 @@ def init_llm( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> nn.Module: - return MiniCPMModel(config, - cache_config=cache_config, - quant_config=quant_config) + + return LLMWrapper(MiniCPMModel(config, + cache_config=cache_config, + quant_config=quant_config), + name="model") def init_vision_module(self) -> nn.Module: # TODO :refactor this vision model @@ -764,6 +783,9 @@ def init_vision_module(self) -> nn.Module: return model + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_tokens(input_ids) + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: with set_default_torch_dtype(torch.float16): resampler = Resampler2( @@ -799,6 +821,9 @@ def get_vision_embedding( res.append(self.resampler(vision_embedding, tgt_size)) return torch.vstack(res) + def get_llm_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.llm.embed_tokens(input_ids) + def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: pixel_values = data["pixel_values"] @@ -847,9 +872,10 @@ def init_llm( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> nn.Module: - return LlamaModel(config, - cache_config=cache_config, - quant_config=quant_config) + return LLMWrapper(LlamaModel(config, + cache_config=cache_config, + quant_config=quant_config), + name="model") def init_vision_module(self) -> nn.Module: model = Idefics2VisionTransformer(self.config.vision_config) @@ -878,6 +904,9 @@ def get_vision_embedding( vision_embedding = self.resampler(vision_embedding, tgt_sizes) return vision_embedding + def get_llm_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.llm.model.embed_tokens(input_ids) + def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: pixel_values = data["pixel_values"] @@ -957,7 +986,6 @@ def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: num_heads=embed_dim // 128, kv_dim=vision_dim, ) - return resampler def get_vision_embedding( @@ -973,6 +1001,9 @@ def get_vision_embedding( ).last_hidden_state return vision_embedding + def get_llm_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.llm.embed_tokens(input_ids) + def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: pixel_values = data["pixel_values"] @@ -1020,9 +1051,9 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA): which is not conducive to the current integration logic of LoRA and bitsandbytes in vLLM. Therefore, it is necessary to separate them. """ + # Ensure that the LoRA support check passes when the class is not + # initialized,but set all these attributes to empty packed_modules_mapping = {} - - # LoRA specific attributes supported_lora_modules = [] embedding_modules = {} embedding_padding_modules = [] From 2ea5006cfc2691a081e5a04059355c65e5c23b97 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 6 Aug 2024 18:10:50 +0800 Subject: [PATCH 03/18] delete comment --- vllm/worker/model_runner.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7ddbda7ee8ba4..f0bf981a37462 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -734,10 +734,6 @@ def load_model(self) -> None: if self.lora_config: assert supports_lora(self.model), "Model does not support LoRA" - # assert not supports_vision( - # self.model - # ), "To be tested: vision language model with LoRA settings." - self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, From 42e846a48d1b1a3a7a8d5af82fe8be1389039094 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 7 Aug 2024 14:32:24 +0800 Subject: [PATCH 04/18] Trigger LoRA test From 9eed2354fb16c442908fb2be3a8802e03a2f75bc Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 25 Sep 2024 14:28:45 +0800 Subject: [PATCH 05/18] Modify code --- vllm/model_executor/models/minicpmv.py | 4 +--- vllm/worker/model_runner.py | 5 ++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index b802202b802c1..1d40aad8c82da 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -389,7 +389,6 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object): return MultiModalInputs(batch_data) - class LLMWrapper(nn.Module): """ To align with the key names of LoRA trained with PEFT, we need to add an @@ -405,7 +404,6 @@ def forward(self, *args, **kwargs) -> Any: return getattr(self, self.model_name)(*args, **kwargs) - class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): """ The abstract class of MiniCPMV can only be inherited, but cannot be @@ -448,7 +446,7 @@ def get_embedding( input_ids: torch.Tensor, image_inputs: Optional[MiniCPMVImagePixelInputs], ) -> Tuple[torch.Tensor, torch.Tensor]: - vlm_embedding: torch.Tensor = self.get_llm_embedding(input_ids) + vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) if hasattr(self.config, "scale_emb"): vlm_embedding *= self.config.scale_emb diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index dd20954e88826..dc1020e3ea8ba 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1027,9 +1027,8 @@ def load_model(self) -> None: if self.lora_config: assert supports_lora(self.model), "Model does not support LoRA" - # assert not supports_multimodal( - # self.model - # ), "To be tested: Multi-modal model with LoRA settings." + if supports_multimodal(self.model): + logger.warning("todo:add warning info") self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, From e4e3f46536e96f300697a0ca42f3cc329b7501ba Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 25 Sep 2024 18:49:51 +0800 Subject: [PATCH 06/18] Complete VL supports lora --- vllm/lora/models.py | 35 ++++++++++-- vllm/model_executor/models/minicpmv.py | 18 +++++- vllm/model_executor/models/module_mapping.py | 59 ++++++++++++++++++++ vllm/worker/model_runner.py | 3 +- 4 files changed, 108 insertions(+), 7 deletions(-) create mode 100644 vllm/model_executor/models/module_mapping.py diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 8d33fc4e2f299..d726a36b83c2c 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -24,7 +24,9 @@ from vllm.lora.punica import PunicaWrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) -from vllm.model_executor.models.interfaces import SupportsLoRA +from vllm.model_executor.models.interfaces import (SupportsLoRA, + supports_multimodal) +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer from vllm.utils import is_pin_memory_available @@ -332,6 +334,8 @@ def __init__( self.supported_lora_modules.append("rotary_emb") self.packed_modules_mapping = copy.deepcopy( self.model.packed_modules_mapping) + # Used to indicate whether the model is a multimodal model + self.supports_mm: bool = supports_multimodal(self.model) self.packed_modules: Dict[str, List[str]] = {} self.modules: Dict[str, "BaseLayerWithLoRA"] = {} # Dict instead of a Set for compatibility with LRUCache. @@ -437,6 +441,15 @@ def _create_lora_modules(self): continue if not self._match_target_modules(module_name): continue + # A temporary approach for multimodal models to support LoRA + # TODO: Remove this restriction + if self._filter_unsupported_modules(module_name): + logger.warning( + "Regarding multimodal models, vLLM currently only supports " + "adding LoRA to language models, %s will be ignored.", + module_name, + ) + continue parts = module_name.split(".")[-1] packed_moduled_lst = self.packed_modules_mapping.get(parts, []) new_module = replace_submodule( @@ -485,9 +498,10 @@ def create_dummy_lora( """Create zero-initialized LoRAModel for warmup.""" model = LoRAModel(lora_id, rank, {}, scaling_factor) for module_name, module in self.model.named_modules(): - if not self._match_target_modules(module_name) or not isinstance( - module, BaseLayerWithLoRA) or isinstance( - module, LinearScalingRotaryEmbeddingWithLora): + if (not self._match_target_modules(module_name) + or not isinstance(module, BaseLayerWithLoRA) + or isinstance(module, LinearScalingRotaryEmbeddingWithLora) + or self._filter_unsupported_modules(module_name)): continue parts = module_name.split(".") if module_name not in self.packed_modules: @@ -548,6 +562,19 @@ def _match_target_modules(self, module_name: str): module_name) or target_module == module_name for target_module in self.supported_lora_modules) + def _filter_unsupported_modules(self, module_name: str) -> bool: + """ + Regarding multimodal models, vLLM currently only supports adding LoRA to + language model. LoRA for other modules, such as the vision tower, will + be filtered out. + """ + if self.supports_mm: + prefix = module_name.split(".")[0] + module_mapping: MultiModelKeys = self.model.get_mm_mapping() + return (prefix in module_mapping.connector + or prefix in module_mapping.vision_tower) + return False + def _register_packed_modules(self, module_full_name: str) -> None: parts = module_full_name.split(".") module_name = parts[-1] diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 1d40aad8c82da..26adb04cc07c9 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -50,6 +50,7 @@ from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.minicpm import MiniCPMModel +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -403,6 +404,9 @@ def __init__(self, llm: nn.Module, name: str) -> None: def forward(self, *args, **kwargs) -> Any: return getattr(self, self.model_name)(*args, **kwargs) + def embed_tokens(self, *args, **kwargs): + return getattr(self, self.model_name).embed_tokens(*args, **kwargs) + class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): """ @@ -636,6 +640,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) weight_loader(param, loaded_weight) + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys(language_model="llm", + connector="resampler", + vision_tower="vpm") + def init_llm( self, config: PretrainedConfig, @@ -778,8 +790,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): # LoRA specific attributes supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "fc1", "fc2", - "out_proj", "kv_proj" + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", ] embedding_modules = {} embedding_padding_modules = [] diff --git a/vllm/model_executor/models/module_mapping.py b/vllm/model_executor/models/module_mapping.py new file mode 100644 index 0000000000000..7e13896acbe90 --- /dev/null +++ b/vllm/model_executor/models/module_mapping.py @@ -0,0 +1,59 @@ + +#Copied code from: https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py + +from dataclasses import dataclass, field +from typing import List, Union + + +@dataclass +class ModelKeys: + model_type: str = None + + module_list: str = None + + embedding: str = None + + mlp: str = None + + down_proj: str = None + + attention: str = None + + o_proj: str = None + + q_proj: str = None + + k_proj: str = None + + v_proj: str = None + + qkv_proj: str = None + + qk_proj: str = None + + qa_proj: str = None + + qb_proj: str = None + + kva_proj: str = None + + kvb_proj: str = None + + output: str = None + + +@dataclass +class MultiModelKeys(ModelKeys): + language_model: Union[List[str], str] = field(default_factory=list) + connector: Union[List[str], str] = field(default_factory=list) + vision_tower: Union[List[str], str] = field(default_factory=list) + generator: Union[List[str], str] = field(default_factory=list) + + def __post_init__(self): + # compat + for key in ["language_model", "connector", "vision_tower", "generator"]: + v = getattr(self, key) + if isinstance(v, str): + setattr(self, key, [v]) + if v is None: + setattr(self, key, []) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index dc1020e3ea8ba..970b42b231350 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1028,7 +1028,8 @@ def load_model(self) -> None: if self.lora_config: assert supports_lora(self.model), "Model does not support LoRA" if supports_multimodal(self.model): - logger.warning("todo:add warning info") + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, From 65b5b08fe3497046812470b25334b9b6b022e7f4 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 25 Sep 2024 23:12:18 +0800 Subject: [PATCH 07/18] Format code --- vllm/model_executor/models/module_mapping.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/module_mapping.py b/vllm/model_executor/models/module_mapping.py index 7e13896acbe90..10ee06fde0e1d 100644 --- a/vllm/model_executor/models/module_mapping.py +++ b/vllm/model_executor/models/module_mapping.py @@ -1,5 +1,5 @@ - -#Copied code from: https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py +# Copied code from +# https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py from dataclasses import dataclass, field from typing import List, Union @@ -51,7 +51,9 @@ class MultiModelKeys(ModelKeys): def __post_init__(self): # compat - for key in ["language_model", "connector", "vision_tower", "generator"]: + for key in [ + "language_model", "connector", "vision_tower", "generator" + ]: v = getattr(self, key) if isinstance(v, str): setattr(self, key, [v]) From 9bf92d5deb03cbd2b33fb2dc7209e9cf60ed1437 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 25 Sep 2024 23:36:26 +0800 Subject: [PATCH 08/18] Clean code --- vllm/lora/models.py | 2 +- vllm/model_executor/models/minicpmv.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index d726a36b83c2c..b2110e9188be2 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -446,7 +446,7 @@ def _create_lora_modules(self): if self._filter_unsupported_modules(module_name): logger.warning( "Regarding multimodal models, vLLM currently only supports " - "adding LoRA to language models, %s will be ignored.", + "adding LoRA to language model, %s will be ignored.", module_name, ) continue diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 26adb04cc07c9..063c698db91fb 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -787,13 +787,19 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): "up_proj", ], } - # LoRA specific attributes supported_lora_modules = [ - "qkv_proj", + # vision encoder + "fc1", + "fc2", + "out_proj", + # language model + "qkv_proj", # same name with vision encoder "o_proj", "gate_up_proj", "down_proj", + # resampler + "kv_proj", ] embedding_modules = {} embedding_padding_modules = [] From 561b4b758be556c7e57583572b6d27566c5745a2 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 26 Sep 2024 16:05:21 +0800 Subject: [PATCH 09/18] Clean code --- vllm/lora/models.py | 8 ++++---- vllm/model_executor/models/minicpmv.py | 6 +++--- vllm/model_executor/models/module_mapping.py | 8 +++----- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index b2110e9188be2..c0b7ff8258a89 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -443,7 +443,7 @@ def _create_lora_modules(self): continue # A temporary approach for multimodal models to support LoRA # TODO: Remove this restriction - if self._filter_unsupported_modules(module_name): + if self._filter_unsupported_module(module_name): logger.warning( "Regarding multimodal models, vLLM currently only supports " "adding LoRA to language model, %s will be ignored.", @@ -501,7 +501,7 @@ def create_dummy_lora( if (not self._match_target_modules(module_name) or not isinstance(module, BaseLayerWithLoRA) or isinstance(module, LinearScalingRotaryEmbeddingWithLora) - or self._filter_unsupported_modules(module_name)): + or self._filter_unsupported_module(module_name)): continue parts = module_name.split(".") if module_name not in self.packed_modules: @@ -562,7 +562,7 @@ def _match_target_modules(self, module_name: str): module_name) or target_module == module_name for target_module in self.supported_lora_modules) - def _filter_unsupported_modules(self, module_name: str) -> bool: + def _filter_unsupported_module(self, module_name: str) -> bool: """ Regarding multimodal models, vLLM currently only supports adding LoRA to language model. LoRA for other modules, such as the vision tower, will @@ -572,7 +572,7 @@ def _filter_unsupported_modules(self, module_name: str) -> bool: prefix = module_name.split(".")[0] module_mapping: MultiModelKeys = self.model.get_mm_mapping() return (prefix in module_mapping.connector - or prefix in module_mapping.vision_tower) + or prefix in module_mapping.tower_model) return False def _register_packed_modules(self, module_full_name: str) -> None: diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 063c698db91fb..2be49dc18b76b 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -404,7 +404,7 @@ def __init__(self, llm: nn.Module, name: str) -> None: def forward(self, *args, **kwargs) -> Any: return getattr(self, self.model_name)(*args, **kwargs) - def embed_tokens(self, *args, **kwargs): + def embed_tokens(self, *args, **kwargs) -> Any: return getattr(self, self.model_name).embed_tokens(*args, **kwargs) @@ -646,7 +646,7 @@ def get_mm_mapping(self) -> MultiModelKeys: """ return MultiModelKeys(language_model="llm", connector="resampler", - vision_tower="vpm") + tower_model="vpm") def init_llm( self, @@ -1001,7 +1001,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA): bitsandbytes in vLLM. Therefore, it is necessary to separate them. """ # Ensure that the LoRA support check passes when the class is not - # initialized,but set all these attributes to empty + # initialized, but set all these attributes to empty. packed_modules_mapping = {} supported_lora_modules = [] embedding_modules = {} diff --git a/vllm/model_executor/models/module_mapping.py b/vllm/model_executor/models/module_mapping.py index 10ee06fde0e1d..221be8bb61337 100644 --- a/vllm/model_executor/models/module_mapping.py +++ b/vllm/model_executor/models/module_mapping.py @@ -46,14 +46,12 @@ class ModelKeys: class MultiModelKeys(ModelKeys): language_model: Union[List[str], str] = field(default_factory=list) connector: Union[List[str], str] = field(default_factory=list) - vision_tower: Union[List[str], str] = field(default_factory=list) + # such vision tower and audio tower + tower_model: Union[List[str], str] = field(default_factory=list) generator: Union[List[str], str] = field(default_factory=list) def __post_init__(self): - # compat - for key in [ - "language_model", "connector", "vision_tower", "generator" - ]: + for key in ["language_model", "connector", "tower_model", "generator"]: v = getattr(self, key) if isinstance(v, str): setattr(self, key, [v]) From 578deba51fbdb67f21d88870f80c5b3bd24234ed Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 26 Sep 2024 16:13:27 +0800 Subject: [PATCH 10/18] Clean code --- vllm/model_executor/models/module_mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/module_mapping.py b/vllm/model_executor/models/module_mapping.py index 221be8bb61337..66f5427f8f30c 100644 --- a/vllm/model_executor/models/module_mapping.py +++ b/vllm/model_executor/models/module_mapping.py @@ -46,7 +46,7 @@ class ModelKeys: class MultiModelKeys(ModelKeys): language_model: Union[List[str], str] = field(default_factory=list) connector: Union[List[str], str] = field(default_factory=list) - # such vision tower and audio tower + # vision tower and audio tower tower_model: Union[List[str], str] = field(default_factory=list) generator: Union[List[str], str] = field(default_factory=list) From 99dacdf223d3148ebe4169984012fffc02dd953e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 27 Sep 2024 14:12:26 +0800 Subject: [PATCH 11/18] Add unit test for minicpmv25 --- tests/lora/conftest.py | 5 ++ tests/lora/test_minicpmv.py | 99 ++++++++++++++++++++++++++ vllm/model_executor/models/minicpmv.py | 11 ++- vllm/worker/model_runner.py | 4 +- 4 files changed, 115 insertions(+), 4 deletions(-) create mode 100644 tests/lora/test_minicpmv.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 4834a9d35a3ee..7f6f60f38b5de 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -194,6 +194,11 @@ def baichuan_zero_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init") +@pytest.fixture(scope="session") +def minicpmv_lora_files(): + return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon") + + @pytest.fixture(scope="session") def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py new file mode 100644 index 0000000000000..d279513dabaa8 --- /dev/null +++ b/tests/lora/test_minicpmv.py @@ -0,0 +1,99 @@ +from typing import List + +import pytest + +import vllm +from vllm.lora.request import LoRARequest +from vllm.assets.image import ImageAsset + + +MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" + +PROMPT_TEMPLATE = ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + "(./)\nWhat is in the image?<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) + +IMAGE_ASSETS = [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), +] + + +# After fine-tuning with LoRA, all generated content should start begin `A`. +EXPECTED_OUTPUT = [ + "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501 + "A pink cherry blossom tree with a blue sky in the background.", +] + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + sampling_params = vllm.SamplingParams( + temperature=0, + max_tokens=256, + stop_token_ids=[128001, 128009], # eos_id, eot_id + ) + + inputs = [ + { + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": {"image": asset.pil_image}, + } + for asset in IMAGE_ASSETS + ] + + outputs = llm.generate( + inputs, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id + else None, + ) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def test_minicpmv_lora(minicpmv_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + trust_remote_code=True, + ) + + output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) + for i in range(len(EXPECTED_OUTPUT)): + assert output1[i] == EXPECTED_OUTPUT[i] + output2 = do_sample(llm, minicpmv_lora_files, lora_id=2) + for i in range(len(EXPECTED_OUTPUT)): + assert output2[i] == EXPECTED_OUTPUT[i] + + +# @pytest.mark.skip("Requires multiple GPUs") +@pytest.mark.parametrize("fully_sharded", [True, False]) +@pytest.mark.parametrize("tp", [2, 4]) +def test_minicpmv_tensor_parallel(minicpmv_lora_files, fully_sharded, tp): + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=tp, + trust_remote_code=True, + fully_sharded_loras=fully_sharded, + ) + output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) + + for i in range(len(EXPECTED_OUTPUT)): + assert output_tp[i] == EXPECTED_OUTPUT[i] + diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 2be49dc18b76b..d793cfa83e01e 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -904,9 +904,14 @@ def init_llm( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> nn.Module: - return Qwen2Model(config, - cache_config=cache_config, - quant_config=quant_config) + # return Qwen2Model(config, + # cache_config=cache_config, + # quant_config=quant_config) + + return LLMWrapper(Qwen2Model(config, + cache_config=cache_config, + quant_config=quant_config), + name="model") def init_vision_module(self) -> nn.Module: # A custom version of SiglipVisionTransformer, won't work with TP diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 970b42b231350..b2107862aa046 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1026,7 +1026,9 @@ def load_model(self) -> None: self.model_memory_usage / float(2**30)) if self.lora_config: - assert supports_lora(self.model), "Model does not support LoRA" + assert supports_lora( + self.model + ), f"{self.model.__class__.__name__} does not support LoRA yet." if supports_multimodal(self.model): logger.warning("Regarding multimodal models, vLLM currently " "only supports adding LoRA to language model.") From 9b85373a54bf8eb0ee1d0932ba48de1419ce6425 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 27 Sep 2024 14:18:47 +0800 Subject: [PATCH 12/18] Format code --- tests/lora/test_minicpmv.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py index d279513dabaa8..d0bfdf076a3b4 100644 --- a/tests/lora/test_minicpmv.py +++ b/tests/lora/test_minicpmv.py @@ -3,24 +3,21 @@ import pytest import vllm -from vllm.lora.request import LoRARequest from vllm.assets.image import ImageAsset - +from vllm.lora.request import LoRARequest MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" PROMPT_TEMPLATE = ( "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" "(./)\nWhat is in the image?<|eot_id|>" - "<|start_header_id|>assistant<|end_header_id|>\n\n" -) + "<|start_header_id|>assistant<|end_header_id|>\n\n") IMAGE_ASSETS = [ ImageAsset("stop_sign"), ImageAsset("cherry_blossom"), ] - # After fine-tuning with LoRA, all generated content should start begin `A`. EXPECTED_OUTPUT = [ "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501 @@ -35,20 +32,18 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: stop_token_ids=[128001, 128009], # eos_id, eot_id ) - inputs = [ - { - "prompt": PROMPT_TEMPLATE, - "multi_modal_data": {"image": asset.pil_image}, - } - for asset in IMAGE_ASSETS - ] + inputs = [{ + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": { + "image": asset.pil_image + }, + } for asset in IMAGE_ASSETS] outputs = llm.generate( inputs, sampling_params, lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id - else None, + if lora_id else None, ) # Print the outputs. generated_texts: List[str] = [] @@ -96,4 +91,3 @@ def test_minicpmv_tensor_parallel(minicpmv_lora_files, fully_sharded, tp): for i in range(len(EXPECTED_OUTPUT)): assert output_tp[i] == EXPECTED_OUTPUT[i] - From bf4ee9d6d13c9fa876ec39e5b09d35bde712660b Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 27 Sep 2024 16:47:30 +0800 Subject: [PATCH 13/18] Modify code --- tests/lora/test_minicpmv.py | 4 +++- vllm/lora/models.py | 6 +++--- vllm/model_executor/models/minicpmv.py | 22 +--------------------- vllm/model_executor/models/utils.py | 22 ++++++++++++++++++++-- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py index d0bfdf076a3b4..92cd155f56258 100644 --- a/tests/lora/test_minicpmv.py +++ b/tests/lora/test_minicpmv.py @@ -6,6 +6,8 @@ from vllm.assets.image import ImageAsset from vllm.lora.request import LoRARequest +from ..utils import multi_gpu_test + MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" PROMPT_TEMPLATE = ( @@ -73,7 +75,7 @@ def test_minicpmv_lora(minicpmv_lora_files): assert output2[i] == EXPECTED_OUTPUT[i] -# @pytest.mark.skip("Requires multiple GPUs") +@multi_gpu_test(num_gpus=4) @pytest.mark.parametrize("fully_sharded", [True, False]) @pytest.mark.parametrize("tp", [2, 4]) def test_minicpmv_tensor_parallel(minicpmv_lora_files, fully_sharded, tp): diff --git a/vllm/lora/models.py b/vllm/lora/models.py index c0b7ff8258a89..03c019b7f90a2 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -443,7 +443,7 @@ def _create_lora_modules(self): continue # A temporary approach for multimodal models to support LoRA # TODO: Remove this restriction - if self._filter_unsupported_module(module_name): + if self._filter_unsupported_mm_module(module_name): logger.warning( "Regarding multimodal models, vLLM currently only supports " "adding LoRA to language model, %s will be ignored.", @@ -501,7 +501,7 @@ def create_dummy_lora( if (not self._match_target_modules(module_name) or not isinstance(module, BaseLayerWithLoRA) or isinstance(module, LinearScalingRotaryEmbeddingWithLora) - or self._filter_unsupported_module(module_name)): + or self._filter_unsupported_mm_module(module_name)): continue parts = module_name.split(".") if module_name not in self.packed_modules: @@ -562,7 +562,7 @@ def _match_target_modules(self, module_name: str): module_name) or target_module == module_name for target_module in self.supported_lora_modules) - def _filter_unsupported_module(self, module_name: str) -> bool: + def _filter_unsupported_mm_module(self, module_name: str) -> bool: """ Regarding multimodal models, vLLM currently only supports adding LoRA to language model. LoRA for other modules, such as the vision tower, will diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index d793cfa83e01e..a56559c35e9bf 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -52,6 +52,7 @@ from vllm.model_executor.models.minicpm import MiniCPMModel from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2Model +from vllm.model_executor.models.utils import LLMWrapper from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -390,24 +391,6 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object): return MultiModalInputs(batch_data) -class LLMWrapper(nn.Module): - """ - To align with the key names of LoRA trained with PEFT, we need to add an - additional layer to the llm's implementation. - """ - - def __init__(self, llm: nn.Module, name: str) -> None: - super().__init__() - self.model_name = name - setattr(self, name, llm) - - def forward(self, *args, **kwargs) -> Any: - return getattr(self, self.model_name)(*args, **kwargs) - - def embed_tokens(self, *args, **kwargs) -> Any: - return getattr(self, self.model_name).embed_tokens(*args, **kwargs) - - class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): """ The abstract class of MiniCPMV can only be inherited, but cannot be @@ -904,9 +887,6 @@ def init_llm( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> nn.Module: - # return Qwen2Model(config, - # cache_config=cache_config, - # quant_config=quant_config) return LLMWrapper(Qwen2Model(config, cache_config=cache_config, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 38d6a4653ebd6..f6218bad4ef1e 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,7 +1,7 @@ import itertools from collections import UserDict -from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, - Union, overload) +from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol, + Tuple, Union, overload) import torch import torch.nn as nn @@ -329,3 +329,21 @@ def make_empty_intermediate_tensors( }) return make_empty_intermediate_tensors + + +class LLMWrapper(nn.Module): + """ + To align with the key names of LoRA trained with PEFT, we need to add an + additional layer to the llm's implementation. + """ + + def __init__(self, llm: nn.Module, name: str) -> None: + super().__init__() + self.model_name = name + setattr(self, name, llm) + + def forward(self, *args, **kwargs) -> Any: + return getattr(self, self.model_name)(*args, **kwargs) + + def embed_tokens(self, *args, **kwargs) -> Any: + return getattr(self, self.model_name).embed_tokens(*args, **kwargs) From a9e724c1a06c7f46b176f61356062f22f6a7a927 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 28 Sep 2024 00:31:00 +0800 Subject: [PATCH 14/18] Modify module_mapping logic --- vllm/model_executor/models/minicpmv.py | 6 ++-- vllm/model_executor/models/module_mapping.py | 36 +++++++++++++------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index a56559c35e9bf..89cdfbcc6afa9 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -627,9 +627,9 @@ def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models """ - return MultiModelKeys(language_model="llm", - connector="resampler", - tower_model="vpm") + return MultiModelKeys.from_string_field(language_model="llm", + connector="resampler", + tower_model="vpm") def init_llm( self, diff --git a/vllm/model_executor/models/module_mapping.py b/vllm/model_executor/models/module_mapping.py index 66f5427f8f30c..a9102a6073a2f 100644 --- a/vllm/model_executor/models/module_mapping.py +++ b/vllm/model_executor/models/module_mapping.py @@ -1,4 +1,4 @@ -# Copied code from +# Adapted from # https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py from dataclasses import dataclass, field @@ -44,16 +44,26 @@ class ModelKeys: @dataclass class MultiModelKeys(ModelKeys): - language_model: Union[List[str], str] = field(default_factory=list) - connector: Union[List[str], str] = field(default_factory=list) + language_model: List[str] = field(default_factory=list) + connector: List[str] = field(default_factory=list) # vision tower and audio tower - tower_model: Union[List[str], str] = field(default_factory=list) - generator: Union[List[str], str] = field(default_factory=list) - - def __post_init__(self): - for key in ["language_model", "connector", "tower_model", "generator"]: - v = getattr(self, key) - if isinstance(v, str): - setattr(self, key, [v]) - if v is None: - setattr(self, key, []) + tower_model: List[str] = field(default_factory=list) + generator: List[str] = field(default_factory=list) + + @staticmethod + def from_string_field(language_model: Union[str, List[str]] = None, + connector: Union[str, List[str]] = None, + tower_model: Union[str, List[str]] = None, + generator: Union[str, List[str]] = None, + **kwargs) -> 'MultiModelKeys': + + def to_list(value): + if value is None: + return [] + return [value] if isinstance(value, str) else list(value) + + return MultiModelKeys(language_model=to_list(language_model), + connector=to_list(connector), + tower_model=to_list(tower_model), + generator=to_list(generator), + **kwargs) From be6c92860fd29beca6d07af4c3205c1db1109f90 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 28 Sep 2024 00:43:13 +0800 Subject: [PATCH 15/18] Add unit test --- tests/lora/test_minicpmv.py | 24 --------- tests/lora/test_minicpmv_tp.py | 99 ++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 24 deletions(-) create mode 100644 tests/lora/test_minicpmv_tp.py diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py index 92cd155f56258..4860b72fc1f60 100644 --- a/tests/lora/test_minicpmv.py +++ b/tests/lora/test_minicpmv.py @@ -1,13 +1,9 @@ from typing import List -import pytest - import vllm from vllm.assets.image import ImageAsset from vllm.lora.request import LoRARequest -from ..utils import multi_gpu_test - MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" PROMPT_TEMPLATE = ( @@ -73,23 +69,3 @@ def test_minicpmv_lora(minicpmv_lora_files): output2 = do_sample(llm, minicpmv_lora_files, lora_id=2) for i in range(len(EXPECTED_OUTPUT)): assert output2[i] == EXPECTED_OUTPUT[i] - - -@multi_gpu_test(num_gpus=4) -@pytest.mark.parametrize("fully_sharded", [True, False]) -@pytest.mark.parametrize("tp", [2, 4]) -def test_minicpmv_tensor_parallel(minicpmv_lora_files, fully_sharded, tp): - llm = vllm.LLM( - MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=tp, - trust_remote_code=True, - fully_sharded_loras=fully_sharded, - ) - output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) - - for i in range(len(EXPECTED_OUTPUT)): - assert output_tp[i] == EXPECTED_OUTPUT[i] diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py new file mode 100644 index 0000000000000..40b1aa2701acf --- /dev/null +++ b/tests/lora/test_minicpmv_tp.py @@ -0,0 +1,99 @@ +from typing import List + +import pytest + +import vllm +from vllm.assets.image import ImageAsset +from vllm.lora.request import LoRARequest + +from ..utils import multi_gpu_test + +MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" + +PROMPT_TEMPLATE = ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + "(./)\nWhat is in the image?<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) + +IMAGE_ASSETS = [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), +] + +# After fine-tuning with LoRA, all generated content should start begin `A`. +EXPECTED_OUTPUT = [ + "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501 + "A pink cherry blossom tree with a blue sky in the background.", +] + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + sampling_params = vllm.SamplingParams( + temperature=0, + max_tokens=256, + stop_token_ids=[128001, 128009], # eos_id, eot_id + ) + + inputs = [ + { + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": {"image": asset.pil_image}, + } + for asset in IMAGE_ASSETS + ] + + outputs = llm.generate( + inputs, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id + else None, + ) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("fully_sharded", [True, False]) +def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded): + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=2, + trust_remote_code=True, + fully_sharded_loras=fully_sharded, + ) + + output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) + + for i in range(len(EXPECTED_OUTPUT)): + assert output_tp[i] == EXPECTED_OUTPUT[i] + + +@multi_gpu_test(num_gpus=4) +@pytest.mark.parametrize("fully_sharded", [True, False]) +def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded): + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=fully_sharded, + ) + + output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) + + for i in range(len(EXPECTED_OUTPUT)): + assert output_tp[i] == EXPECTED_OUTPUT[i] From c9db73e1f4d3f4e8c170346b44daef5a4bde2d1e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 28 Sep 2024 22:40:08 +0800 Subject: [PATCH 16/18] Modify unit test --- tests/lora/test_minicpmv.py | 2 +- tests/lora/test_minicpmv_tp.py | 24 +++++++++++------------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py index 4860b72fc1f60..4e1b5468641e8 100644 --- a/tests/lora/test_minicpmv.py +++ b/tests/lora/test_minicpmv.py @@ -56,7 +56,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: def test_minicpmv_lora(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, - max_model_len=1024, + max_num_seqs=2, enable_lora=True, max_loras=4, max_lora_rank=64, diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py index 40b1aa2701acf..260530627597b 100644 --- a/tests/lora/test_minicpmv_tp.py +++ b/tests/lora/test_minicpmv_tp.py @@ -13,8 +13,7 @@ PROMPT_TEMPLATE = ( "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" "(./)\nWhat is in the image?<|eot_id|>" - "<|start_header_id|>assistant<|end_header_id|>\n\n" -) + "<|start_header_id|>assistant<|end_header_id|>\n\n") IMAGE_ASSETS = [ ImageAsset("stop_sign"), @@ -35,20 +34,18 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: stop_token_ids=[128001, 128009], # eos_id, eot_id ) - inputs = [ - { - "prompt": PROMPT_TEMPLATE, - "multi_modal_data": {"image": asset.pil_image}, - } - for asset in IMAGE_ASSETS - ] + inputs = [{ + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": { + "image": asset.pil_image + }, + } for asset in IMAGE_ASSETS] outputs = llm.generate( inputs, sampling_params, lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id - else None, + if lora_id else None, ) # Print the outputs. generated_texts: List[str] = [] @@ -59,13 +56,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") return generated_texts + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("fully_sharded", [True, False]) def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded): llm = vllm.LLM( MODEL_PATH, enable_lora=True, - max_num_seqs=16, + max_num_seqs=2, max_loras=4, max_lora_rank=64, tensor_parallel_size=2, @@ -85,7 +83,7 @@ def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded): llm = vllm.LLM( MODEL_PATH, enable_lora=True, - max_num_seqs=16, + max_num_seqs=2, max_loras=4, max_lora_rank=64, tensor_parallel_size=4, From bbfd3e0b8032d504edd52339f542e89fca8a6765 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 29 Sep 2024 10:42:47 +0800 Subject: [PATCH 17/18] Delete mincpmv25 distributed test --- tests/lora/test_minicpmv_tp.py | 97 ---------------------------------- 1 file changed, 97 deletions(-) delete mode 100644 tests/lora/test_minicpmv_tp.py diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py deleted file mode 100644 index 260530627597b..0000000000000 --- a/tests/lora/test_minicpmv_tp.py +++ /dev/null @@ -1,97 +0,0 @@ -from typing import List - -import pytest - -import vllm -from vllm.assets.image import ImageAsset -from vllm.lora.request import LoRARequest - -from ..utils import multi_gpu_test - -MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" - -PROMPT_TEMPLATE = ( - "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" - "(./)\nWhat is in the image?<|eot_id|>" - "<|start_header_id|>assistant<|end_header_id|>\n\n") - -IMAGE_ASSETS = [ - ImageAsset("stop_sign"), - ImageAsset("cherry_blossom"), -] - -# After fine-tuning with LoRA, all generated content should start begin `A`. -EXPECTED_OUTPUT = [ - "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501 - "A pink cherry blossom tree with a blue sky in the background.", -] - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: - sampling_params = vllm.SamplingParams( - temperature=0, - max_tokens=256, - stop_token_ids=[128001, 128009], # eos_id, eot_id - ) - - inputs = [{ - "prompt": PROMPT_TEMPLATE, - "multi_modal_data": { - "image": asset.pil_image - }, - } for asset in IMAGE_ASSETS] - - outputs = llm.generate( - inputs, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, - ) - # Print the outputs. - generated_texts: List[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("fully_sharded", [True, False]) -def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded): - llm = vllm.LLM( - MODEL_PATH, - enable_lora=True, - max_num_seqs=2, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=2, - trust_remote_code=True, - fully_sharded_loras=fully_sharded, - ) - - output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) - - for i in range(len(EXPECTED_OUTPUT)): - assert output_tp[i] == EXPECTED_OUTPUT[i] - - -@multi_gpu_test(num_gpus=4) -@pytest.mark.parametrize("fully_sharded", [True, False]) -def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded): - llm = vllm.LLM( - MODEL_PATH, - enable_lora=True, - max_num_seqs=2, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=fully_sharded, - ) - - output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) - - for i in range(len(EXPECTED_OUTPUT)): - assert output_tp[i] == EXPECTED_OUTPUT[i] From acc836a6a375a82749e6fc394819108b2886f2c3 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 29 Sep 2024 12:44:52 +0800 Subject: [PATCH 18/18] Fix lora bug and modify minicpmv lora tests --- tests/lora/test_minicpmv.py | 6 +-- tests/lora/test_minicpmv_tp.py | 95 ++++++++++++++++++++++++++++++++++ vllm/lora/models.py | 17 +++--- 3 files changed, 108 insertions(+), 10 deletions(-) create mode 100644 tests/lora/test_minicpmv_tp.py diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py index 4e1b5468641e8..81b8188e638c9 100644 --- a/tests/lora/test_minicpmv.py +++ b/tests/lora/test_minicpmv.py @@ -26,7 +26,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: sampling_params = vllm.SamplingParams( temperature=0, - max_tokens=256, + max_tokens=5, stop_token_ids=[128001, 128009], # eos_id, eot_id ) @@ -65,7 +65,7 @@ def test_minicpmv_lora(minicpmv_lora_files): output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) for i in range(len(EXPECTED_OUTPUT)): - assert output1[i] == EXPECTED_OUTPUT[i] + assert EXPECTED_OUTPUT[i].startswith(output1[i]) output2 = do_sample(llm, minicpmv_lora_files, lora_id=2) for i in range(len(EXPECTED_OUTPUT)): - assert output2[i] == EXPECTED_OUTPUT[i] + assert EXPECTED_OUTPUT[i].startswith(output2[i]) diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py new file mode 100644 index 0000000000000..ba29e562e58ec --- /dev/null +++ b/tests/lora/test_minicpmv_tp.py @@ -0,0 +1,95 @@ +from typing import List + +import pytest + +import vllm +from vllm.assets.image import ImageAsset +from vllm.lora.request import LoRARequest + +from ..utils import multi_gpu_test + +MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" + +PROMPT_TEMPLATE = ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + "(./)\nWhat is in the image?<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n") + +IMAGE_ASSETS = [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), +] + +# After fine-tuning with LoRA, all generated content should start begin `A`. +EXPECTED_OUTPUT = [ + "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501 + "A pink cherry blossom tree with a blue sky in the background.", +] + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + sampling_params = vllm.SamplingParams( + temperature=0, + max_tokens=5, + stop_token_ids=[128001, 128009], # eos_id, eot_id + ) + + inputs = [{ + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": { + "image": asset.pil_image + }, + } for asset in IMAGE_ASSETS] + + outputs = llm.generate( + inputs, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None, + ) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("fully_sharded", [True, False]) +def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded): + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=2, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=2, + trust_remote_code=True, + fully_sharded_loras=fully_sharded, + ) + + output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) + + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) + + +@multi_gpu_test(num_gpus=4) +@pytest.mark.parametrize("fully_sharded", [True, False]) +def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded): + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=2, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=fully_sharded, + ) + output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 03c019b7f90a2..1f80c716bc481 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -456,13 +456,7 @@ def _create_lora_modules(self): self.model, module_name, from_layer(module, self.lora_slots, self.lora_config, packed_moduled_lst, self.model.config)) - # In some models, especially multimodal ones, layers with the same - # name may have different types, such as nn.Linear and - # ReplicatedLinear. The nn.Linear layers cannot be replaced with - # LoRA layers, leading to assertion error. The following check - # aims to prevent this error - if not isinstance(new_module, BaseLayerWithLoRA): - continue + # LinearScalingRotaryEmbeddingWithLora is used to handle # long context lora. Register relevant metadata. if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora): @@ -480,6 +474,15 @@ def _create_lora_modules(self): module, self.lora_slots, self.lora_config, self.model.config)) + + # In some models, especially multimodal ones, layers with the same + # name may have different types, such as nn.Linear and + # ReplicatedLinear. The nn.Linear layers cannot be replaced with + # LoRA layers, leading to assertion error. The following check + # aims to prevent this error + if self.supports_mm and not isinstance(new_module, + BaseLayerWithLoRA): + continue self.register_module(module_name, new_module) self._register_packed_modules(module_name) # All lora layers share the same punica_wrapper based on reference.