Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Issue running the Granite-7b GGUF quantized model on multiple GPUs with vLLM due to a tensor size mismatch. #12170

Closed
1 task done
tarukumar opened this issue Jan 17, 2025 · 1 comment · Fixed by #12230
Assignees
Labels
bug Something isn't working

Comments

@tarukumar
Copy link

tarukumar commented Jan 17, 2025

Your current environment

The output of `python collect_env.py`
sh-5.1$ python collect_env.py 
Collecting environment information...
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Red Hat Enterprise Linux 9.5 (Plow) (x86_64)
GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-2)
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.34

Python version: 3.12.5 (main, Dec  3 2024, 00:00:00) [GCC 11.5.0 20240719 (Red Hat 11.5.0-2)] (64-bit runtime)
Python platform: Linux-5.14.0-427.47.1.el9_4.x86_64-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB

Nvidia driver version: 550.127.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               80
On-line CPU(s) list:                  0-79
Vendor ID:                            GenuineIntel
Model name:                           Intel Xeon Processor (Icelake)
CPU family:                           6
Model:                                134
Thread(s) per core:                   2
Core(s) per socket:                   20
Socket(s):                            2
Stepping:                             0
BogoMIPS:                             5600.02
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm md_clear arch_capabilities
Virtualization:                       VT-x
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            2.5 MiB (80 instances)
L1i cache:                            2.5 MiB (80 instances)
L2 cache:                             160 MiB (40 instances)
L3 cache:                             32 MiB (2 instances)
NUMA node(s):                         2
NUMA node0 CPU(s):                    0-39
NUMA node1 CPU(s):                    40-79
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Reg file data sampling: Vulnerable: No microcode
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS Not affected; BHI SW loop, KVM SW loop
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] flashinfer==0.1.6+cu124torch2.4
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.2.0
[pip3] torch==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.48.0
[pip3] triton==3.1.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.6.post2.dev254+g3b27f9351
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    NIC0    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV12    NODE    0-39    0               N/A
GPU1    NV12     X      NODE    0-39    0               N/A
NIC0    NODE    NODE     X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0

NVIDIA_VISIBLE_DEVICES=GPU-60d7e087-ad98-9a77-fb5c-5ce7d11abb96,GPU-f3a5a87e-eb5c-fc8b-3278-66ea3134947e
VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
VLLM_WORKER_MULTIPROC_METHOD=fork
VLLM_USAGE_SOURCE=production-docker-image
LD_LIBRARY_PATH=/opt/vllm/lib/python3.12/site-packages/cv2/../../lib64:/opt/vllm/lib/python3.12/site-packages/nvidia/nvtx/lib:/opt/vllm/lib/python3.12/site-packages/nvidia/cuda_runtime/lib:/opt/vllm/lib/python3.12/site-packages/nvidia/cuda_nvrtc/lib:
VLLM_NO_USAGE_STATS=1
CUDA_MODULE_LOADING=LAZY

sh-5.1$ 

Model Input Dumps

No response

🐛 Describe the bug

I'm trying to load andrun Granite-7b GGUF quantized model on multi gpus in openshift/k8s cluster , but I'm encountering a tensor size mismatch error while model is being loaded.

NOTE:

  1. The issue is not observed when using a single GPU to load the model; it loads and able to performs inference without any issues.
  2. The issue is also observed and ca be reproduced in the current master branch.

Configurations:

      args:
        - '--port=8080'
        - '--distributed-executor-backend=mp'
        - '--model=/mnt/models/granite-7b-lab-Q4_K_M.gguf'
        - '--tensor-parallel-size=2'
        - '--max-model-len=4096'
        - '--uvicorn-log-level=debug'
        - '--served-model-name=granite-7b-lab-gguf'

Error logs:

WARNING 01-17 14:00:06 multiproc_worker_utils.py:312] Reducing Torch parallelism from 40 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 01-17 14:00:06 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
INFO 01-17 14:00:06 selector.py:120] Using Flash Attention backend.
�[1;36m(VllmWorkerProcess pid=354)�[0;0m INFO 01-17 14:00:06 selector.py:120] Using Flash Attention backend.
�[1;36m(VllmWorkerProcess pid=354)�[0;0m INFO 01-17 14:00:06 multiproc_worker_utils.py:222] Worker ready; awaiting tasks
�[1;36m(VllmWorkerProcess pid=354)�[0;0m INFO 01-17 14:00:07 utils.py:918] Found nccl from library libnccl.so.2
INFO 01-17 14:00:07 utils.py:918] Found nccl from library libnccl.so.2
INFO 01-17 14:00:07 pynccl.py:69] vLLM is using nccl==2.21.5
�[1;36m(VllmWorkerProcess pid=354)�[0;0m INFO 01-17 14:00:07 pynccl.py:69] vLLM is using nccl==2.21.5
INFO 01-17 14:00:07 custom_all_reduce_utils.py:204] generating GPU P2P access cache in /home/vllm/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
�[1;36m(VllmWorkerProcess pid=354)�[0;0m INFO 01-17 14:00:28 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /home/vllm/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
INFO 01-17 14:00:28 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /home/vllm/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
INFO 01-17 14:00:28 shm_broadcast.py:255] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1], buffer_handle=(1, 4194304, 6, 'psm_3fbb2dd8'), local_subscribe_port=59663, remote_subscribe_port=None)
INFO 01-17 14:00:28 model_runner.py:1094] Starting to load model /mnt/models/granite-7b-lab-Q4_K_M.gguf...
�[1;36m(VllmWorkerProcess pid=354)�[0;0m INFO 01-17 14:00:28 model_runner.py:1094] Starting to load model /mnt/models/granite-7b-lab-Q4_K_M.gguf...
ERROR 01-17 14:00:32 engine.py:366] The size of tensor a (16004) must match the size of tensor b (16032) at non-singleton dimension 0
ERROR 01-17 14:00:32 engine.py:366] Traceback (most recent call last):
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 357, in run_mp_engine
ERROR 01-17 14:00:32 engine.py:366]     engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
ERROR 01-17 14:00:32 engine.py:366]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 119, in from_engine_args
ERROR 01-17 14:00:32 engine.py:366]     return cls(ipc_path=ipc_path,
ERROR 01-17 14:00:32 engine.py:366]            ^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 71, in __init__
ERROR 01-17 14:00:32 engine.py:366]     self.engine = LLMEngine(*args, **kwargs)
ERROR 01-17 14:00:32 engine.py:366]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/engine/llm_engine.py", line 273, in __init__
ERROR 01-17 14:00:32 engine.py:366]     self.model_executor = executor_class(vllm_config=vllm_config, )
ERROR 01-17 14:00:32 engine.py:366]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/executor/distributed_gpu_executor.py", line 26, in __init__
ERROR 01-17 14:00:32 engine.py:366]     super().__init__(*args, **kwargs)
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/executor/executor_base.py", line 36, in __init__
ERROR 01-17 14:00:32 engine.py:366]     self._init_executor()
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/executor/multiproc_gpu_executor.py", line 83, in _init_executor
ERROR 01-17 14:00:32 engine.py:366]     self._run_workers("load_model",
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/executor/multiproc_gpu_executor.py", line 157, in _run_workers
ERROR 01-17 14:00:32 engine.py:366]     driver_worker_output = driver_worker_method(*args, **kwargs)
ERROR 01-17 14:00:32 engine.py:366]                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/worker/worker.py", line 155, in load_model
ERROR 01-17 14:00:32 engine.py:366]     self.model_runner.load_model()
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/worker/model_runner.py", line 1096, in load_model
ERROR 01-17 14:00:32 engine.py:366]     self.model = get_model(vllm_config=self.vllm_config)
ERROR 01-17 14:00:32 engine.py:366]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/model_loader/__init__.py", line 12, in get_model
ERROR 01-17 14:00:32 engine.py:366]     return loader.load_model(vllm_config=vllm_config)
ERROR 01-17 14:00:32 engine.py:366]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/model_loader/loader.py", line 1234, in load_model
ERROR 01-17 14:00:32 engine.py:366]     model.load_weights(
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/models/llama.py", line 594, in load_weights
ERROR 01-17 14:00:32 engine.py:366]     return loader.load_weights(
ERROR 01-17 14:00:32 engine.py:366]            ^^^^^^^^^^^^^^^^^^^^
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/models/utils.py", line 237, in load_weights
ERROR 01-17 14:00:32 engine.py:366]     autoloaded_weights = set(self._load_module("", self.module, weights))
ERROR 01-17 14:00:32 engine.py:366]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/models/utils.py", line 198, in _load_module
ERROR 01-17 14:00:32 engine.py:366]     yield from self._load_module(prefix,
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/models/utils.py", line 175, in _load_module
ERROR 01-17 14:00:32 engine.py:366]     loaded_params = module_load_weights(weights)
ERROR 01-17 14:00:32 engine.py:366]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/models/llama.py", line 432, in load_weights
ERROR 01-17 14:00:32 engine.py:366]     weight_loader(param, loaded_weight)
ERROR 01-17 14:00:32 engine.py:366]   File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/layers/vocab_parallel_embedding.py", line 398, in weight_loader
ERROR 01-17 14:00:32 engine.py:366]     param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
ERROR 01-17 14:00:32 engine.py:366] RuntimeError: The size of tensor a (16004) must match the size of tensor b (16032) at non-singleton dimension 0
ERROR 01-17 14:00:32 multiproc_worker_utils.py:123] Worker VllmWorkerProcess pid 354 died, exit code: -15
INFO 01-17 14:00:32 multiproc_worker_utils.py:127] Killing local vLLM worker processes
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/usr/lib64/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib64/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 368, in run_mp_engine
    raise e
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 357, in run_mp_engine
    engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 119, in from_engine_args
    return cls(ipc_path=ipc_path,
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 71, in __init__
    self.engine = LLMEngine(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/engine/llm_engine.py", line 273, in __init__
    self.model_executor = executor_class(vllm_config=vllm_config, )
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/executor/distributed_gpu_executor.py", line 26, in __init__
    super().__init__(*args, **kwargs)
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/executor/executor_base.py", line 36, in __init__
    self._init_executor()
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/executor/multiproc_gpu_executor.py", line 83, in _init_executor
    self._run_workers("load_model",
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/executor/multiproc_gpu_executor.py", line 157, in _run_workers
    driver_worker_output = driver_worker_method(*args, **kwargs)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/worker/worker.py", line 155, in load_model
    self.model_runner.load_model()
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/worker/model_runner.py", line 1096, in load_model
    self.model = get_model(vllm_config=self.vllm_config)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/model_loader/__init__.py", line 12, in get_model
    return loader.load_model(vllm_config=vllm_config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/model_loader/loader.py", line 1234, in load_model
    model.load_weights(
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/models/llama.py", line 594, in load_weights
    return loader.load_weights(
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/models/utils.py", line 237, in load_weights
    autoloaded_weights = set(self._load_module("", self.module, weights))
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/models/utils.py", line 198, in _load_module
    yield from self._load_module(prefix,
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/models/utils.py", line 175, in _load_module
    loaded_params = module_load_weights(weights)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/models/llama.py", line 432, in load_weights
    weight_loader(param, loaded_weight)
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/model_executor/layers/vocab_parallel_embedding.py", line 398, in weight_loader
    param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
RuntimeError: The size of tensor a (16004) must match the size of tensor b (16032) at non-singleton dimension 0
[rank0]:[W117 14:00:33.196816217 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
Task exception was never retrieved
future: <Task finished name='Task-2' coro=<MQLLMEngineClient.run_output_handler_loop() done, defined at /opt/vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/client.py:178> exception=ZMQError('Operation not supported')>
Traceback (most recent call last):
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/client.py", line 184, in run_output_handler_loop
    while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/zmq/_future.py", line 400, in poll
    raise _zmq.ZMQError(_zmq.ENOTSUP)
zmq.error.ZMQError: Operation not supported
Task exception was never retrieved
future: <Task finished name='Task-3' coro=<MQLLMEngineClient.run_output_handler_loop() done, defined at /opt/vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/client.py:178> exception=ZMQError('Operation not supported')>
Traceback (most recent call last):
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/client.py", line 184, in run_output_handler_loop
    while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/zmq/_future.py", line 400, in poll
    raise _zmq.ZMQError(_zmq.ENOTSUP)
zmq.error.ZMQError: Operation not supported
Task exception was never retrieved
future: <Task finished name='Task-4' coro=<MQLLMEngineClient.run_output_handler_loop() done, defined at /opt/vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/client.py:178> exception=ZMQError('Operation not supported')>
Traceback (most recent call last):
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/engine/multiprocessing/client.py", line 184, in run_output_handler_loop
    while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/zmq/_future.py", line 400, in poll
    raise _zmq.ZMQError(_zmq.ENOTSUP)
zmq.error.ZMQError: Operation not supported
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 774, in <module>
    uvloop.run(run_server(args))
  File "/opt/vllm/lib64/python3.12/site-packages/uvloop/__init__.py", line 109, in run
    return __asyncio.run(
           ^^^^^^^^^^^^^^
  File "/usr/lib64/python3.12/asyncio/runners.py", line 194, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
  File "/opt/vllm/lib64/python3.12/site-packages/uvloop/__init__.py", line 61, in wrapper
    return await main
           ^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 740, in run_server
    async with build_async_engine_client(args) as engine_client:
  File "/usr/lib64/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 118, in build_async_engine_client
    async with build_async_engine_client_from_engine_args(
  File "/usr/lib64/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib64/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 223, in build_async_engine_client_from_engine_args
    raise RuntimeError(
RuntimeError: Engine process failed to start. See stack trace for the root cause.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@tarukumar tarukumar added the bug Something isn't working label Jan 17, 2025
@Isotr0py Isotr0py self-assigned this Jan 17, 2025
@NickLucche
Copy link
Contributor

I spent some time on it but couldn't yet figure out where it originates, it appears the shard_size you get here https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/vocab_parallel_embedding.py#L370 is 16032 (half of the padded vocab size), while param is materialized to be 16004 here https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/vocab_parallel_embedding.py#L359, which is half of the original vocab size 🤔

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants