Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Port binding failure when using pp > 1 after commit 7c7714d856eee6fa94aade729b67f00584f72a4c #8791

Open
1 task done
dengminhao opened this issue Sep 25, 2024 · 19 comments
Labels
bug Something isn't working stale

Comments

@dengminhao
Copy link

Your current environment

The output of `python collect_env.py`
PyTorch version: 2.4.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: CentOS Linux 7 (Core) (x86_64)
GCC version: (conda-forge gcc 11.4.0-13) 11.4.0
Clang version: Could not collect
CMake version: version 2.8.12.2
Libc version: glibc-2.17

Python version: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-3.10.0-1160.95.1.el7.x86_64-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090
GPU 3: NVIDIA GeForce RTX 3090
GPU 4: NVIDIA GeForce RTX 3090
GPU 5: NVIDIA GeForce RTX 3090

Nvidia driver version: 535.104.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                128
On-line CPU(s) list:   0-127
Thread(s) per core:    2
Core(s) per socket:    32
Socket(s):             2
NUMA node(s):          2
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 106
Model name:            Intel(R) Xeon(R) Platinum 8336C CPU @ 2.30GHz
Stepping:              6
CPU MHz:               800.000
CPU max MHz:           2301.0000
CPU min MHz:           800.0000
BogoMIPS:              4600.00
Virtualization:        VT-x
L1d cache:             48K
L1i cache:             32K
L2 cache:              1280K
L3 cache:              55296K
NUMA node0 CPU(s):     0-31,64-95
NUMA node1 CPU(s):     32-63,96-127
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch epb cat_l3 invpcid_single ssbd mba rsb_ctxsw ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq md_clear pconfig spec_ctrl intel_stibp flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] flashinfer==0.1.6+cu121torch2.4
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.6.20
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] pyzmq==26.2.0
[pip3] torch==2.4.0
[pip3] torchvision==0.19.0
[pip3] transformers==4.45.0.dev0
[pip3] triton==3.0.0
[pip3] zmq==0.0.0
[conda] flashinfer                0.1.6+cu121torch2.4          pypi_0    pypi
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
[conda] nvidia-ml-py              12.560.30                pypi_0    pypi
[conda] nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.6.20                  pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
[conda] pyzmq                     26.2.0                   pypi_0    pypi
[conda] torch                     2.4.0                    pypi_0    pypi
[conda] torchvision               0.19.0                   pypi_0    pypi
[conda] transformers              4.45.0.dev0              pypi_0    pypi
[conda] triton                    3.0.0                    pypi_0    pypi
[conda] zmq                       0.0.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.1.post2@0faab90eb006c677add65cd4c2d0f740a63e064d
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0	GPU1	GPU2	GPU3	GPU4	GPU5	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	NODE	NODE	SYS	SYS	SYS	0-31,64-95	0		N/A
GPU1	NODE	 X 	PXB	SYS	SYS	SYS	0-31,64-95	0		N/A
GPU2	NODE	PXB	 X 	SYS	SYS	SYS	0-31,64-95	0		N/A
GPU3	SYS	SYS	SYS	 X 	NODE	NODE	32-63,96-127	1		N/A
GPU4	SYS	SYS	SYS	NODE	 X 	PXB	32-63,96-127	1		N/A
GPU5	SYS	SYS	SYS	NODE	PXB	 X 	32-63,96-127	1		N/A


Model Input Dumps

No response

🐛 Describe the bug

After a binary search, I found that after commit 7c7714d, the main port binding will fail when pp> 1. But if we only set tp>1, the binding will success.

For example:
vllm serve /home/ai/ai/model/Qwen2.5-3B-Instruct/ --served-model-name qwen2.5-3B -pp 2 --trust-remote-code --max-model-len 4096 --enforce-eager --port 18004 --gpu-memory-utilization 1 --preemption-mode swap
will fail with ERROR:
ERROR: [Errno 98] error while attempting to bind on address ('0.0.0.0', 18004): address already in use

But
vllm serve /home/ai/ai/model/Qwen2.5-3B-Instruct/ --served-model-name qwen2.5-3B -tp 2 --trust-remote-code --max-model-len 4096 --enforce-eager --port 18004 --gpu-memory-utilization 1 --preemption-mode swap
will run successfully with:
INFO: Uvicorn running on http://0.0.0.0:18004 (Press CTRL+C to quit)

If we checkout to commit 9d104b5 in main branch, we can launch successfully with pp>1

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@dengminhao dengminhao added the bug Something isn't working label Sep 25, 2024
@dengminhao dengminhao changed the title [Bug]: Binding failure when using pp > 1 after commit 7c7714d856eee6fa94aade729b67f00584f72a4c [Bug]: Port binding failure when using pp > 1 after commit 7c7714d856eee6fa94aade729b67f00584f72a4c Sep 25, 2024
@SolitaryThinker
Copy link
Contributor

cc @youkaichao @kevin314

@youkaichao
Copy link
Member

looks strange. @robertgshaw2-neuralmagic do you have any ideas?

@dengminhao do you try --disable-frontend-multiprocessing ?

@dengminhao
Copy link
Author

dengminhao commented Sep 25, 2024

Nothing changed after --disable-frontend-multiprocessing.
I added code in uvicorn so I am sure there is only one process try to bind the main port.
Also I copied the code in util which trys to find any process binding this port, and still I cannot find any process holds the main port after binding failure

@dengminhao
Copy link
Author

I aslo tried reuse_port=True when calling loop.create_server in uvicorn/server.py. It won't help.

@dengminhao
Copy link
Author

Oh, another finding.
Adding "--disable-frontend-multiprocessing" will make tp=2 fails to bind also.
Remove "--disable-frontend-multiprocessing", tp=2 can run successfully.

@dengminhao
Copy link
Author

After this try, I have an idea and then find a quick fix
Hope it is helpful to you. It seems still a multi-process resource issue.

diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index 1b9eb30..8512b9a 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -530,6 +530,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
         raise KeyboardInterrupt("terminated")
 
     signal.signal(signal.SIGTERM, signal_handler)
+    temp_socket.close()
 
     async with build_async_engine_client(args) as engine_client:
         # If None, creation of the client failed and we exit.
@@ -541,7 +542,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
         model_config = await engine_client.get_model_config()
         init_app_state(engine_client, model_config, app.state, args)
 
-        temp_socket.close()
 
         shutdown_task = await serve_http(
             app,

@youkaichao
Copy link
Member

we use the socket to hold the port, so that engine does not take this port. your change will essentially disable the functionality.

It seems still a multi-process resource issue

did you try to set export VLLM_WORKER_MULTIPROC_METHOD=spawn or use ray backend?

@dengminhao
Copy link
Author

I know what you expect, so I said it was a quick fix. Roll back my change, and then:
export VLLM_WORKER_MULTIPROC_METHOD=spawn
works.
add --distributed-executor-backend=ray
also works.

@dengminhao
Copy link
Author

I upgrade to 0.6.2, the bug still exists.
And of couse I can use VLLM_WORKER_MULTIPROC_METHOD=spawn or distributed-executor-backend=ray, but I believe it is still a bug.

@youkaichao
Copy link
Member

is it because of port resource management after fork

@dengminhao
Copy link
Author

I guess in my case ,the temp socket resource duplicated when fork. But I don't know what's the difference between tp and pp in this case.

@HelloCard
Copy link

HelloCard commented Oct 8, 2024

same issue, use 0.6.2.
python3 -m vllm.entrypoints.openai.api_server --model /mnt/e/Code/models/Mistral-Nemo-Instruct-2407-W8A8 --max-model-len 8192 --gpu-memory-utilization 0.88 --swap_space=0 --pipeline-parallel-size 2 --dtype=half --max-num-seqs=1
ERROR: [Errno 98] error while attempting to bind on address ('0.0.0.0', 8000): address already in use
set --port 8001 then:
ERROR: [Errno 98] error while attempting to bind on address ('0.0.0.0', 8001): address already in use
export VLLM_WORKER_MULTIPROC_METHOD=spawn it worked, thanks.

@youkaichao
Copy link
Member

@HelloCard can you test if #8537 solve this issue? you can follow https://docs.vllm.ai/en/latest/getting_started/installation.html#install-the-latest-code to install the latest wheel.

@HelloCard
Copy link

@HelloCard can you test if #8537 solve this issue? you can follow https://docs.vllm.ai/en/latest/getting_started/installation.html#install-the-latest-code to install the latest wheel.

I get failures whether I use the --disable-frontend-multiprocessing argument or not, so I'm not sure my environment will test the results you expect. Anyway, I tried installing the latest version.

Successfully installed opencv-python-headless-4.10.0.84 vllm-0.6.3.dev144+gdc4aea67.d20241009
python3 -m vllm.entrypoints.openai.api_server --model /mnt/e/Code/models/Orca-2-13b-W8A8 --max-model-len 4096 --pipeline-parallel-size 2 --gpu-memory-utilization 0.85 --dtype=half --max-num-seqs=1 --swap-space 0 --disable-frontend-multiprocessing

INFO 10-09 19:24:17 api_server.py:528] vLLM API server version 0.6.3.dev144+gdc4aea67.d20241009
INFO 10-09 19:24:17 api_server.py:529] args: Namespace(host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=True, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='/mnt/e/Code/models/Orca-2-13b-W8A8', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', config_format='auto', dtype='half', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=4096, guided_decoding_backend='outlines', distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=2, tensor_parallel_size=1, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=0, swap_space=0.0, cpu_offload_gb=0, gpu_memory_utilization=0.85, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=1, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, enforce_eager=False, max_context_len_to_capture=None, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_disable_mqa_scorer=False, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, override_neuron_config=None, scheduling_policy='fcfs', disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False)
WARNING 10-09 19:24:17 config.py:1646] Casting torch.bfloat16 to torch.float16.
INFO 10-09 19:24:25 config.py:875] Defaulting to use mp for distributed inference
WARNING 10-09 19:24:25 config.py:346] Async output processing can not be enabled with pipeline parallel
INFO 10-09 19:24:30 llm_engine.py:237] Initializing an LLM engine (v0.6.3.dev144+gdc4aea67.d20241009) with config: model='/mnt/e/Code/models/Orca-2-13b-W8A8', speculative_config=None, tokenizer='/mnt/e/Code/models/Orca-2-13b-W8A8', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=2, disable_custom_all_reduce=False, quantization=compressed-tensors, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=/mnt/e/Code/models/Orca-2-13b-W8A8, use_v2_block_manager=True, num_scheduler_steps=1, chunked_prefill_enabled=False multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=False, use_cached_outputs=False, mm_processor_kwargs=None)
WARNING 10-09 19:24:30 multiproc_gpu_executor.py:53] Reducing Torch parallelism from 4 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 10-09 19:24:30 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
WARNING 10-09 19:24:30 utils.py:769] Using 'pin_memory=False' as WSL is detected. This may slow down the performance.
(VllmWorkerProcess pid=6409) WARNING 10-09 19:24:30 utils.py:769] Using 'pin_memory=False' as WSL is detected. This may slow down the performance.
INFO 10-09 19:24:30 selector.py:217] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
INFO 10-09 19:24:30 selector.py:116] Using XFormers backend.
(VllmWorkerProcess pid=6409) INFO 10-09 19:24:30 selector.py:217] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
(VllmWorkerProcess pid=6409) INFO 10-09 19:24:30 selector.py:116] Using XFormers backend.
/root/miniconda3/lib/python3.12/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("xformers_flash::flash_fwd")
(VllmWorkerProcess pid=6409) /root/miniconda3/lib/python3.12/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
(VllmWorkerProcess pid=6409)   @torch.library.impl_abstract("xformers_flash::flash_fwd")
(VllmWorkerProcess pid=6409) /root/miniconda3/lib/python3.12/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
/root/miniconda3/lib/python3.12/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("xformers_flash::flash_bwd")
(VllmWorkerProcess pid=6409)   @torch.library.impl_abstract("xformers_flash::flash_bwd")
(VllmWorkerProcess pid=6409) INFO 10-09 19:24:32 multiproc_worker_utils.py:216] Worker ready; awaiting tasks
(VllmWorkerProcess pid=6409) INFO 10-09 19:24:35 utils.py:1005] Found nccl from library libnccl.so.2
INFO 10-09 19:24:35 utils.py:1005] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=6409) INFO 10-09 19:24:35 pynccl.py:63] vLLM is using nccl==2.20.5
INFO 10-09 19:24:35 pynccl.py:63] vLLM is using nccl==2.20.5
INFO 10-09 19:24:35 model_runner.py:1051] Starting to load model /mnt/e/Code/models/Orca-2-13b-W8A8...
(VllmWorkerProcess pid=6409) INFO 10-09 19:24:35 model_runner.py:1051] Starting to load model /mnt/e/Code/models/Orca-2-13b-W8A8...
INFO 10-09 19:24:35 selector.py:217] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
(VllmWorkerProcess pid=6409) INFO 10-09 19:24:35 selector.py:217] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
(VllmWorkerProcess pid=6409) INFO 10-09 19:24:35 selector.py:116] Using XFormers backend.
INFO 10-09 19:24:35 selector.py:116] Using XFormers backend.
Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  33% Completed | 1/3 [02:15<04:31, 135.58s/it]
Loading safetensors checkpoint shards:  67% Completed | 2/3 [03:03<01:24, 84.04s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [03:03<00:00, 45.74s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [03:03<00:00, 61.24s/it]

INFO 10-09 19:27:39 model_runner.py:1062] Loading model weights took 6.2885 GB
(VllmWorkerProcess pid=6409) INFO 10-09 19:27:40 model_runner.py:1062] Loading model weights took 6.2885 GB
ERROR 10-09 19:27:40 _custom_ops.py:53] Error in calling custom op cutlass_scaled_mm: No compiled cutlass_scaled_mm for a compute capability less than CUDA device capability: 75
ERROR 10-09 19:27:40 _custom_ops.py:53] Not implemented or built, mostly likely because the current current device does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set incorrectly while building)
(VllmWorkerProcess pid=6409) ERROR 10-09 19:27:40 _custom_ops.py:53] Error in calling custom op cutlass_scaled_mm: No compiled cutlass_scaled_mm for a compute capability less than CUDA device capability: 75
(VllmWorkerProcess pid=6409) ERROR 10-09 19:27:40 _custom_ops.py:53] Not implemented or built, mostly likely because the current current device does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set incorrectly while building)
INFO 10-09 19:27:40 model_runner_base.py:120] Writing input of failed execution to /tmp/err_execute_model_input_20241009-192740.pkl...
(VllmWorkerProcess pid=6409) INFO 10-09 19:27:40 model_runner_base.py:120] Writing input of failed execution to /tmp/err_execute_model_input_20241009-192740.pkl...
INFO 10-09 19:27:41 model_runner_base.py:149] Completed writing input of failed execution to /tmp/err_execute_model_input_20241009-192740.pkl.
[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/_custom_ops.py", line 45, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/_custom_ops.py", line 512, in cutlass_scaled_mm
[rank0]:     torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_ops.py", line 1061, in __call__
[rank0]:     return self_._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: NotImplementedError: No compiled cutlass_scaled_mm for a compute capability less than CUDA device capability: 75

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/worker/model_runner_base.py", line 116, in _wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1648, in execute_model
[rank0]:     hidden_or_intermediate_states = model_executable(
[rank0]:                                     ^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 547, in forward
[rank0]:     model_output = self.model(input_ids, positions, kv_caches,
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 337, in forward
[rank0]:     hidden_states, residual = layer(positions, hidden_states,
[rank0]:                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 256, in forward
[rank0]:     hidden_states = self.self_attn(positions=positions,
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 183, in forward
[rank0]:     qkv, _ = self.qkv_proj(hidden_states)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/model_executor/layers/linear.py", line 371, in forward
[rank0]:     output_parallel = self.quant_method.apply(self, input_, bias)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py", line 368, in apply
[rank0]:     return scheme.apply_weights(layer, x, bias=bias)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py", line 143, in apply_weights
[rank0]:     return apply_int8_linear(input=x,
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 217, in apply_int8_linear
[rank0]:     return ops.cutlass_scaled_mm(x_q,
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/_custom_ops.py", line 54, in wrapper
[rank0]:     raise NotImplementedError(msg % (fn.__name__, e)) from e
[rank0]: NotImplementedError: Error in calling custom op cutlass_scaled_mm: No compiled cutlass_scaled_mm for a compute capability less than CUDA device capability: 75
[rank0]: Not implemented or built, mostly likely because the current current device does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set incorrectly while building)

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank0]:   File "<frozen runpy>", line 88, in _run_code
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 585, in <module>
[rank0]:     uvloop.run(run_server(args))
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/uvloop/__init__.py", line 109, in run
[rank0]:     return __asyncio.run(
[rank0]:            ^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/asyncio/runners.py", line 194, in run
[rank0]:     return runner.run(main)
[rank0]:            ^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/asyncio/runners.py", line 118, in run
[rank0]:     return self._loop.run_until_complete(task)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "uvloop/loop.pyx", line 1517, in uvloop.loop.Loop.run_until_complete
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/uvloop/__init__.py", line 61, in wrapper
[rank0]:     return await main
[rank0]:            ^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 552, in run_server
[rank0]:     async with build_async_engine_client(args) as engine_client:
[rank0]:   File "/root/miniconda3/lib/python3.12/contextlib.py", line 210, in __aenter__
[rank0]:     return await anext(self.gen)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 107, in build_async_engine_client
[rank0]:     async with build_async_engine_client_from_engine_args(
[rank0]:   File "/root/miniconda3/lib/python3.12/contextlib.py", line 210, in __aenter__
[rank0]:     return await anext(self.gen)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 141, in build_async_engine_client_from_engine_args
[rank0]:     engine_client = await asyncio.get_running_loop().run_in_executor(
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/concurrent/futures/thread.py", line 58, in run
[rank0]:     result = self.fn(*self.args, **self.kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/engine/async_llm_engine.py", line 674, in from_engine_args
[rank0]:     engine = cls(
[rank0]:              ^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/engine/async_llm_engine.py", line 569, in __init__
[rank0]:     self.engine = self._engine_class(*args, **kwargs)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/engine/async_llm_engine.py", line 265, in __init__
[rank0]:     super().__init__(*args, **kwargs)
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 349, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 484, in _initialize_kv_caches
[rank0]:     self.model_executor.determine_num_available_blocks())
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/executor/distributed_gpu_executor.py", line 39, in determine_num_available_blocks
[rank0]:     num_blocks = self._run_workers("determine_num_available_blocks", )
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/executor/multiproc_gpu_executor.py", line 192, in _run_workers
[rank0]:     driver_worker_output = driver_worker_method(*args, **kwargs)
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/worker/worker.py", line 223, in determine_num_available_blocks
[rank0]:     self.model_runner.profile_run()
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1292, in profile_run
[rank0]:     self.execute_model(model_input, kv_caches, intermediate_tensors)
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/vllm/worker/model_runner_base.py", line 152, in _wrapper
[rank0]:     raise type(err)(
[rank0]: NotImplementedError: Error in model execution (input dumped to /tmp/err_execute_model_input_20241009-192740.pkl): Error in calling custom op cutlass_scaled_mm: No compiled cutlass_scaled_mm for a compute capability less than CUDA device capability: 75
[rank0]: Not implemented or built, mostly likely because the current current device does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set incorrectly while building)
(VllmWorkerProcess pid=6409) INFO 10-09 19:27:41 multiproc_worker_utils.py:242] Worker exiting
INFO 10-09 19:27:41 multiproc_worker_utils.py:121] Killing local vLLM worker processes
Fatal Python error: _enter_buffered_busy: could not acquire lock for <_io.BufferedWriter name='<stdout>'> at interpreter shutdown, possibly due to daemon threads
Python runtime state: finalizing (tstate=0x00000000009c11f8)

Current thread 0x00007f24eb01f440 (most recent call first):
  <no Python frame>

Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, _brotli, zstandard.backend_c, yaml._yaml, markupsafe._speedups, PIL._imaging, psutil._psutil_linux, psutil._psutil_posix, msgspec._core, sentencepiece._sentencepiece, PIL._imagingft, regex._regex, msgpack._cmsgpack, google._upb._message, setproctitle, uvloop.loop, ray._raylet, multidict._multidict, yarl._helpers_c, yarl._quoting_c, aiohttp._helpers, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket, frozenlist._frozenlist, zmq.backend.cython._zmq (total: 45)
Aborted (core dumped)

Unfortunately, it seems that I cannot use the nightly wheel because my environment is WSL2?

@youkaichao
Copy link
Member

ERROR 10-09 19:27:40 _custom_ops.py:53] Error in calling custom op cutlass_scaled_mm: No compiled cutlass_scaled_mm for a compute capability less than CUDA device capability: 75

you are trying to run a quantized model W8A8 while your gpu does not support it. try an unquantized model please.

@HelloCard
Copy link

ERROR 10-09 19:27:40 _custom_ops.py:53] Error in calling custom op cutlass_scaled_mm: No compiled cutlass_scaled_mm for a compute capability less than CUDA device capability: 75

you are trying to run a quantized model W8A8 while your gpu does not support it. try an unquantized model please.

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.50                 Driver Version: 531.79       CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 2080 Ti      On | 00000000:01:00.0  On |                  N/A |
|  0%   36C    P0              300W / 300W|   7011MiB / 22528MiB |     96%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 2080 Ti      On | 00000000:02:00.0 Off |                  N/A |
| 84%   36C    P0              280W / 300W|   7011MiB / 22528MiB |     91%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A        41      G   /Xwayland                                 N/A      |
|    1   N/A  N/A        41      G   /Xwayland                                 N/A      |
+---------------------------------------------------------------------------------------+

My GPU obviously supports W8A8 quantization, because I have run the model in W8A8 format many times with version 0.6.2.
Anyway, I will try it with the unquantized version later.
In addition, I edited the api_server.py of version 0.6.2 according to the suggestion here #9172, which also solved the problem.

@HelloCard
Copy link

ERROR 10-09 19:27:40 _custom_ops.py:53] Error in calling custom op cutlass_scaled_mm: No compiled cutlass_scaled_mm for a compute capability less than CUDA device capability: 75

you are trying to run a quantized model W8A8 while your gpu does not support it. try an unquantized model please.

Successfully installed vllm-0.6.3.dev172+ge808156f.d20241011
python3 -m vllm.entrypoints.openai.api_server --model /mnt/e/Code/models/Phi-3-medium-4k-instruct --max-model-len 4096 --gpu-memory-utilization 0.7 --swap_space=0 --tensor-parallel-size 2 --dtype=half --max-num-seqs=1

INFO 10-11 20:41:31 launcher.py:19] Available routes are:
INFO 10-11 20:41:31 launcher.py:27] Route: /openapi.json, Methods: GET, HEAD
INFO 10-11 20:41:31 launcher.py:27] Route: /docs, Methods: GET, HEAD
INFO 10-11 20:41:31 launcher.py:27] Route: /docs/oauth2-redirect, Methods: GET, HEAD
INFO 10-11 20:41:31 launcher.py:27] Route: /redoc, Methods: GET, HEAD
INFO 10-11 20:41:31 launcher.py:27] Route: /health, Methods: GET
INFO 10-11 20:41:31 launcher.py:27] Route: /tokenize, Methods: POST
INFO 10-11 20:41:31 launcher.py:27] Route: /detokenize, Methods: POST
INFO 10-11 20:41:31 launcher.py:27] Route: /v1/models, Methods: GET
INFO 10-11 20:41:31 launcher.py:27] Route: /version, Methods: GET
INFO 10-11 20:41:31 launcher.py:27] Route: /v1/chat/completions, Methods: POST
INFO 10-11 20:41:31 launcher.py:27] Route: /v1/completions, Methods: POST
INFO 10-11 20:41:31 launcher.py:27] Route: /v1/embeddings, Methods: POST
INFO:     Started server process [6520]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on socket ('0.0.0.0', 8000) (Press CTRL+C to quit)
INFO 10-11 20:41:41 metrics.py:345] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.

Loading microsoft/Phi-3-medium-4k-instruct everything went well.
But I still have to use the 0.6.2 version of the release, because I mainly use models in W8A8 format, and it seems that the nightly version cannot load models in W8A8 format is another bug.

@HelloCard
Copy link

ERROR 10-09 19:27:40 _custom_ops.py:53] Error in calling custom op cutlass_scaled_mm: No compiled cutlass_scaled_mm for a compute capability less than CUDA device capability: 75

you are trying to run a quantized model W8A8 while your gpu does not support it. try an unquantized model please.

(base) root@DESKTOP-PEPA2G9:/mnt/c/Windows/system32# python3 -m vllm.entrypoints.openai.api_server --model /mnt/e/Code/models/Phi-3-medium-4k-instruct --max-model-len 4096 --gpu-memory-utilization 0.7 --swap_space=0 --tensor-parallel-size 2 --dtype=half --max-num-seqs=1 --disable-frontend-multiprocessing
failed:

INFO 10-12 00:17:01 api_server.py:526] vLLM API server version 0.6.1.dev238+ge2c6e0a82
INFO 10-12 00:17:01 api_server.py:527] args: Namespace(host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=True, enable_auto_tool_choice=False, tool_call_parser=None, model='/mnt/e/Code/models/Phi-3-medium-4k-instruct', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', config_format='auto', dtype='half', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=4096, guided_decoding_backend='outlines', distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=2, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, disable_sliding_window=False, use_v2_block_manager=False, num_lookahead_slots=0, seed=0, swap_space=0.0, cpu_offload_gb=0, gpu_memory_utilization=0.7, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=1, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, enforce_eager=False, max_context_len_to_capture=None, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, multi_step_stream_outputs=False, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, override_neuron_config=None, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False)
WARNING 10-12 00:17:01 config.py:1656] Casting torch.bfloat16 to torch.float16.
INFO 10-12 00:17:01 config.py:899] Defaulting to use mp for distributed inference
INFO 10-12 00:17:01 llm_engine.py:226] Initializing an LLM engine (v0.6.1.dev238+ge2c6e0a82) with config: model='/mnt/e/Code/models/Phi-3-medium-4k-instruct', speculative_config=None, tokenizer='/mnt/e/Code/models/Phi-3-medium-4k-instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=/mnt/e/Code/models/Phi-3-medium-4k-instruct, use_v2_block_manager=False, num_scheduler_steps=1, multi_step_stream_outputs=False, enable_prefix_caching=False, use_async_output_proc=True, use_cached_outputs=False, mm_processor_kwargs=None)
WARNING 10-12 00:17:01 multiproc_gpu_executor.py:53] Reducing Torch parallelism from 4 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 10-12 00:17:01 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
(VllmWorkerProcess pid=639) WARNING 10-12 00:17:01 utils.py:747] Using 'pin_memory=False' as WSL is detected. This may slow down the performance.
WARNING 10-12 00:17:01 utils.py:747] Using 'pin_memory=False' as WSL is detected. This may slow down the performance.
INFO 10-12 00:17:01 selector.py:217] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
INFO 10-12 00:17:01 selector.py:116] Using XFormers backend.
(VllmWorkerProcess pid=639) INFO 10-12 00:17:01 selector.py:217] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
(VllmWorkerProcess pid=639) INFO 10-12 00:17:01 selector.py:116] Using XFormers backend.
/root/miniconda3/lib/python3.12/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("xformers_flash::flash_fwd")
(VllmWorkerProcess pid=639) /root/miniconda3/lib/python3.12/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
(VllmWorkerProcess pid=639)   @torch.library.impl_abstract("xformers_flash::flash_fwd")
/root/miniconda3/lib/python3.12/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("xformers_flash::flash_bwd")
(VllmWorkerProcess pid=639) /root/miniconda3/lib/python3.12/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
(VllmWorkerProcess pid=639)   @torch.library.impl_abstract("xformers_flash::flash_bwd")
(VllmWorkerProcess pid=639) INFO 10-12 00:17:03 multiproc_worker_utils.py:218] Worker ready; awaiting tasks
(VllmWorkerProcess pid=639) INFO 10-12 00:17:05 utils.py:992] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=639) INFO 10-12 00:17:05 pynccl.py:63] vLLM is using nccl==2.20.5
INFO 10-12 00:17:05 utils.py:992] Found nccl from library libnccl.so.2
INFO 10-12 00:17:05 pynccl.py:63] vLLM is using nccl==2.20.5
INFO 10-12 00:17:06 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
(VllmWorkerProcess pid=639) INFO 10-12 00:17:06 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
INFO 10-12 00:17:06 shm_broadcast.py:241] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1], buffer=<vllm.distributed.device_communicators.shm_broadcast.ShmRingBuffer object at 0x7f037e3fc3e0>, local_subscribe_port=34637, remote_subscribe_port=None)
INFO 10-12 00:17:06 model_runner.py:1014] Starting to load model /mnt/e/Code/models/Phi-3-medium-4k-instruct...
(VllmWorkerProcess pid=639) INFO 10-12 00:17:06 model_runner.py:1014] Starting to load model /mnt/e/Code/models/Phi-3-medium-4k-instruct...
INFO 10-12 00:17:06 selector.py:217] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
(VllmWorkerProcess pid=639) INFO 10-12 00:17:06 selector.py:217] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
(VllmWorkerProcess pid=639) INFO 10-12 00:17:06 selector.py:116] Using XFormers backend.
INFO 10-12 00:17:06 selector.py:116] Using XFormers backend.
Loading safetensors checkpoint shards:   0% Completed | 0/6 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  17% Completed | 1/6 [01:37<08:08, 97.80s/it]
Loading safetensors checkpoint shards:  33% Completed | 2/6 [03:15<06:31, 97.86s/it]
Loading safetensors checkpoint shards:  50% Completed | 3/6 [04:57<04:58, 99.48s/it]
Loading safetensors checkpoint shards:  67% Completed | 4/6 [06:42<03:23, 101.71s/it]
Loading safetensors checkpoint shards:  83% Completed | 5/6 [08:27<01:43, 103.02s/it]
Loading safetensors checkpoint shards: 100% Completed | 6/6 [09:47<00:00, 95.16s/it]
Loading safetensors checkpoint shards: 100% Completed | 6/6 [09:47<00:00, 97.91s/it]

INFO 10-12 00:26:54 model_runner.py:1025] Loading model weights took 13.1200 GB
(VllmWorkerProcess pid=639) INFO 10-12 00:26:54 model_runner.py:1025] Loading model weights took 13.1200 GB
INFO 10-12 00:26:57 distributed_gpu_executor.py:57] # GPU blocks: 465, # CPU blocks: 0
INFO 10-12 00:26:57 model_runner.py:1329] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 10-12 00:26:57 model_runner.py:1333] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=639) INFO 10-12 00:26:57 model_runner.py:1329] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
(VllmWorkerProcess pid=639) INFO 10-12 00:26:57 model_runner.py:1333] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=639) INFO 10-12 00:26:58 custom_all_reduce.py:229] Registering 81 cuda graph addresses
INFO 10-12 00:26:58 custom_all_reduce.py:229] Registering 81 cuda graph addresses
(VllmWorkerProcess pid=639) INFO 10-12 00:26:58 model_runner.py:1456] Graph capturing finished in 1 secs.
INFO 10-12 00:26:58 model_runner.py:1456] Graph capturing finished in 1 secs.
WARNING 10-12 00:26:58 serving_embedding.py:189] embedding_mode is False. Embedding API will not work.
INFO 10-12 00:26:58 launcher.py:19] Available routes are:
INFO 10-12 00:26:58 launcher.py:27] Route: /openapi.json, Methods: HEAD, GET
INFO 10-12 00:26:58 launcher.py:27] Route: /docs, Methods: HEAD, GET
INFO 10-12 00:26:58 launcher.py:27] Route: /docs/oauth2-redirect, Methods: HEAD, GET
INFO 10-12 00:26:58 launcher.py:27] Route: /redoc, Methods: HEAD, GET
INFO 10-12 00:26:58 launcher.py:27] Route: /health, Methods: GET
INFO 10-12 00:26:58 launcher.py:27] Route: /tokenize, Methods: POST
INFO 10-12 00:26:58 launcher.py:27] Route: /detokenize, Methods: POST
INFO 10-12 00:26:58 launcher.py:27] Route: /v1/models, Methods: GET
INFO 10-12 00:26:58 launcher.py:27] Route: /version, Methods: GET
INFO 10-12 00:26:58 launcher.py:27] Route: /v1/chat/completions, Methods: POST
INFO 10-12 00:26:58 launcher.py:27] Route: /v1/completions, Methods: POST
INFO 10-12 00:26:58 launcher.py:27] Route: /v1/embeddings, Methods: POST
INFO:     Started server process [593]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
ERROR:    [Errno 98] error while attempting to bind on address ('0.0.0.0', 8000): address already in use
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.

Copy link

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

@github-actions github-actions bot added the stale label Jan 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale
Projects
None yet
Development

No branches or pull requests

4 participants