Skip to content

Commit

Permalink
[Model][LoRA]LoRA support added for MiniCPMV2.6 (vllm-project#8943)
Browse files Browse the repository at this point in the history
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
  • Loading branch information
2 people authored and garg-amit committed Oct 28, 2024
1 parent 02f63a9 commit 900231d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 880 deletions.
24 changes: 15 additions & 9 deletions vllm/model_executor/models/idefics2_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,10 @@ def __init__(self, config: Idefics2VisionConfig):
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)

def forward(
self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
) -> torch.Tensor:
def forward(self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
Expand All @@ -84,8 +83,13 @@ def forward(
fill_value=0)

for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()

if tgt_sizes is not None:
nb_patches_h = tgt_sizes[batch_idx][0]
nb_patches_w = tgt_sizes[batch_idx][1]
else:
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h,
Expand Down Expand Up @@ -287,10 +291,12 @@ def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
) -> torch.tensor:
tgt_sizes: Optional[torch.IntTensor] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask)
patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes)
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
101 changes: 34 additions & 67 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,15 @@
import torch.types
from PIL import Image
from torch import nn
from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (Resampler2,
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
get_2d_sincos_pos_embed)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
Expand Down Expand Up @@ -106,58 +104,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)


class BaseResampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""

def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
) -> None:
super().__init__()

self.num_queries = num_queries
self.embed_dim = embed_dim
self.num_heads = num_heads

self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
else:
# Maintain the same return value with ReplicatedLinear.forward
self.kv_proj = lambda *args, **kwargs: (
nn.Identity()(*args, **kwargs),
None,
)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.ln_post = norm_layer(embed_dim)
self.proj = nn.Parameter(
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))

def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)


class Resampler2_5(BaseResampler):

def __init__(
Expand Down Expand Up @@ -869,7 +815,35 @@ def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name


class MiniCPMV2_6(MiniCPMVBaseModel):
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision encoder
"fc1",
"fc2",
"out_proj",
# language model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
# resampler
"kv_proj",
]

embedding_modules = {}
embedding_padding_modules = []

def __init__(
self,
Expand All @@ -894,15 +868,8 @@ def init_llm(
name="model")

def init_vision_module(self) -> nn.Module:
# A custom version of SiglipVisionTransformer, won't work with TP
from vllm.model_executor.models.na_vit import SiglipVisionTransformer

if self.config._attn_implementation == "flash_attention_2":
self.config.vision_config._attn_implementation = "flash_attention_2"
else:
# not support sdpa
self.config.vision_config._attn_implementation = "eager"
model = SiglipVisionTransformer(self.config.vision_config)
model = Idefics2VisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
Expand All @@ -928,7 +895,7 @@ def get_vision_embedding(
pixel_values,
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
).last_hidden_state
)
return vision_embedding

def get_vision_hidden_states(
Expand Down Expand Up @@ -960,12 +927,12 @@ def get_vision_hidden_states(
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
).last_hidden_state
)

return self.resampler(vision_embedding, tgt_sizes)

def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name or "vpm" in name
return "resampler" in name


_SUPPORT_VERSION = {
Expand Down
Loading

0 comments on commit 900231d

Please sign in to comment.