Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Chunked prefill for multimodal models #10290

Open
1 task done
QiuJingkai opened this issue Nov 13, 2024 · 1 comment
Open
1 task done

[Feature]: Chunked prefill for multimodal models #10290

QiuJingkai opened this issue Nov 13, 2024 · 1 comment

Comments

@QiuJingkai
Copy link

QiuJingkai commented Nov 13, 2024

Your current environment

[pip3] numpy==1.25.1
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.2.0
[pip3] torch==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.46.2
[pip3] triton==3.1.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.3.post2.dev280+ge036e527

Model Input Dumps

err_execute_model_input_20241113-062507.zip

🐛 Describe the bug

When the multimodal model uses chunked prefill for inference, if the number of tokens exceeds num_max_batched_tokens, the available prefill slots are not enough for the multimodal placeholders. It will occur an error in merge_multimodal_embeddings.

ERROR 11-13 06:25:07 engine.py:144] ValueError('Error in model execution (input dumped to /tmp/err_execute_model_input_20241113-062507.pkl): Attempted to assign 9 x 576 = 5184 multimodal tokens to 4879 placeholders')
ERROR 11-13 06:25:07 engine.py:144] Traceback (most recent call last):
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner_base.py", line 116, in _wrapper
ERROR 11-13 06:25:07 engine.py:144]     return func(*args, **kwargs)
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 1646, in execute_model
ERROR 11-13 06:25:07 engine.py:144]     hidden_or_intermediate_states = model_executable(
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
ERROR 11-13 06:25:07 engine.py:144]     return self._call_impl(*args, **kwargs)
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
ERROR 11-13 06:25:07 engine.py:144]     return forward_call(*args, **kwargs)
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llava.py", line 505, in forward
ERROR 11-13 06:25:07 engine.py:144]     inputs_embeds = merge_multimodal_embeddings(
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/utils.py", line 426, in merge_multimodal_embeddings
ERROR 11-13 06:25:07 engine.py:144]     return _merge_multimodal_embeddings(
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/utils.py", line 365, in _merge_multimodal_embeddings
ERROR 11-13 06:25:07 engine.py:144]     raise ValueError(
ERROR 11-13 06:25:07 engine.py:144] ValueError: Attempted to assign 9 x 576 = 5184 multimodal tokens to 4879 placeholders
ERROR 11-13 06:25:07 engine.py:144] 
ERROR 11-13 06:25:07 engine.py:144] The above exception was the direct cause of the following exception:
ERROR 11-13 06:25:07 engine.py:144] 
ERROR 11-13 06:25:07 engine.py:144] Traceback (most recent call last):
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/multiprocessing/engine.py", line 142, in start
ERROR 11-13 06:25:07 engine.py:144]     self.run_engine_loop()
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/multiprocessing/engine.py", line 205, in run_engine_loop
ERROR 11-13 06:25:07 engine.py:144]     request_outputs = self.engine_step()
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/multiprocessing/engine.py", line 223, in engine_step
ERROR 11-13 06:25:07 engine.py:144]     raise e
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/multiprocessing/engine.py", line 214, in engine_step
ERROR 11-13 06:25:07 engine.py:144]     return self.engine.step()
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 1461, in step
ERROR 11-13 06:25:07 engine.py:144]     outputs = self.model_executor.execute_model(
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/gpu_executor.py", line 125, in execute_model
ERROR 11-13 06:25:07 engine.py:144]     output = self.driver_worker.execute_model(execute_model_req)
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 343, in execute_model
ERROR 11-13 06:25:07 engine.py:144]     output = self.model_runner.execute_model(
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 11-13 06:25:07 engine.py:144]     return func(*args, **kwargs)
ERROR 11-13 06:25:07 engine.py:144]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner_base.py", line 152, in _wrapper
ERROR 11-13 06:25:07 engine.py:144]     raise type(err)(
ERROR 11-13 06:25:07 engine.py:144] ValueError: Error in model execution (input dumped to /tmp/err_execute_model_input_20241113-062507.pkl): Attempted to assign 9 x 576 = 5184 multimodal tokens to 4879 placeholders

PR #8425 and #8346 want to solve this problem, and haved been merged. So I use the vllm nightly version. But I found that this PR only support the Ultravox model by use the function merge_multimodal_embeddings_from_map.
And how can I use chunked prefill in other multimodal models, such as llava. Will vllm support other models in the future version. If I modify the source code in llava.py to use function merge_multimodal_embeddings_from_map can solve this problem.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@QiuJingkai QiuJingkai added the bug Something isn't working label Nov 13, 2024
@DarkLight1337
Copy link
Member

Please see #9950

@DarkLight1337 DarkLight1337 added feature request and removed bug Something isn't working labels Nov 13, 2024
@DarkLight1337 DarkLight1337 changed the title [Bug]: Chunked prefill error in multimodal [Feature]: Chunked prefill for multimodal models Nov 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants