Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][LoRA]LoRA support added for MiniCPMV2.5 #7199

Merged
merged 25 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def baichuan_zero_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")


@pytest.fixture(scope="session")
def minicpmv_lora_files():
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")


@pytest.fixture(scope="session")
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
Expand Down
71 changes: 71 additions & 0 deletions tests/lora/test_minicpmv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import List

import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest

MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"

PROMPT_TEMPLATE = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n")

IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]

# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
"A pink cherry blossom tree with a blue sky in the background.",
]


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=256,
stop_token_ids=[128001, 128009], # eos_id, eot_id
)

inputs = [{
"prompt": PROMPT_TEMPLATE,
"multi_modal_data": {
"image": asset.pil_image
},
} for asset in IMAGE_ASSETS]

outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True,
)

output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert output1[i] == EXPECTED_OUTPUT[i]
output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
for i in range(len(EXPECTED_OUTPUT)):
assert output2[i] == EXPECTED_OUTPUT[i]
99 changes: 99 additions & 0 deletions tests/lora/test_minicpmv_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import List

import pytest

import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest

from ..utils import multi_gpu_test

MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"

PROMPT_TEMPLATE = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)

IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]

# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
"A pink cherry blossom tree with a blue sky in the background.",
]


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=256,
stop_token_ids=[128001, 128009], # eos_id, eot_id
)

inputs = [
{
"prompt": PROMPT_TEMPLATE,
"multi_modal_data": {"image": asset.pil_image},
}
for asset in IMAGE_ASSETS
]

outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id
else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts

@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=2,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
)

output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)

for i in range(len(EXPECTED_OUTPUT)):
assert output_tp[i] == EXPECTED_OUTPUT[i]


@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
)

output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)

for i in range(len(EXPECTED_OUTPUT)):
assert output_tp[i] == EXPECTED_OUTPUT[i]
42 changes: 38 additions & 4 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from vllm.lora.punica import PunicaWrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.model_executor.models.interfaces import (SupportsLoRA,
supports_multimodal)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer
from vllm.utils import is_pin_memory_available

Expand Down Expand Up @@ -332,6 +334,8 @@ def __init__(
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping)
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = supports_multimodal(self.model)
self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
# Dict instead of a Set for compatibility with LRUCache.
Expand Down Expand Up @@ -437,12 +441,28 @@ def _create_lora_modules(self):
continue
if not self._match_target_modules(module_name):
continue
# A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction
if self._filter_unsupported_mm_module(module_name):
logger.warning(
"Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored.",
module_name,
)
continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
new_module = replace_submodule(
self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config,
packed_moduled_lst, self.model.config))
# In some models, especially multimodal ones, layers with the same
# name may have different types, such as nn.Linear and
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
# LoRA layers, leading to assertion error. The following check
# aims to prevent this error
if not isinstance(new_module, BaseLayerWithLoRA):
continue
# LinearScalingRotaryEmbeddingWithLora is used to handle
# long context lora. Register relevant metadata.
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
Expand Down Expand Up @@ -478,9 +498,10 @@ def create_dummy_lora(
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name) or not isinstance(
module, BaseLayerWithLoRA) or isinstance(
module, LinearScalingRotaryEmbeddingWithLora):
if (not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA)
or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
or self._filter_unsupported_mm_module(module_name)):
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
Expand Down Expand Up @@ -541,6 +562,19 @@ def _match_target_modules(self, module_name: str):
module_name) or target_module == module_name
for target_module in self.supported_lora_modules)

def _filter_unsupported_mm_module(self, module_name: str) -> bool:
"""
Regarding multimodal models, vLLM currently only supports adding LoRA to
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
"""
if self.supports_mm:
prefix = module_name.split(".")[0]
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
return (prefix in module_mapping.connector
or prefix in module_mapping.tower_model)
return False

def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")
module_name = parts[-1]
Expand Down
Loading
Loading