Skip to content

Commit

Permalink
* Add llava-next-video tests
Browse files Browse the repository at this point in the history
  • Loading branch information
TKONIY committed Aug 31, 2024
1 parent 67261b5 commit 46960dc
Show file tree
Hide file tree
Showing 5 changed files with 331 additions and 14 deletions.
56 changes: 55 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import TokenizerPoolConfig
from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment,
Expand All @@ -44,6 +45,7 @@
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]
PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]]


def _read_prompts(filename: str) -> List[str]:
Expand Down Expand Up @@ -85,8 +87,35 @@ def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
return [prompts["stop_sign"], prompts["cherry_blossom"]]


class _VideoAssetPrompts(TypedDict):
sample_demo_1: str


if sys.version_info < (3, 9):
# UserList cannot be subscripted
class _VideoAssetsBase(UserList):
pass
else:

class _VideoAssetsBase(UserList[VideoAsset]):
pass


class _VideoAssets(_VideoAssetsBase):

def __init__(self) -> None:
super().__init__([
VideoAsset("sample_demo_1.mp4"),
])

def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
return [prompts["sample_demo_1"]]


IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -202,6 +231,11 @@ def image_assets() -> _ImageAssets:
return IMAGE_ASSETS


@pytest.fixture(scope="session")
def video_assets() -> _VideoAssets:
return VIDEO_ASSETS


_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)


Expand Down Expand Up @@ -279,6 +313,7 @@ def generate(
self,
prompts: List[str],
images: Optional[List[Image.Image]] = None,
videos: Optional[List[np.ndarray]] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
if images:
Expand All @@ -292,6 +327,8 @@ def generate(
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
if videos is not None and videos[i] is not None:
processor_kwargs["videos"] = videos[i]

inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
Expand Down Expand Up @@ -352,6 +389,7 @@ def generate_greedy_logprobs(
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
videos: Optional[List[np.ndarray]] = None,
**kwargs: Any,
) -> List[List[torch.Tensor]]:
all_logprobs: List[List[torch.Tensor]] = []
Expand All @@ -362,6 +400,8 @@ def generate_greedy_logprobs(
}
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
if videos is not None and videos[i] is not None:
processor_kwargs["videos"] = videos[i]

inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
Expand Down Expand Up @@ -435,6 +475,7 @@ def generate_greedy_logprobs_limit(
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
audios: Optional[List[Tuple[np.ndarray, int]]] = None,
videos: Optional[List[np.ndarray]] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = []
Expand All @@ -454,6 +495,8 @@ def generate_greedy_logprobs_limit(
processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr

if videos is not None:
processor_kwargs["videos"] = videos[i]
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)

Expand Down Expand Up @@ -634,12 +677,16 @@ def generate_w_logprobs(
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None

if images is not None:
assert len(prompts) == len(images)

if videos is not None:
assert len(prompts) == len(videos)

inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
Expand All @@ -649,6 +696,11 @@ def generate_w_logprobs(
for i, audio in enumerate(audios):
inputs[i]["multi_modal_data"] = {"audio": audio}

if videos is not None:
for i, video in enumerate(videos):
inputs[i]["multi_modal_data"] = {"video": video}
print(f"[INPUTS!!!!]: {inputs}, {sampling_params}")

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs)
Expand Down Expand Up @@ -685,6 +737,7 @@ def generate_greedy_logprobs(
num_logprobs: int,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
Expand All @@ -694,7 +747,8 @@ def generate_greedy_logprobs(
outputs = self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images,
audios=audios)
audios=audios,
videos=videos)

return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
Expand Down
228 changes: 228 additions & 0 deletions tests/models/test_llava_next_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
from typing import List, Optional, Tuple, Type, overload

import pytest
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer

from vllm.multimodal.utils import rescale_video_size, resize_video
from vllm.multimodal.utils import sample_frames_from_video
from vllm.sequence import SampleLogprobs

from ..conftest import VIDEO_ASSETS, HfRunner, VllmRunner, _VideoAssets
from .utils import check_logprobs_close

pytestmark = pytest.mark.vlm

_PREFACE = (
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's "
"questions.")

HF_VIDEO_PROMPTS = VIDEO_ASSETS.prompts({
"sample_demo_1": f"{_PREFACE}USER: <video>\nWhy is this video funny? ASSISTANT:"
})

models = ["llava-hf/LLaVA-NeXT-Video-7B-hf"]


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output

config = AutoConfig.from_pretrained(model)
video_token_id = config.video_token_index

tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id

hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != video_token_id or output_ids[idx - 1] != video_token_id
]

assert output_str[0] == " "
hf_output_str = output_str[1:]
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)

return hf_output_ids, hf_output_str, out_logprobs


@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
num_frames: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...


@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
num_frames: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...


def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
num_frames: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
videos = [
sample_frames_from_video(asset.np_ndarrays, num_frames)
for asset in video_assets]

for video in videos:
print(video.shape)

if size_factors is not None:
inputs_per_video = [(
[prompt for _ in size_factors],
[rescale_video_size(video, factor) for factor in size_factors],
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
elif sizes is not None:
inputs_per_video = [(
[prompt for _ in sizes],
[resize_video(video, size) for size in sizes],
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
else:
raise ValueError("You must provide either `size_factors` or `sizes`")

# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_model_len=4096,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_video = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
videos=videos)
for prompts, videos in inputs_per_video
]

with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_video = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
videos=videos)
for prompts, videos in inputs_per_video
]

for hf_outputs, vllm_outputs in zip(hf_outputs_per_video,
vllm_outputs_per_video):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No video
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("num_frames", [16])
def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
dtype, max_tokens, num_logprobs, num_frames) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/videos.
For huggingface runner, we provide the np.ndarray as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
run_test(
hf_runner,
vllm_runner,
video_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
num_frames=num_frames,
tensor_parallel_size=1,
)


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("num_frames", [16])
def test_models_fixed_sizes(hf_runner, vllm_runner, video_assets, model, sizes,
dtype, max_tokens, num_logprobs, num_frames) -> None:
run_test(
hf_runner,
vllm_runner,
video_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
num_frames=num_frames,
tensor_parallel_size=1,
)
Loading

0 comments on commit 46960dc

Please sign in to comment.