Skip to content

Commit

Permalink
[BugFix] Propagate 'trust_remote_code' setting in internvl and minicp…
Browse files Browse the repository at this point in the history
…mv (#8250)
  • Loading branch information
zifeitong authored Sep 25, 2024
1 parent fc3afc2 commit e3dd069
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 41 deletions.
15 changes: 9 additions & 6 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,9 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)

prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
Expand Down Expand Up @@ -278,8 +279,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
use_thumbnail=use_thumbnail) for img in data
]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_token_id = tokenizer.encode(IMG_CONTEXT,
add_special_tokens=False,
return_tensors="pt")[0]
Expand All @@ -298,8 +300,9 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)

seq_data = dummy_seq_data_for_clip(
vision_config,
Expand Down
137 changes: 108 additions & 29 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torch import nn
from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
Expand All @@ -52,6 +53,7 @@
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData
Expand All @@ -64,6 +66,17 @@
}


class MiniCPMVImageInput(TypedDict):
"""Input mapper input with auxiliary data for computing image bounds."""
image: Image.Image

# Image bounds token ids in 0-dim scaler tensor.
im_start_id: torch.Tensor
im_end_id: torch.Tensor
slice_start_id: NotRequired[torch.Tensor]
slice_end_id: NotRequired[torch.Tensor]


class MiniCPMVImagePixelInputs(TypedDict):
pixel_values: List[torch.Tensor]
"""
Expand All @@ -88,8 +101,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
"""


MiniCPMVImageInputs = MiniCPMVImagePixelInputs

DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)


Expand Down Expand Up @@ -234,6 +245,25 @@ def forward(self, x: torch.Tensor,
return x


def _build_image_input(ctx: InputContext,
image: Image.Image) -> MiniCPMVImageInput:
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code)
if hasattr(tokenizer, "slice_start_id"):
return MiniCPMVImageInput(
image=image,
im_start_id=torch.tensor(tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id),
slice_start_id=torch.tensor(tokenizer.slice_start_id),
slice_end_id=torch.tensor(tokenizer.slice_end_id))
else:
return MiniCPMVImageInput(image=image,
im_start_id=torch.tensor(
tokenizer.im_start_id),
im_end_id=torch.tensor(tokenizer.im_end_id))


def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
version_float = getattr(config, "version", None)

Expand All @@ -257,10 +287,13 @@ def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
return SequenceData.from_token_counts((0, seq_len))


def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int):
def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig,
num_images: int):
width = height = hf_config.image_size
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
image = _build_image_input(ctx,
image=Image.new("RGB", (width, height),
color=0))
return {"image": [image] if num_images == 1 else [image] * num_images}


def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
Expand All @@ -269,7 +302,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
num_images = mm_counts["image"]

seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
mm_data = dummy_image_for_minicpmv(hf_config, num_images)
mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images)

return seq_data, mm_data

Expand All @@ -280,8 +313,9 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs
model_config = ctx.model_config
version = get_version_by_config(model_config.hf_config)
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_processor = cached_get_image_processor(model_config.tokenizer)

def get_placeholder(image_size: Tuple[int, int], num_image: int):
Expand Down Expand Up @@ -317,6 +351,10 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int):
new_prompt = "".join(new_prompt_chunks)
new_token_ids = tokenizer.encode(new_prompt)

multi_modal_data["image"] = [
_build_image_input(ctx, image) for image in images
]

llm_inputs = LLMInputs(
prompt_token_ids=new_token_ids,
prompt=new_prompt,
Expand All @@ -325,6 +363,32 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int):
return llm_inputs


def input_mapper_for_minicpmv(ctx: InputContext, data: object):
model_config = ctx.model_config

image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")

if not isinstance(data, list):
raise ValueError(
"Image input must be list of MiniCPMVImageInput, got (%s)", data)
batch_data = image_processor \
.preprocess([img["image"] for img in data], return_tensors="pt") \
.data

if len(data) > 0:
batch_data["im_start_id"] = data[0]["im_start_id"]
batch_data["im_end_id"] = data[0]["im_end_id"]
if "slice_start_id" in data[0]:
batch_data["slice_start_id"] = data[0]["slice_start_id"]
batch_data["slice_end_id"] = data[0]["slice_end_id"]

return MultiModalInputs(batch_data)


class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
Expand Down Expand Up @@ -365,7 +429,7 @@ def __init__(
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
image_inputs: Optional[MiniCPMVImagePixelInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
if hasattr(self.config, "scale_emb"):
Expand Down Expand Up @@ -393,14 +457,20 @@ def get_embedding(

return vlm_embedding, vision_hidden_states

def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor:
tokenizer = cached_get_tokenizer(self.config._name_or_path,
trust_remote_code=True)
start_cond = input_ids == tokenizer.im_start_id
end_cond = input_ids == tokenizer.im_end_id
if hasattr(tokenizer, "slice_start_id"):
start_cond |= (input_ids == tokenizer.slice_start_id)
end_cond |= (input_ids == tokenizer.slice_end_id)
def _get_image_bounds(
self,
input_ids: torch.Tensor,
im_start_id: torch.Tensor,
im_end_id: torch.Tensor,
slice_start_id: Optional[torch.Tensor] = None,
slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor:
# All the images in the batch should share the same special image
# bound token ids.
start_cond = input_ids == im_start_id[0]
end_cond = input_ids == im_end_id[0]
if slice_start_id is not None:
start_cond |= (input_ids == slice_start_id[0])
end_cond |= (input_ids == slice_end_id[0])

image_start_tokens, = torch.where(start_cond)
image_start_tokens += 1
Expand All @@ -419,7 +489,7 @@ def _parse_and_validate_inputs(
self,
input_ids: torch.Tensor,
**kwargs: object,
) -> Optional[MiniCPMVImageInputs]:
) -> Optional[MiniCPMVImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", [])
tgt_sizes = kwargs.pop("tgt_sizes", [])

Expand Down Expand Up @@ -456,8 +526,17 @@ def _parse_and_validate_inputs(
if len(pixel_values_flat) == 0:
return None

return MiniCPMVImageInputs(
image_bounds=self._get_image_bounds(input_ids),
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
if im_start_id is None:
return None

return MiniCPMVImagePixelInputs(
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
pixel_values=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
)
Expand Down Expand Up @@ -564,8 +643,8 @@ def get_vision_embedding(
) -> torch.Tensor:
raise NotImplementedError

def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
raise NotImplementedError

def is_default_weight_loading(self, name: str) -> bool:
Expand Down Expand Up @@ -654,8 +733,8 @@ def get_vision_embedding(
res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res)

def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]

return self.get_vision_embedding(pixel_values)
Expand Down Expand Up @@ -713,8 +792,8 @@ def get_vision_embedding(
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
return vision_embedding

def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]

Expand Down Expand Up @@ -807,8 +886,8 @@ def get_vision_embedding(
).last_hidden_state
return vision_embedding

def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]

Expand Down Expand Up @@ -851,7 +930,7 @@ def is_default_weight_loading(self, name: str) -> bool:
}


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_minicpmv)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
Expand Down
15 changes: 9 additions & 6 deletions vllm/model_executor/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,9 @@ def input_processor_for_qwen(ctx: InputContext,
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_data = multi_modal_data["image"]
if isinstance(image_data, torch.Tensor):
num_dims = len(image_data.shape)
Expand Down Expand Up @@ -735,8 +736,9 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
return MultiModalInputs()

model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)

image_pair_tok = tokenizer.encode(IMG_START + IMG_END,
add_special_tokens=False,
Expand Down Expand Up @@ -824,8 +826,9 @@ def dummy_data_for_qwen(
# We have a visual component - use images to warm up
num_images = mm_counts["image"]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)

# Build the image prompts with no imgpads; the tokenizer will add img pads
image_prompt = ''.join(
Expand Down

0 comments on commit e3dd069

Please sign in to comment.