Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: ValueError: could not broadcast input array from shape (513,) into shape (512,) #8068

Closed
1 task done
ashgold opened this issue Sep 1, 2024 · 15 comments · Fixed by #8340
Closed
1 task done
Labels
bug Something isn't working

Comments

@ashgold
Copy link

ashgold commented Sep 1, 2024

Your current environment

The output of `python collect_env.py`
Collecting environment information...
PyTorch version: 2.4.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.10.14 (main, Apr  6 2024, 18:45:05) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-25-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3

Nvidia driver version: 535.86.10
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   52 bits physical, 57 bits virtual
CPU(s):                          96
On-line CPU(s) list:             0-95
Thread(s) per core:              1
Core(s) per socket:              48
Socket(s):                       2
NUMA node(s):                    8
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           143
Model name:                      Intel(R) Xeon(R) Platinum 8468
Stepping:                        8
CPU MHz:                         799.999
CPU max MHz:                     2100.0000
CPU min MHz:                     800.0000
BogoMIPS:                        4200.00
L1d cache:                       4.5 MiB
L1i cache:                       3 MiB
L2 cache:                        192 MiB
L3 cache:                        210 MiB
NUMA node0 CPU(s):               0-11
NUMA node1 CPU(s):               12-23
NUMA node2 CPU(s):               24-35
NUMA node3 CPU(s):               36-47
NUMA node4 CPU(s):               48-59
NUMA node5 CPU(s):               60-71
NUMA node6 CPU(s):               72-83
NUMA node7 CPU(s):               84-95
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr avx512_fp16 flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] flashinfer==0.1.4+cu121torch2.4
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.6.20
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] pyzmq==26.2.0
[pip3] torch==2.4.0
[pip3] torchvision==0.19.0
[pip3] transformers==4.44.2
[pip3] triton==3.0.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.5@
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    NIC6    NIC7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    PXB     SYS     SYS     SYS     SYS     SYS     SYS     SYS     0-11    0               N/A
GPU1    NV18     X      NV18    NV18    SYS     PIX     PIX     PXB     SYS     SYS     SYS     SYS     24-35   2               N/A
GPU2    NV18    NV18     X      NV18    SYS     SYS     SYS     SYS     SYS     SYS     SYS     PXB     72-83   6               N/A
GPU3    NV18    NV18    NV18     X      SYS     SYS     SYS     SYS     SYS     SYS     SYS     PIX     72-83   6               N/A
NIC0    PXB     SYS     SYS     SYS      X      SYS     SYS     SYS     SYS     SYS     SYS     SYS
NIC1    SYS     PIX     SYS     SYS     SYS      X      PIX     PXB     SYS     SYS     SYS     SYS
NIC2    SYS     PIX     SYS     SYS     SYS     PIX      X      PXB     SYS     SYS     SYS     SYS
NIC3    SYS     PXB     SYS     SYS     SYS     PXB     PXB      X      SYS     SYS     SYS     SYS
NIC4    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      PIX     PXB     SYS
NIC5    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     PIX      X      PXB     SYS
NIC6    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     PXB     PXB      X      SYS
NIC7    SYS     SYS     PXB     PIX     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5
  NIC6: mlx5_6
  NIC7: mlx5_7

🐛 Describe the bug

ERROR 08-31 23:19:40 async_llm_engine.py:65] Engine background task failed
ERROR 08-31 23:19:40 async_llm_engine.py:65] Traceback (most recent call last):
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 55, in _log_task_completion
ERROR 08-31 23:19:40 async_llm_engine.py:65]     return_value = task.result()
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 930, in run_engine_loop
ERROR 08-31 23:19:40 async_llm_engine.py:65]     result = task.result()
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 873, in engine_step
ERROR 08-31 23:19:40 async_llm_engine.py:65]     request_outputs = await self.engine.step_async(virtual_engine)
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 337, in step_async
ERROR 08-31 23:19:40 async_llm_engine.py:65]     output = await self.model_executor.execute_model_async(
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/distributed_gpu_executor.py", line 175, in execute_model_async
ERROR 08-31 23:19:40 async_llm_engine.py:65]     return await self._driver_execute_model_async(execute_model_req)
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/multiproc_gpu_executor.py", line 224, in _driver_execute_model_async
ERROR 08-31 23:19:40 async_llm_engine.py:65]     return await self.driver_exec_model(execute_model_req)
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run
ERROR 08-31 23:19:40 async_llm_engine.py:65]     result = self.fn(*self.args, **self.kwargs)
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 298, in execute_model
ERROR 08-31 23:19:40 async_llm_engine.py:65]     inputs = self.prepare_input(execute_model_req)
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/multi_step_worker.py", line 157, in prepare_input
ERROR 08-31 23:19:40 async_llm_engine.py:65]     kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/multi_step_worker.py", line 60, in _get_driver_input_and_broadcast
ERROR 08-31 23:19:40 async_llm_engine.py:65]     self.model_runner.prepare_model_input(
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/multi_step_model_runner.py", line 208, in prepare_model_input
ERROR 08-31 23:19:40 async_llm_engine.py:65]     frozen_model_input = self._base_model_runner.prepare_model_input(
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 1345, in prepare_model_input
ERROR 08-31 23:19:40 async_llm_engine.py:65]     model_input = self._prepare_model_input_tensors(
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 1006, in _prepare_model_input_tensors
ERROR 08-31 23:19:40 async_llm_engine.py:65]     return builder.build()  # type: ignore
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 707, in build
ERROR 08-31 23:19:40 async_llm_engine.py:65]     attn_metadata = self.attn_metadata_builder.build(
ERROR 08-31 23:19:40 async_llm_engine.py:65]   File "/usr/local/lib/python3.10/dist-packages/vllm/attention/backends/flash_attn.py", line 467, in build
ERROR 08-31 23:19:40 async_llm_engine.py:65]     input_block_tables[i, :len(block_table)] = block_table
ERROR 08-31 23:19:40 async_llm_engine.py:65] ValueError: could not broadcast input array from shape (513,) into shape (512,)

This is the same issue raised in #5563. From what I've researched, it was related to the size of allocated space associated with CUDA Graph.

here is arguments vllm to run.

    containers:
    - args:
      - --model
      - /data/models/llama-3-1-70b-instruct/base
      - --tensor-parallel-size
      - "4"
      - --load-format
      - "auto"
      - --max-model-len
      - "16384"
#      - --max-seq-len-to-capture
#      - "16384"
      - --disable-log-requests
      - --uvicorn-log-level
      - "warning"
      - --gpu-memory-utilization
      - "0.9"
      - --enable-prefix-caching
      - --num-scheduler-steps
      - "8"
      image: aspcr01-queffmyz.scr.skr-west.scp-in.com/russianblue/vllm:v0.5.5

The error was occurring at the following points.

if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table

it's regarding max context len // block size.
I went to the part that allocates self.runner.graph_block_tables, and it was allocating it like this.

self.graph_block_tables = np.zeros(
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32)

def get_max_block_per_batch(self) -> int:
block_size = self.block_size
return (self.max_seq_len_to_capture + block_size - 1) // block_size

max_seq_len_to_capture was set to the default value of 8192 unless otherwise set.

max_seq_len_to_capture: int = 8192

Ultimately, the value of self.max_seq_len_to_capture was determined by the following logic.

vllm/vllm/config.py

Lines 333 to 337 in 5b86b19

def _verify_cuda_graph(self) -> None:
if self.max_seq_len_to_capture is None:
self.max_seq_len_to_capture = self.max_model_len
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len)

I think this can be fixed by replacing min with max.
I'm curious what your intentions were in taking the min value.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@ashgold ashgold added the bug Something isn't working label Sep 1, 2024
@youkaichao
Copy link
Member

great investigation!

I think the min operation makes sense, but somewhere we made a wrong decision to use cuda graph.

the problem here, is that cudagraph only works for block table with 512 blocks. but somehow we are using cudagraph even if it has 513 blocks.

if you can provide the input example, or even find out which code made the wrong decision, it would be very helpful.

you should be able to find the code to blame in vllm/vllm/worker/model_runner.py or vllm/attention/backends/flash_attn.py .

@ashgold
Copy link
Author

ashgold commented Sep 3, 2024

great investigation!

I think the min operation makes sense, but somewhere we made a wrong decision to use cuda graph.

the problem here, is that cudagraph only works for block table with 512 blocks. but somehow we are using cudagraph even if it has 513 blocks.

if you can provide the input example, or even find out which code made the wrong decision, it would be very helpful.

you should be able to find the code to blame in vllm/vllm/worker/model_runner.py or vllm/attention/backends/flash_attn.py .

I'm using the conversation history to generate load, but the timing of when vLLM throws the error keeps varying, so it seems like it's not due to a specific prompt, but rather a specific condition that needs to be met while the system is under load.

Anyway, if I set the --max-seq-len-to-capture option to be the same as --max-model-len, would I be able to avoid the current error?

@youkaichao
Copy link
Member

Anyway, if I set the --max-seq-len-to-capture option to be the same as --max-model-len, would I be able to avoid the current error?

I think so, but not sure. Welcome to have a try and report back.

@Ximingwang-09
Copy link

很棒的调查!

我认为该min操作是有意义的,但是在某些地方我们做出了使用 cuda graph 的错误决定。

这里的问题是,cudagraph 仅适用于具有 512 个块的块表。但不知何故,即使它有 513 个块,我们也在使用 cudagraph。

如果您可以提供输入示例,甚至找出哪个代码做出了错误的决定,那将非常有帮助。

您应该能够在 vllm/vllm/worker/model_runner.py 或 vllm/attention/backends/flash_attn.py 中找到应负责任的代码。

I think this pr may solve this problem #8145

@youkaichao
Copy link
Member

@Ximingwang-09 thanks for the investigation!

  - --num-scheduler-steps
 - "8"

I just notice @ashgold uses multi-step scheduling, which indeed has lookahead slots.

@stefanobranco
Copy link

stefanobranco commented Sep 5, 2024

I can confirm, I have the same issue, but only if using multi-step scheduling - without it everything works fine.

@ashgold
Copy link
Author

ashgold commented Sep 5, 2024

I can confirm, I have the same issue, but only if using multi-step scheduling - without it everything works fine.

This is an assumption, but if we do multi step scheduling, engine needs consecutive memory space for multi step scheduling, and in that case, is there any chance that it's going to go beyond the 512 that was allocated?

@alexm-redhat
Copy link
Collaborator

alexm-redhat commented Sep 10, 2024

This may provide a tmp solution #8340. Would be good to know if this solves your issue.

@ashgold
Copy link
Author

ashgold commented Sep 10, 2024

This may provide a tmp solution #8340. Would be good to know if this solves your issue.

I'll see if this issue will be reproduced when the next release comes out. Thank you!

@youkaichao
Copy link
Member

@ashgold you don't need to wait for release. we have per-commit wheel.

after that pr is merged, you can follow https://docs.vllm.ai/en/latest/getting_started/installation.html to install the wheel for the commit.

in addition, if you want to have a try, you can even just add the lines in the pr into your code. it's just several python lines.

@ashgold
Copy link
Author

ashgold commented Sep 10, 2024

@ashgold you don't need to wait for release. we have per-commit wheel.

after that pr is merged, you can follow https://docs.vllm.ai/en/latest/getting_started/installation.html to install the wheel for the commit.

in addition, if you want to have a try, you can even just add the lines in the pr into your code. it's just several python lines.

Okay.
There is other PR(#8267) that solved other issues. So I'll check them together once this PR is merged.

@ashgold
Copy link
Author

ashgold commented Sep 13, 2024

I checked that issue not reproduced in v0.6.1.post1. I think problem solved!

@andrea-veritas
Copy link

andrea-veritas commented Oct 16, 2024

I checked that issue not reproduced in v0.6.1.post1. I think problem solved!

it seemed that issue reproduced in v0.6.3 again. downgrading to v0.6.1.post1 solved the problem in my case.

@youkaichao
Copy link
Member

@andrea-veritas please open a new issue with detailed info

@FuryMartin
Copy link
Contributor

I checked that issue not reproduced in v0.6.1.post1. I think problem solved!

it seemed that issue reproduced in v0.6.3 again. downgrading to v0.6.1.post1 solved the problem in my case.

#9848

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
8 participants
@ashgold @andrea-veritas @youkaichao @FuryMartin @stefanobranco @alexm-redhat @Ximingwang-09 and others