From ad137cd1112ab9b17ac36fc123fc7806a1d7473d Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 20 Jun 2024 04:52:09 -0700 Subject: [PATCH] [Model] Port over CLIPVisionModel for VLMs (#5591) --- csrc/activation_kernels.cu | 12 ++ csrc/ops.h | 2 + csrc/torch_bindings.cpp | 4 + vllm/_custom_ops.py | 4 + vllm/model_executor/layers/activation.py | 16 ++ vllm/model_executor/models/clip.py | 203 +++++++++++++++++++++++ vllm/model_executor/models/llava.py | 17 +- vllm/model_executor/models/llava_next.py | 19 ++- vllm/model_executor/models/phi3v.py | 13 +- 9 files changed, 269 insertions(+), 21 deletions(-) create mode 100644 vllm/model_executor/models/clip.py diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 86ac2e75e78ee..5ed1dc3b8f792 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -135,6 +135,12 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) { return ((T)0.5) * x * (((T)1.0) + t); } +template +__device__ __forceinline__ T gelu_quick_kernel(const T& x) { + // x * sigmoid(1.702 * x) + return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x))); +} + } // namespace vllm void gelu_new(torch::Tensor& out, // [..., d] @@ -148,3 +154,9 @@ void gelu_fast(torch::Tensor& out, // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } + +void gelu_quick(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] +{ + LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel); +} diff --git a/csrc/ops.h b/csrc/ops.h index 9e2e977fa3c2e..ba92cc5373d7a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -49,6 +49,8 @@ void gelu_new(torch::Tensor& out, torch::Tensor& input); void gelu_fast(torch::Tensor& out, torch::Tensor& input); +void gelu_quick(torch::Tensor& out, torch::Tensor& input); + #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 867bf438937cd..953f2eb4d8e7d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -68,6 +68,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_fast", torch::kCUDA, &gelu_fast); + // Quick GELU implementation. + ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_quick", torch::kCUDA, &gelu_quick); + // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ab2a67950bfea..a053a3aa237e7 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -66,6 +66,10 @@ def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: torch.ops._C.gelu_new(out, x) +def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + torch.ops._C.gelu_quick(out, x) + + # page attention ops def paged_attention_v1( out: torch.Tensor, diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index eb0606948686d..80cad15b43426 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -141,6 +141,21 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: return out +class QuickGELU(CustomOp): + + # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return x * torch.sigmoid(1.702 * x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + from vllm import _custom_ops as ops + + out = torch.empty_like(x) + ops.gelu_quick(out, x) + return out + + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. @@ -189,6 +204,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): "gelu_new": NewGELU(), "gelu_pytorch_tanh": nn.GELU(approximate="tanh"), "relu": nn.ReLU(), + "quick_gelu": QuickGELU(), } diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py new file mode 100644 index 0000000000000..aa4e87228a7e4 --- /dev/null +++ b/vllm/model_executor/models/clip.py @@ -0,0 +1,203 @@ +"""Minimal implementation of CLIPVisionModel intended to be only used +within a vision language model.""" +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from transformers import CLIPVisionConfig +from transformers.models.clip.modeling_clip import CLIPAttention + +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + + +def get_clip_num_patches(image_size: int, patch_size: int) -> int: + assert image_size % patch_size == 0 + return (image_size // patch_size)**2 + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa +class CLIPVisionEmbeddings(nn.Module): + + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = get_clip_num_patches(self.image_size, + self.patch_size) + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + self.register_buffer("position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + + return embeddings + + +class CLIPMLP(nn.Module): + + def __init__(self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config) + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class CLIPEncoderLayer(nn.Module): + + def __init__(self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + + self.self_attn = CLIPAttention(config) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config, quant_config=quant_config) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: + + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self + attention layers. Each layer is a [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__(self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + CLIPEncoderLayer(config=config, quant_config=quant_config) + for _ in range(config.num_hidden_layers) + ]) + + def forward(self, + inputs_embeds: torch.Tensor, + vision_feature_layer: int = -1): + + # Encoder forward pass only up to the required layer + num_layer = len(self.layers) + vision_feature_layer + 1 + hidden_states = inputs_embeds + for encoder_layer in self.layers[:num_layer]: + hidden_states = encoder_layer(hidden_states) + + return hidden_states + + +class CLIPVisionTransformer(nn.Module): + + def __init__(self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config) + + # NOTE: This typo of "layrnorm" is not fixed on purpose to match + # the original transformers code and name of the model weights. + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = CLIPEncoder(config=config, quant_config=quant_config) + + def forward( + self, + pixel_values: torch.Tensor, + vision_feature_layer: int = -1, + ) -> torch.Tensor: + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + hidden_states = self.encoder(inputs_embeds=hidden_states, + vision_feature_layer=vision_feature_layer) + + return hidden_states + + +class CLIPVisionModel(nn.Module): + + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + + def __init__(self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.vision_model = CLIPVisionTransformer(config=config, + quant_config=quant_config) + + def forward(self, + pixel_values: Optional[torch.Tensor] = None, + vision_feature_layer: int = -1): + + return self.vision_model(pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer) + + @property + def device(self): + return next(self.parameters()).device diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 39355b9d3ab44..8e36c54b1c511 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -2,9 +2,7 @@ import torch import torch.nn as nn -# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on -# transformers' impl. -from transformers import CLIPVisionModel, LlavaConfig +from transformers import LlavaConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VisionLanguageConfig @@ -15,6 +13,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -189,12 +188,11 @@ def _select_image_features(self, image_features: torch.Tensor, *, def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, pixel_values: torch.Tensor) -> torch.Tensor: - # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. - image_outputs = vision_tower(pixel_values.to(vision_tower.device), - output_hidden_states=True) - image_features = image_outputs.hidden_states[ - self.config.vision_feature_layer] + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_tower(pixel_values.to(vision_tower.device), + self.config.vision_feature_layer) return self._select_image_features( image_features, @@ -317,6 +315,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue + # post_layernorm is not needed in CLIPVisionModel + if "vision_model.post_layernorm" in name: + continue for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in name: name = name.replace(key_to_modify, new_key) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 0ab9afea9ac69..c1158c933c88b 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -4,9 +4,7 @@ import torch import torch.nn as nn from PIL import Image -# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on -# transformers' impl. -from transformers import CLIPVisionModel, LlavaNextConfig +from transformers import LlavaNextConfig from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) from typing_extensions import NotRequired @@ -20,6 +18,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData @@ -121,7 +120,7 @@ def __init__(self, if self.vision_language_config.image_input_type == ( VisionLanguageConfig.ImageInputType.PIXEL_VALUES): - self.vision_tower = CLIPVisionModel(config.vision_config) + self.vision_tower = CLIPVisionModel(config=config.vision_config) else: raise TypeError("Image features are not supported by LLaVA-NeXT") @@ -219,12 +218,11 @@ def _select_image_features(self, image_features: torch.Tensor, *, def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, pixel_values: torch.Tensor) -> torch.Tensor: - # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. - image_outputs = vision_tower(pixel_values.to(vision_tower.device), - output_hidden_states=True) - image_features = image_outputs.hidden_states[ - self.config.vision_feature_layer] + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_tower(pixel_values.to(vision_tower.device), + self.config.vision_feature_layer) return self._select_image_features( image_features, @@ -430,6 +428,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue + # post_layernorm is not needed in CLIPVisionModel + if "vision_model.post_layernorm" in name: + continue for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in name: name = name.replace(key_to_modify, new_key) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 35f3b894f099a..fa20a7c5903d6 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn -from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig +from transformers import CLIPVisionConfig, PretrainedConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VisionLanguageConfig @@ -27,6 +27,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.vlm_base import VisionLanguageModelBase from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -70,9 +71,10 @@ def get_img_features(self, LAYER_IDX = self.layer_idx TYPE_FEATURE = self.type_feature - img_processor_output = self.img_processor(img_embeds, - output_hidden_states=True) - img_feature = img_processor_output.hidden_states[LAYER_IDX] + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the img_processor + img_feature = self.img_processor(img_embeds, + vision_feature_layer=LAYER_IDX) if TYPE_FEATURE == "patch": patch_feature = img_feature[:, 1:] @@ -352,6 +354,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue + # post_layernorm is not needed in CLIPVisionModel + if "vision_model.post_layernorm" in name: + continue for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in name: name = name.replace(key_to_modify, new_key)