Skip to content

Commit

Permalink
[Model] Port over CLIPVisionModel for VLMs (vllm-project#5591)
Browse files Browse the repository at this point in the history
  • Loading branch information
ywang96 authored Jun 20, 2024
1 parent 111af1f commit ad137cd
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 21 deletions.
12 changes: 12 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
return ((T)0.5) * x * (((T)1.0) + t);
}

template <typename T>
__device__ __forceinline__ T gelu_quick_kernel(const T& x) {
// x * sigmoid(1.702 * x)
return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x)));
}

} // namespace vllm

void gelu_new(torch::Tensor& out, // [..., d]
Expand All @@ -148,3 +154,9 @@ void gelu_fast(torch::Tensor& out, // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}

void gelu_quick(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
}
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);

void gelu_quick(torch::Tensor& out, torch::Tensor& input);

#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
Expand Down
4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);

// Quick GELU implementation.
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
Expand Down
4 changes: 4 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_new(out, x)


def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_quick(out, x)


# page attention ops
def paged_attention_v1(
out: torch.Tensor,
Expand Down
16 changes: 16 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,21 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return out


class QuickGELU(CustomOp):

# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return x * torch.sigmoid(1.702 * x)

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

out = torch.empty_like(x)
ops.gelu_quick(out, x)
return out


class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
Expand Down Expand Up @@ -189,6 +204,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
"gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
"quick_gelu": QuickGELU(),
}


Expand Down
203 changes: 203 additions & 0 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from typing import Optional, Tuple

import torch
import torch.nn as nn
from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPAttention

from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)


def get_clip_num_patches(image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
return (image_size // patch_size)**2


# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):

def __init__(self, config: CLIPVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size

self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)

self.num_patches = get_clip_num_patches(self.image_size,
self.patch_size)
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
self.register_buffer("position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False)

def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(
dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)

return embeddings


class CLIPMLP(nn.Module):

def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config)
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)

return hidden_states


class CLIPEncoderLayer(nn.Module):

def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()

self.self_attn = CLIPAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config)
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)

def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:

residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

return hidden_states


class CLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""

def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
CLIPEncoderLayer(config=config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])

def forward(self,
inputs_embeds: torch.Tensor,
vision_feature_layer: int = -1):

# Encoder forward pass only up to the required layer
num_layer = len(self.layers) + vision_feature_layer + 1
hidden_states = inputs_embeds
for encoder_layer in self.layers[:num_layer]:
hidden_states = encoder_layer(hidden_states)

return hidden_states


class CLIPVisionTransformer(nn.Module):

def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
embed_dim = config.hidden_size

self.embeddings = CLIPVisionEmbeddings(config)

# NOTE: This typo of "layrnorm" is not fixed on purpose to match
# the original transformers code and name of the model weights.
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder(config=config, quant_config=quant_config)

def forward(
self,
pixel_values: torch.Tensor,
vision_feature_layer: int = -1,
) -> torch.Tensor:

hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
hidden_states = self.encoder(inputs_embeds=hidden_states,
vision_feature_layer=vision_feature_layer)

return hidden_states


class CLIPVisionModel(nn.Module):

config_class = CLIPVisionConfig
main_input_name = "pixel_values"

def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.vision_model = CLIPVisionTransformer(config=config,
quant_config=quant_config)

def forward(self,
pixel_values: Optional[torch.Tensor] = None,
vision_feature_layer: int = -1):

return self.vision_model(pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer)

@property
def device(self):
return next(self.parameters()).device
17 changes: 9 additions & 8 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import torch
import torch.nn as nn
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaConfig
from transformers import LlavaConfig

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
Expand All @@ -15,6 +13,7 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
Expand Down Expand Up @@ -189,12 +188,11 @@ def _select_image_features(self, image_features: torch.Tensor, *,

def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
output_hidden_states=True)

image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values.to(vision_tower.device),
self.config.vision_feature_layer)

return self._select_image_features(
image_features,
Expand Down Expand Up @@ -317,6 +315,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
Expand Down
19 changes: 10 additions & 9 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import torch
import torch.nn as nn
from PIL import Image
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaNextConfig
from transformers import LlavaNextConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
Expand All @@ -20,6 +18,7 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
Expand Down Expand Up @@ -121,7 +120,7 @@ def __init__(self,

if self.vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config)
self.vision_tower = CLIPVisionModel(config=config.vision_config)
else:
raise TypeError("Image features are not supported by LLaVA-NeXT")

Expand Down Expand Up @@ -219,12 +218,11 @@ def _select_image_features(self, image_features: torch.Tensor, *,

def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
output_hidden_states=True)

image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values.to(vision_tower.device),
self.config.vision_feature_layer)

return self._select_image_features(
image_features,
Expand Down Expand Up @@ -430,6 +428,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
Expand Down
Loading

0 comments on commit ad137cd

Please sign in to comment.