Skip to content

Commit

Permalink
[Model] Support NVLM-D and fix QK Norm in InternViT (vllm-project#9045)
Browse files Browse the repository at this point in the history
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
  • Loading branch information
3 people authored Oct 7, 2024
1 parent f19da64 commit 151ef4e
Show file tree
Hide file tree
Showing 12 changed files with 518 additions and 236 deletions.
9 changes: 9 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ Multimodal Language Models

.. _supported_vlms:

Text Generation
---------------

.. list-table::
:widths: 25 25 25 25 5 5
:header-rows: 1
Expand Down Expand Up @@ -384,7 +387,13 @@ Multimodal Language Models
- Image
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
-
-
* - :code:`NVLM_D_Model`
- NVLM-D 1.0
- Image\ :sup:`E+`
- :code:`nvidia/NVLM-D-72B`, etc.
-
- ✅︎
* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- Image\ :sup:`E`
Expand Down
55 changes: 40 additions & 15 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


# LLaVA-1.5
def run_llava(question, modality):
def run_llava(question: str, modality: str):
assert modality == "image"

prompt = f"USER: <image>\n{question}\nASSISTANT:"
Expand All @@ -29,7 +29,7 @@ def run_llava(question, modality):


# LLaVA-1.6/LLaVA-NeXT
def run_llava_next(question, modality):
def run_llava_next(question: str, modality: str):
assert modality == "image"

prompt = f"[INST] <image>\n{question} [/INST]"
Expand All @@ -40,7 +40,7 @@ def run_llava_next(question, modality):

# LlaVA-NeXT-Video
# Currently only support for video input
def run_llava_next_video(question, modality):
def run_llava_next_video(question: str, modality: str):
assert modality == "video"

prompt = f"USER: <video>\n{question} ASSISTANT:"
Expand All @@ -50,7 +50,7 @@ def run_llava_next_video(question, modality):


# LLaVA-OneVision
def run_llava_onevision(question, modality):
def run_llava_onevision(question: str, modality: str):

if modality == "video":
prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
Expand All @@ -67,7 +67,7 @@ def run_llava_onevision(question, modality):


# Fuyu
def run_fuyu(question, modality):
def run_fuyu(question: str, modality: str):
assert modality == "image"

prompt = f"{question}\n"
Expand All @@ -77,7 +77,7 @@ def run_fuyu(question, modality):


# Phi-3-Vision
def run_phi3v(question, modality):
def run_phi3v(question: str, modality: str):
assert modality == "image"

prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
Expand Down Expand Up @@ -112,7 +112,7 @@ def run_phi3v(question, modality):


# PaliGemma
def run_paligemma(question, modality):
def run_paligemma(question: str, modality: str):
assert modality == "image"

# PaliGemma has special prompt format for VQA
Expand All @@ -123,7 +123,7 @@ def run_paligemma(question, modality):


# Chameleon
def run_chameleon(question, modality):
def run_chameleon(question: str, modality: str):
assert modality == "image"

prompt = f"{question}<image>"
Expand All @@ -133,7 +133,7 @@ def run_chameleon(question, modality):


# MiniCPM-V
def run_minicpmv(question, modality):
def run_minicpmv(question: str, modality: str):
assert modality == "image"

# 2.0
Expand Down Expand Up @@ -176,7 +176,7 @@ def run_minicpmv(question, modality):


# InternVL
def run_internvl(question, modality):
def run_internvl(question: str, modality: str):
assert modality == "image"

model_name = "OpenGVLab/InternVL2-2B"
Expand All @@ -203,8 +203,32 @@ def run_internvl(question, modality):
return llm, prompt, stop_token_ids


# NVLM-D
def run_nvlm_d(question: str, modality: str):
assert modality == "image"

model_name = "nvidia/NVLM-D-72B"

# Adjust this as necessary to fit in GPU
llm = LLM(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
tensor_parallel_size=4,
)

tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
stop_token_ids = None
return llm, prompt, stop_token_ids


# BLIP-2
def run_blip2(question, modality):
def run_blip2(question: str, modality: str):
assert modality == "image"

# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
Expand All @@ -216,7 +240,7 @@ def run_blip2(question, modality):


# Qwen
def run_qwen_vl(question, modality):
def run_qwen_vl(question: str, modality: str):
assert modality == "image"

llm = LLM(
Expand All @@ -232,7 +256,7 @@ def run_qwen_vl(question, modality):


# Qwen2-VL
def run_qwen2_vl(question, modality):
def run_qwen2_vl(question: str, modality: str):
assert modality == "image"

model_name = "Qwen/Qwen2-VL-7B-Instruct"
Expand All @@ -252,8 +276,8 @@ def run_qwen2_vl(question, modality):
return llm, prompt, stop_token_ids


# LLama
def run_mllama(question, modality):
# LLama 3.2
def run_mllama(question: str, modality: str):
assert modality == "image"

model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
Expand Down Expand Up @@ -287,6 +311,7 @@ def run_mllama(question, modality):
"minicpmv": run_minicpmv,
"blip-2": run_blip2,
"internvl_chat": run_internvl,
"NVLM_D": run_nvlm_d,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"mllama": run_mllama,
Expand Down
34 changes: 34 additions & 0 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,39 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
)


def load_nvlm_d(question: str, image_urls: List[str]):
model_name = "nvidia/NVLM-D-72B"

# Adjust this as necessary to fit in GPU
llm = LLM(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
tensor_parallel_size=4,
limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={"max_dynamic_patch": 4},
)

placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]

tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
stop_token_ids = None

return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)


def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
try:
from qwen_vl_utils import process_vision_info
Expand Down Expand Up @@ -204,6 +237,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
model_example_map = {
"phi3_v": load_phi3v,
"internvl_chat": load_internvl,
"NVLM_D": load_nvlm_d,
"qwen2_vl": load_qwen2_vl,
"qwen_vl_chat": load_qwenvl_chat,
}
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _placeholder_str(self, modality: ModalityStr,
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
return "<image>"
if model_type == "mllama":
return "<|image|>"
Expand Down
32 changes: 30 additions & 2 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@ def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))

self.hidden_size = hidden_size
self.variance_epsilon = eps
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)

self.weight = nn.Parameter(torch.ones(hidden_size))

def forward_native(
self,
Expand All @@ -35,7 +41,23 @@ def forward_native(
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)

variance = x.pow(2).mean(dim=-1, keepdim=True)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")

if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")

x_var = x[:, :, :self.variance_size_override]

variance = x_var.pow(2).mean(dim=-1, keepdim=True)

x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
Expand All @@ -48,6 +70,9 @@ def forward_cuda(
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)

from vllm import _custom_ops as ops

if residual is not None:
Expand All @@ -72,6 +97,9 @@ def forward_xpu(
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)

from vllm._ipex_ops import ipex_ops as ops

if residual is not None:
Expand Down
Loading

0 comments on commit 151ef4e

Please sign in to comment.