Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Upgrade Aria to transformers 4.48 #12203

Merged
merged 10 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,8 @@ def run_aria(question: str, modality: str):

# NOTE: Need L40 (or equivalent) to avoid OOM
llm = LLM(model=model_name,
tokenizer_mode="slow",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)

prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
Expand Down
148 changes: 45 additions & 103 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)

import torch
import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig
from transformers import (AriaConfig, AriaTextConfig, BatchFeature,
PretrainedConfig)
from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
Expand All @@ -26,10 +29,11 @@
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
AriaVisionConfig)

from .idefics2_vision_model import Idefics2VisionTransformer
# yapf: disable
from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
from .interfaces import SupportsMultiModal
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
Expand All @@ -47,35 +51,18 @@ class AriaImagePixelInputs(TypedDict):
"""


class AriaVisionTransformer(Idefics2VisionTransformer):
"""
AriaVisionTransformer is a modified version of Idefics2VisionTransformer
that replaces the post-layernorm with an identity layer.
"""

def __init__(
self,
config: AriaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, quant_config, prefix)
self.post_layernorm = nn.Identity()


class AriaVisionModel(nn.Module):
config_class = AriaVisionConfig

def __init__(
self,
config: AriaVisionConfig,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
) -> None:
super().__init__()

self.vision_model = AriaVisionTransformer(
self.vision_model = Idefics3VisionTransformer(
config,
quant_config,
prefix=f"{prefix}.vision_model",
Expand Down Expand Up @@ -122,7 +109,7 @@ def _create_image_attention_mask(
return torch.logical_not(flattened_mask)


class FFN(nn.Module):
class AriaProjectorMLP(nn.Module):

def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None:
super().__init__()
Expand All @@ -137,46 +124,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class CrossAttention(nn.Module):

def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None:
super().__init__()
self.num_heads = num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False)

self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.linear = nn.Linear(embed_dim, embed_dim)

self.layer_norm = nn.LayerNorm(embed_dim)
self.ln_kv = nn.LayerNorm(kv_dim)

def forward(
self,
x: torch.Tensor,
hidden_states: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
normed_hidden_states = self.layer_norm(hidden_states)
query = self.q_proj(normed_hidden_states).permute(1, 0, 2)

x = self.ln_kv(x)
key = self.k_proj(x).permute(1, 0, 2)
value = self.v_proj(x).permute(1, 0, 2)

attn_output, _ = self.multihead_attn(query,
key,
value,
attn_mask=attn_mask)

attn_output = attn_output.permute(1, 0, 2)

attn_output = self.linear(attn_output)

return attn_output


class AriaProjector(nn.Module):
"""
A projection module with one cross attention layer and one FFN layer, which
Expand All @@ -198,28 +145,26 @@ class AriaProjector(nn.Module):
A tensor with the shape of (batch_size, query_number, output_dim)
"""

def __init__(
self,
patch_to_query_dict: dict[int, int],
embed_dim: int,
num_heads: int,
kv_dim: int,
ff_dim: int,
output_dim: int,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
) -> None:
def __init__(self, config: AriaConfig) -> None:
super().__init__()
self.patch_to_query_dict = patch_to_query_dict
self.embed_dim = embed_dim
self.num_heads = num_heads

self.patch_to_query_dict = config.projector_patch_to_query_dict
self.in_features = config.vision_config.hidden_size
self.num_heads = config.vision_config.num_attention_heads
self.kv_dim = config.vision_config.hidden_size
self.hidden_features = config.text_config.hidden_size
self.output_dim = config.text_config.hidden_size

self.query = nn.Parameter(
torch.empty(max(patch_to_query_dict.values()), self.embed_dim))
torch.empty(config.max_value_projector_patch_to_query_dict,
self.in_features))

self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads)
self.cross_attn = AriaCrossAttention(config)

self.ln_ffn = norm_layer(embed_dim)
self.ffn = FFN(embed_dim, ff_dim, output_dim)
self.layer_norm = nn.LayerNorm(self.in_features)
self.feed_forward = AriaProjectorMLP(self.in_features,
self.hidden_features,
self.output_dim)

def forward(
self,
Expand All @@ -241,7 +186,7 @@ def forward(

attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)

out = self.ffn(self.ln_ffn(attention_out))
out = self.feed_forward(self.layer_norm(attention_out))

return out

Expand Down Expand Up @@ -289,7 +234,7 @@ class MoELayer(nn.Module):

def __init__(
self,
config: AriaMoELMConfig,
config: AriaTextConfig,
quant_config: Optional[QuantizationConfig],
) -> None:
super().__init__()
Expand All @@ -303,13 +248,13 @@ def __init__(
num_experts=config.moe_num_experts,
top_k=config.moe_topk,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
reduce_results=True,
)
self.shared_experts = LlamaMLP(
config.hidden_size,
config.moe_intermediate_size * config.moe_num_shared_experts,
config.intermediate_size * config.moe_num_shared_experts,
"silu",
quant_config=quant_config,
)
Expand Down Expand Up @@ -344,7 +289,7 @@ class MoEDecoderLayer(LlamaDecoderLayer):

def __init__(
self,
config: AriaMoELMConfig,
config: AriaTextConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
Expand Down Expand Up @@ -434,25 +379,23 @@ def load_weights(self, weights: Iterable[Tuple[str,
return loaded_params


def build_mm_projector(config: PretrainedConfig):
return AriaProjector(
patch_to_query_dict=config.projector_patch_to_query_dict,
embed_dim=config.vision_config.hidden_size,
num_heads=config.vision_config.num_attention_heads,
kv_dim=config.vision_config.hidden_size,
ff_dim=config.text_config.hidden_size,
output_dim=config.text_config.hidden_size,
)


class AriaProcessingInfo(BaseProcessingInfo):

def get_hf_config(self):
return self.ctx.get_hf_config()
return self.ctx.get_hf_config(AriaConfig)

def get_vision_config(self) -> AriaVisionConfig:
def get_vision_config(self):
return self.get_hf_config().vision_config

def get_hf_processor(self):
processor = self.ctx.get_hf_processor(AriaProcessor)

# Patch for https://github.com/huggingface/transformers/issues/35768
processor.tokenizer.image_token = "<|img|>"
processor.image_token = "<|img|>"

return processor

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}

Expand Down Expand Up @@ -555,7 +498,7 @@ def __init__(

self.config = config
self.vision_tower = AriaVisionModel(config.vision_config)
self.multi_modal_projector = build_mm_projector(config)
self.multi_modal_projector = AriaProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AriaMoELMModel(
vllm_config=vllm_config.with_hf_config(config.text_config),
Expand Down Expand Up @@ -683,6 +626,5 @@ def sample(
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

loader = AutoWeightsLoader(self)
loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
9 changes: 4 additions & 5 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from vllm.logger import init_logger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.transformers_utils.configs import (AriaConfig, ChatGLMConfig,
Cohere2Config, DbrxConfig,
DeepseekVLV2Config, EAGLEConfig,
ExaoneConfig, H2OVLChatConfig,
from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
DbrxConfig, DeepseekVLV2Config,
EAGLEConfig, ExaoneConfig,
H2OVLChatConfig,
InternVLChatConfig, JAISConfig,
MedusaConfig, MllamaConfig,
MLPSpeculatorConfig, MPTConfig,
Expand All @@ -52,7 +52,6 @@
}

_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"aria": AriaConfig,
"chatglm": ChatGLMConfig,
"cohere2": Cohere2Config,
"dbrx": DbrxConfig,
Expand Down
2 changes: 0 additions & 2 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from vllm.transformers_utils.configs.aria import AriaConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.cohere2 import Cohere2Config
from vllm.transformers_utils.configs.dbrx import DbrxConfig
Expand All @@ -24,7 +23,6 @@
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

__all__ = [
"AriaConfig",
"ChatGLMConfig",
"Cohere2Config",
"DbrxConfig",
Expand Down
Loading
Loading