From 900231d038ca255cfc7d3a95db7668c388f9c648 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 30 Sep 2024 12:31:55 +0800 Subject: [PATCH] [Model][LoRA]LoRA support added for MiniCPMV2.6 (#8943) Co-authored-by: DarkLight1337 Signed-off-by: Amit Garg --- .../models/idefics2_vision_model.py | 24 +- vllm/model_executor/models/minicpmv.py | 101 +-- vllm/model_executor/models/na_vit.py | 804 ------------------ 3 files changed, 49 insertions(+), 880 deletions(-) delete mode 100644 vllm/model_executor/models/na_vit.py diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index cc448ed28d2dc..3b0b6febaa48c 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -65,11 +65,10 @@ def __init__(self, config: Idefics2VisionConfig): self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) - def forward( - self, - pixel_values: torch.FloatTensor, - patch_attention_mask: torch.BoolTensor, - ) -> torch.Tensor: + def forward(self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor: batch_size, _, max_im_h, max_im_w = pixel_values.shape patch_embeds = self.patch_embedding(pixel_values) embeddings = patch_embeds.flatten(2).transpose(1, 2) @@ -84,8 +83,13 @@ def forward( fill_value=0) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): - nb_patches_h = p_attn_mask[:, 0].sum() - nb_patches_w = p_attn_mask[0].sum() + + if tgt_sizes is not None: + nb_patches_h = tgt_sizes[batch_idx][0] + nb_patches_w = tgt_sizes[batch_idx][1] + else: + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) bucket_coords_h = torch.bucketize(fractional_coords_h, @@ -287,10 +291,12 @@ def forward( self, pixel_values, patch_attention_mask: Optional[torch.BoolTensor] = None, - ) -> torch.tensor: + tgt_sizes: Optional[torch.IntTensor] = None, + ) -> torch.Tensor: hidden_states = self.embeddings( pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask) + patch_attention_mask=patch_attention_mask, + tgt_sizes=tgt_sizes) encoder_outputs = self.encoder(hidden_states) last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 89cdfbcc6afa9..aaae4397c01d2 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -31,17 +31,15 @@ import torch.types from PIL import Image from torch import nn -from torch.nn.init import trunc_normal_ from transformers import PretrainedConfig from typing_extensions import NotRequired from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.resampler import (Resampler2, +from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, get_2d_sincos_pos_embed) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -106,58 +104,6 @@ class MiniCPMVImagePixelInputs(TypedDict): DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) -class BaseResampler(nn.Module): - """ - A 2D perceiver-resampler network with one cross attention layers by - (grid_size**2) learnable queries and 2d sincos pos_emb - Outputs: - A tensor with the shape of (grid_size**2, embed_dim) - """ - - def __init__( - self, - num_queries: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - ) -> None: - super().__init__() - - self.num_queries = num_queries - self.embed_dim = embed_dim - self.num_heads = num_heads - - self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) - trunc_normal_(self.query, std=0.02) - if kv_dim is not None and kv_dim != embed_dim: - self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False) - else: - # Maintain the same return value with ReplicatedLinear.forward - self.kv_proj = lambda *args, **kwargs: ( - nn.Identity()(*args, **kwargs), - None, - ) - self.attn = nn.MultiheadAttention(embed_dim, num_heads) - self.ln_q = norm_layer(embed_dim) - self.ln_kv = norm_layer(embed_dim) - self.ln_post = norm_layer(embed_dim) - self.proj = nn.Parameter( - (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim)) - - def _init_weights(self, m: nn.Module) -> None: - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def _repeat(self, query, N: int): - return query.unsqueeze(1).repeat(1, N, 1) - - class Resampler2_5(BaseResampler): def __init__( @@ -869,7 +815,35 @@ def is_default_weight_loading(self, name: str) -> bool: return "resampler" in name -class MiniCPMV2_6(MiniCPMVBaseModel): +class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + # vision encoder + "fc1", + "fc2", + "out_proj", + # language model + "qkv_proj", # same name with vision encoder + "o_proj", + "gate_up_proj", + "down_proj", + # resampler + "kv_proj", + ] + + embedding_modules = {} + embedding_padding_modules = [] def __init__( self, @@ -894,15 +868,8 @@ def init_llm( name="model") def init_vision_module(self) -> nn.Module: - # A custom version of SiglipVisionTransformer, won't work with TP - from vllm.model_executor.models.na_vit import SiglipVisionTransformer - if self.config._attn_implementation == "flash_attention_2": - self.config.vision_config._attn_implementation = "flash_attention_2" - else: - # not support sdpa - self.config.vision_config._attn_implementation = "eager" - model = SiglipVisionTransformer(self.config.vision_config) + model = Idefics2VisionTransformer(self.config.vision_config) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model @@ -928,7 +895,7 @@ def get_vision_embedding( pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes, - ).last_hidden_state + ) return vision_embedding def get_vision_hidden_states( @@ -960,12 +927,12 @@ def get_vision_hidden_states( all_pixel_values.type(dtype), patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes, - ).last_hidden_state + ) return self.resampler(vision_embedding, tgt_sizes) def is_default_weight_loading(self, name: str) -> bool: - return "resampler" in name or "vpm" in name + return "resampler" in name _SUPPORT_VERSION = { diff --git a/vllm/model_executor/models/na_vit.py b/vllm/model_executor/models/na_vit.py deleted file mode 100644 index 1d6f26f0d4fb5..0000000000000 --- a/vllm/model_executor/models/na_vit.py +++ /dev/null @@ -1,804 +0,0 @@ -import logging -import math -import os -import warnings -from typing import Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn -from torch.nn.init import _calculate_fan_in_and_fan_out -from transformers.activations import ACT2FN -from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask -from transformers.modeling_outputs import (BaseModelOutput, - BaseModelOutputWithPooling) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import (ModelOutput, is_flash_attn_2_available, - replace_return_docstrings) - -logger = logging.getLogger("vllm") - - -# For Siglip: copied from -# HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes -# Remove hints as there's little possibility to change these code. -class SiglipVisionConfig(PretrainedConfig): - - model_type = "siglip_vision_model" - - def __init__( - self, - hidden_size=768, - intermediate_size=3072, - num_hidden_layers=12, - num_attention_heads=12, - num_channels=3, - image_size=224, - patch_size=16, - hidden_act="gelu_pytorch_tanh", - layer_norm_eps=1e-6, - attention_dropout=0.0, - **kwargs, - ): - super().__init__(**kwargs) - - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_channels = num_channels - self.patch_size = patch_size - self.image_size = image_size - self.attention_dropout = attention_dropout - self.layer_norm_eps = layer_norm_eps - self.hidden_act = hidden_act - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, - os.PathLike], - **kwargs) -> "PretrainedConfig": - cls._set_token_in_kwargs(kwargs) - - config_dict, kwargs = cls.get_config_dict( - pretrained_model_name_or_path, **kwargs) - - # get the vision config dict if we are loading from SiglipConfig - if config_dict.get("model_type") == "siglip": - config_dict = config_dict["vision_config"] - - if "model_type" in config_dict and hasattr( - cls, - "model_type") and config_dict["model_type"] != cls.model_type: - logger.warning( - "You are using a model of type %s to " - "instantiate a model of type %s. " - "This is not supported for all configurations" - "of models and can yield errors.", config_dict['model_type'], - cls.model_type) - - return cls.from_dict(config_dict, **kwargs) - - -_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" - -SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "google/siglip-base-patch16-224", - # See all SigLIP models at https://huggingface.co/models?filter=siglip -] - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import pad_input # noqa - from flash_attn.bert_padding import index_first_axis, unpad_input - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def _trunc_normal_(tensor, mean, std, a, b): - - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l_ = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l_ - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - if tensor.dtype in [torch.float16, torch.bfloat16]: - # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu - og_dtype = tensor.dtype - tensor = tensor.to(torch.float32) - tensor.erfinv_() - tensor = tensor.to(og_dtype) - else: - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - if tensor.dtype == torch.float16: - # The `clamp_` op is not (yet?) defined in float16+cpu - tensor = tensor.to(torch.float32) - tensor.clamp_(min=a, max=b) - tensor = tensor.to(torch.float16) - else: - tensor.clamp_(min=a, max=b) - - -def trunc_normal_tf_(tensor: torch.Tensor, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0) -> torch.Tensor: - with torch.no_grad(): - _trunc_normal_(tensor, 0, 1.0, a, b) - tensor.mul_(std).add_(mean) - - -def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - if mode == "fan_in": - denom = fan_in - elif mode == "fan_out": - denom = fan_out - elif mode == "fan_avg": - denom = (fan_in + fan_out) / 2 - - variance = scale / denom - - if distribution == "truncated_normal": - # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) - elif distribution == "normal": - with torch.no_grad(): - tensor.normal_(std=math.sqrt(variance)) - elif distribution == "uniform": - bound = math.sqrt(3 * variance) - with torch.no_grad(): - tensor.uniform_(-bound, bound) - else: - raise ValueError(f"invalid distribution {distribution}") - - -def lecun_normal_(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") - - -def default_flax_embed_init(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="normal") - - -class SiglipVisionModelOutput(ModelOutput): - image_embeds: Optional[torch.FloatTensor] = None - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -class SiglipVisionEmbeddings(nn.Module): - - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.image_size = config.image_size - self.patch_size = config.patch_size - - self.patch_embedding = nn.Conv2d( - in_channels=config.num_channels, - out_channels=self.embed_dim, - kernel_size=self.patch_size, - stride=self.patch_size, - padding="valid", - ) - - self.num_patches_per_side = self.image_size // self.patch_size - self.num_patches = self.num_patches_per_side**2 - self.num_positions = self.num_patches - self.position_embedding = nn.Embedding(self.num_positions, - self.embed_dim) - - def forward(self, - pixel_values: torch.FloatTensor, - patch_attention_mask: torch.BoolTensor, - tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor: - batch_size = pixel_values.size(0) - - patch_embeds = self.patch_embedding(pixel_values) - embeddings = patch_embeds.flatten(2).transpose(1, 2) - - max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) - max_nb_patches_h, max_nb_patches_w = (max_im_h // self.patch_size, - max_im_w // self.patch_size) - boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, - 1 / self.num_patches_per_side) - position_ids = torch.full( - size=( - batch_size, - max_nb_patches_h * max_nb_patches_w, - ), - fill_value=0, - ) - - for batch_idx, p_attn_mask in enumerate(patch_attention_mask): - if tgt_sizes is not None: - nb_patches_h = tgt_sizes[batch_idx][0] - nb_patches_w = tgt_sizes[batch_idx][1] - else: - nb_patches_h = p_attn_mask[:, 0].sum() - nb_patches_w = p_attn_mask[0].sum() - - fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) - fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) - - bucket_coords_h = torch.bucketize(fractional_coords_h, - boundaries, - right=True) - bucket_coords_w = torch.bucketize(fractional_coords_w, - boundaries, - right=True) - - pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + - bucket_coords_w).flatten() - position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids - - position_ids = position_ids.to(self.position_embedding.weight.device) - - embeddings = embeddings + self.position_embedding(position_ids) - return embeddings - - -class SiglipAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ - def __init__(self, config): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - "embed_dim must be divisible by num_heads (got `embed_dim`: " - f"{self.embed_dim} and `num_heads`:" - f" {self.num_heads}).") - self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout - - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - batch_size, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - - k_v_seq_len = key_states.shape[-2] - attn_weights = torch.matmul(query_states, key_states.transpose( - 2, 3)) * self.scale - - if attn_weights.size() != (batch_size, self.num_heads, q_len, - k_v_seq_len): - raise ValueError( - "Attention weights should be of size " - f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}") - - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): - raise ValueError( - "Attention mask should be of size " - f"{(batch_size, 1, q_len, k_v_seq_len)}", - f"but is {attention_mask.size()}") - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, - dtype=torch.float32).to( - query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, - p=self.dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (batch_size, self.num_heads, q_len, - self.head_dim): - raise ValueError( - "`attn_output` should be of size " - f"{(batch_size, self.num_heads, q_len, self.head_dim)}, " - "but is" - f" {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -class SiglipFlashAttention2(SiglipAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_causal = False # Hack to make sure we don't use a causal mask - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length( - kv_seq_len, self.layer_idx) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.dropout if self.training else 0.0 - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning( - "The input hidden states seems to be " - "silently casted in float32, " - "this might be related to the fact " - "you have upcasted embedding or layer norm layers in float32. " - "We will cast back the input in" - " %s.", target_dtype) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = self._flash_attention_forward(query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate) - - attn_output = attn_output.reshape(bsz, q_len, - self.embed_dim).contiguous() - attn_output = self.out_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights - - def _flash_attention_forward(self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None): - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - (query_states, key_states, value_states, indices_q, cu_seq_lens, - max_seq_lens) = self._upad_input(query_states, key_states, - value_states, attention_mask, - query_length) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, - query_length) - else: - attn_output = flash_attn_func(query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal) - - return attn_output - - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, - query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( - attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, - head_dim), indices_k) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, - head_dim), indices_k) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, - head_dim), indices_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - (query_layer, indices_q, cu_seqlens_q, - max_seqlen_in_batch_q) = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip -class SiglipMLP(nn.Module): - - def __init__(self, config): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer -# with CLIP->Siglip -class SiglipEncoderLayer(nn.Module): - - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.embed_dim = config.hidden_size - self._use_flash_attention_2 = ( - config._attn_implementation == "flash_attention_2") - self.self_attn = (SiglipAttention(config) - if not self._use_flash_attention_2 else - SiglipFlashAttention2(config)) - self.layer_norm1 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) - self.mlp = SiglipMLP(config) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor]: - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states, ) - - if output_attentions: - outputs += (attn_weights, ) - - return outputs - - -class SiglipPreTrainedModel(PreTrainedModel): - config_class = SiglipVisionConfig - base_model_prefix = "siglip" - supports_gradient_checkpointing = True - - def _init_weights(self, module): - """Initialize the weights""" - - if isinstance(module, SiglipVisionEmbeddings): - width = self.config.hidden_size - nn.init.normal_(module.position_embedding.weight, - std=1 / np.sqrt(width)) - elif isinstance(module, nn.Embedding): - default_flax_embed_init(module.weight) - elif isinstance(module, SiglipAttention): - nn.init.normal_(module.q_proj.weight) - nn.init.normal_(module.k_proj.weight) - nn.init.normal_(module.v_proj.weight) - nn.init.normal_(module.out_proj.weight) - nn.init.zeros_(module.q_proj.bias) - nn.init.zeros_(module.k_proj.bias) - nn.init.zeros_(module.v_proj.bias) - nn.init.zeros_(module.out_proj.bias) - elif isinstance(module, SiglipMLP): - nn.init.normal_(module.fc1.weight) - nn.init.normal_(module.fc2.weight) - nn.init.normal_(module.fc1.bias, std=1e-6) - nn.init.normal_(module.fc2.bias, std=1e-6) - elif isinstance(module, (nn.Linear, nn.Conv2d)): - lecun_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - -# Copied from transformers.models.clip.modeling_clip.CLIPEncoder -# with CLIP->Siglip -class SiglipEncoder(nn.Module): - - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.config = config - self.layers = nn.ModuleList([ - SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers) - ]) - self.gradient_checkpointing = False - - # Ignore copy - def forward( - self, - inputs_embeds, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: - output_attentions = output_attentions if output_attentions is not None \ - else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None \ - else self.config.use_return_dict - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - hidden_states = inputs_embeds - for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states, ) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1], ) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states, ) - - if not return_dict: - return tuple( - v for v in [hidden_states, encoder_states, all_attentions] - if v is not None) - return BaseModelOutput(last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions) - - -class SiglipVisionTransformer(SiglipPreTrainedModel): - config_class = SiglipVisionConfig - main_input_name = "pixel_values" - _supports_flash_attn_2 = True - - def __init__(self, config: SiglipVisionConfig): - super().__init__(config) - self.config = config - embed_dim = config.hidden_size - - self.embeddings = SiglipVisionEmbeddings(config) - self.encoder = SiglipEncoder(config) - self.post_layernorm = nn.LayerNorm(embed_dim, - eps=config.layer_norm_eps) - self._use_flash_attention_2 = ( - config._attn_implementation == "flash_attention_2") - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.embeddings.patch_embedding - - @replace_return_docstrings(output_type=BaseModelOutputWithPooling, - config_class=SiglipVisionConfig) - def forward( - self, - pixel_values, - patch_attention_mask: Optional[torch.BoolTensor] = None, - tgt_sizes: Optional[torch.IntTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: - r""" - Returns: - """ - output_attentions = output_attentions if output_attentions is not None \ - else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None \ - else self.config.use_return_dict - - batch_size = pixel_values.size(0) - if patch_attention_mask is None: - patch_attention_mask = torch.ones( - size=( - batch_size, - pixel_values.size(2) // self.config.patch_size, - pixel_values.size(3) // self.config.patch_size, - ), - dtype=torch.bool, - device=pixel_values.device, - ) - - hidden_states = self.embeddings( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - tgt_sizes=tgt_sizes) - - patch_attention_mask = patch_attention_mask.view(batch_size, -1) - # The call to `_upad_input` in `_flash_attention_forward` is expensive - # So when the `patch_attention_mask` is full of 1s - # (i.e. attending to the whole sequence), - # avoiding passing the attention_mask, - # which is equivalent to attending to the full sequence - if not torch.any(~patch_attention_mask): - attention_mask = None - else: - attention_mask = (_prepare_4d_attention_mask( - patch_attention_mask, hidden_states.dtype) - if not self._use_flash_attention_2 else - patch_attention_mask) - - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_state = encoder_outputs[0] - last_hidden_state = self.post_layernorm(last_hidden_state) - - if not return_dict: - return (last_hidden_state, None) + encoder_outputs[1:] - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=None, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - )