Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Implement disagg prefill by StatelessProcessGroup #10502

Merged
merged 35 commits into from
Dec 2, 2024

Conversation

KuntaiDu
Copy link
Collaborator

@KuntaiDu KuntaiDu commented Nov 20, 2024

A light-weight implementation of disaggregated prefill. I switched from PR #8498 to here in order to fix DCO issues.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

KuntaiDu and others added 2 commits November 20, 2024 21:46
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Co-authored-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
@KuntaiDu KuntaiDu force-pushed the kuntai-disagg-fix-DCO branch from 4541111 to 1eadc94 Compare November 20, 2024 21:47
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
@KuntaiDu KuntaiDu added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 20, 2024
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
… package

Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Copy link

mergify bot commented Nov 22, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @KuntaiDu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 22, 2024
@mergify mergify bot removed the needs-rebase label Nov 22, 2024
Copy link

mergify bot commented Nov 22, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @KuntaiDu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 22, 2024
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
@mergify mergify bot removed the needs-rebase label Nov 24, 2024
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
@KuntaiDu KuntaiDu added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 1, 2024
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
@KuntaiDu KuntaiDu merged commit 0590ec3 into vllm-project:main Dec 2, 2024
68 checks passed
@KuntaiDu KuntaiDu deleted the kuntai-disagg-fix-DCO branch December 2, 2024 01:01
@wilhelmjung
Copy link

wilhelmjung commented Dec 2, 2024

Hello. @KuntaiDu Shortly after a successful invocation to the proxy. I got these errors from the vllm instances.
And the second invocation will be blocked. Anything goes wrong? or any parameter should be configured. Thx!

[rank0]:[W1202 17:11:14.631061365 socket.cpp:487] [c10d] waitForInput: socket SocketImpl(fd=125, addr=[::ffff:127.0.0.1]:50228, remote=[::ffff:127.0.0.1]:14580) timed out after 300000ms
ERROR 12-02 17:11:14 pynccl_pipe.py:261] Encountering exception in KV receiving thread
ERROR 12-02 17:11:14 pynccl_pipe.py:262] wait timeout after 300000ms, keys: /send_to/0/4
ERROR 12-02 17:11:14 pynccl_pipe.py:263] My device: cpu
Traceback (most recent call last):
File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 259, in recv_tensor
tensor = future.result()
File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/_base.py", line 458, in result
return self.__get_result()
File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 189, in _recv_impl
metadata = self._recv_metadata()
File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 164, in _recv_metadata
return self.group.recv_obj(self.target_rank_for_recv)
File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/utils.py", line 148, in recv_obj
self.store.get(
torch.distributed.DistStoreError: wait timeout after 300000ms, keys: /send_to/0/4
Exception in thread Thread-3 (drop_select_handler):
Traceback (most recent call last):
File "/data/miniconda3/envs/will/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/data/miniconda3/envs/will/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py", line 177, in drop_select_handler
raise e
File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py", line 132, in drop_select_handler
signal = self.signal_pipe.recv_tensor()
File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 266, in recv_tensor
raise e
File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 259, in recv_tensor
tensor = future.result()
File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/_base.py", line 458, in result
return self.__get_result()
File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 189, in _recv_impl
metadata = self._recv_metadata()
File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 164, in _recv_metadata
return self.group.recv_obj(self.target_rank_for_recv)
File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/utils.py", line 148, in recv_obj
self.store.get(
torch.distributed.DistStoreError: wait timeout after 300000ms, keys: /send_to/0/4

@ShangmingCai
Copy link
Contributor

Hello. @KuntaiDu Shortly after a successful invocation to the proxy. I got these errors from the vllm instances. And the second invocation will be blocked. Anything goes wrong? or any parameter should be configured. Thx!

[rank0]:[W1202 17:11:14.631061365 socket.cpp:487] [c10d] waitForInput: socket SocketImpl(fd=125, addr=[::ffff:127.0.0.1]:50228, remote=[::ffff:127.0.0.1]:14580) timed out after 300000ms ERROR 12-02 17:11:14 pynccl_pipe.py:261] Encountering exception in KV receiving thread ERROR 12-02 17:11:14 pynccl_pipe.py:262] wait timeout after 300000ms, keys: /send_to/0/4 ERROR 12-02 17:11:14 pynccl_pipe.py:263] My device: cpu Traceback (most recent call last): File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 259, in recv_tensor tensor = future.result() File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/_base.py", line 458, in result return self.__get_result() File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result raise self._exception File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, **self.kwargs) File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 189, in _recv_impl metadata = self._recv_metadata() File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 164, in _recv_metadata return self.group.recv_obj(self.target_rank_for_recv) File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/utils.py", line 148, in recv_obj self.store.get( torch.distributed.DistStoreError: wait timeout after 300000ms, keys: /send_to/0/4 Exception in thread Thread-3 (drop_select_handler): Traceback (most recent call last): File "/data/miniconda3/envs/will/lib/python3.10/threading.py", line 1016, in _bootstrap_inner self.run() File "/data/miniconda3/envs/will/lib/python3.10/threading.py", line 953, in run self._target(*self._args, **self._kwargs) File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py", line 177, in drop_select_handler raise e File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py", line 132, in drop_select_handler signal = self.signal_pipe.recv_tensor() File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 266, in recv_tensor raise e File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 259, in recv_tensor tensor = future.result() File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/_base.py", line 458, in result return self.__get_result() File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result raise self._exception File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, **self.kwargs) File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 189, in _recv_impl metadata = self._recv_metadata() File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 164, in _recv_metadata return self.group.recv_obj(self.target_rank_for_recv) File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/utils.py", line 148, in recv_obj self.store.get( torch.distributed.DistStoreError: wait timeout after 300000ms, keys: /send_to/0/4

Currently, PyNcclPipe does not implement heartbeats, so the connection will be closed if no request is sent within 5 minutes.

afeldman-nm pushed a commit to neuralmagic/vllm that referenced this pull request Dec 2, 2024
…t#10502)

This PR provides initial support for single-node disaggregated prefill in 1P1D scenario.
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Co-authored-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
@wilhelmjung
Copy link

Hello. @KuntaiDu Shortly after a successful invocation to the proxy. I got these errors from the vllm instances. And the second invocation will be blocked. Anything goes wrong? or any parameter should be configured. Thx!
[rank0]:[W1202 17:11:14.631061365 socket.cpp:487] [c10d] waitForInput: socket SocketImpl(fd=125, addr=[::ffff:127.0.0.1]:50228, remote=[::ffff:127.0.0.1]:14580) timed out after 300000ms ERROR 12-02 17:11:14 pynccl_pipe.py:261] Encountering exception in KV receiving thread ERROR 12-02 17:11:14 pynccl_pipe.py:262] wait timeout after 300000ms, keys: /send_to/0/4 ERROR 12-02 17:11:14 pynccl_pipe.py:263] My device: cpu Traceback (most recent call last): File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 259, in recv_tensor tensor = future.result() File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/_base.py", line 458, in result return self.__get_result() File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result raise self._exception File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, **self.kwargs) File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 189, in _recv_impl metadata = self._recv_metadata() File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 164, in _recv_metadata return self.group.recv_obj(self.target_rank_for_recv) File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/utils.py", line 148, in recv_obj self.store.get( torch.distributed.DistStoreError: wait timeout after 300000ms, keys: /send_to/0/4 Exception in thread Thread-3 (drop_select_handler): Traceback (most recent call last): File "/data/miniconda3/envs/will/lib/python3.10/threading.py", line 1016, in _bootstrap_inner self.run() File "/data/miniconda3/envs/will/lib/python3.10/threading.py", line 953, in run self._target(*self._args, **self._kwargs) File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py", line 177, in drop_select_handler raise e File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py", line 132, in drop_select_handler signal = self.signal_pipe.recv_tensor() File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 266, in recv_tensor raise e File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 259, in recv_tensor tensor = future.result() File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/_base.py", line 458, in result return self.__get_result() File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result raise self._exception File "/data/miniconda3/envs/will/lib/python3.10/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, **self.kwargs) File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 189, in _recv_impl metadata = self._recv_metadata() File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py", line 164, in _recv_metadata return self.group.recv_obj(self.target_rank_for_recv) File "/data/miniconda3/envs/will/lib/python3.10/site-packages/vllm/distributed/utils.py", line 148, in recv_obj self.store.get( torch.distributed.DistStoreError: wait timeout after 300000ms, keys: /send_to/0/4

Currently, PyNcclPipe does not implement heartbeats, so the connection will be closed if no request is sent within 5 minutes.

Is there a workaround to solve this timeout. Or how to modify the 5min timeout. Thanks!

@ShangmingCai
Copy link
Contributor

Is there a workaround to solve this timeout. Or how to modify the 5min timeout. Thanks!

This is a known issue, it will be addressed in the future PR by @KuntaiDu . If you need a quick workaround, you can modify disagg_prefill_proxy_server.py to send a shadow request every 4 min through apscheduler.

@liweiqing1997
Copy link

Hello, I encountered the following issue while running the decomposition reasoning on the 'mian' branch:
ValueError: not enough values to unpack (expected 4, got 2).

The actual real kvcache shape is “kv_cache[0] shape torch.Size([2162, 81920])”

INFO 12-03 14:31:48 model_runner_base.py:120] Writing input of failed execution to /tmp/err_execute_model_input_20241203-143148.pkl...
INFO 12-03 14:31:48 model_runner_base.py:149] Completed writing input of failed execution to /tmp/err_execute_model_input_20241203-143148.pkl.
ERROR 12-03 14:31:48 engine.py:135] ValueError('Error in model execution (input dumped to /tmp/err_execute_model_input_20241203-143148.pkl): not enough values to unpack (expected 4, got 2)')
ERROR 12-03 14:31:48 engine.py:135] Traceback (most recent call last):
ERROR 12-03 14:31:48 engine.py:135] File "/mnt/data_disk101/data_disk/lwq/LLM_INFER/split_platform/opensource/vllm-kuntai-disagg-refactor_1202/vllm/worker/model_runner_base.py", line 116, in _wrapper
ERROR 12-03 14:31:48 engine.py:135] return func(*args, **kwargs)
ERROR 12-03 14:31:48 engine.py:135] File "/mnt/data_disk101/data_disk/lwq/LLM_INFER/split_platform/opensource/vllm-kuntai-disagg-refactor_1202/vllm/worker/model_runner.py", line 1718, in execute_model
ERROR 12-03 14:31:48 engine.py:135] get_kv_transfer_group().send_kv_caches_and_hidden_states(
ERROR 12-03 14:31:48 engine.py:135] File "/mnt/data_disk101/data_disk/lwq/LLM_INFER/split_platform/opensource/vllm-kuntai-disagg-refactor_1202/vllm/distributed/kv_transfer/kv_transfer_agent.py", line 60, in send_kv_caches_and_hidden_states
ERROR 12-03 14:31:48 engine.py:135] self.connector.send_kv_caches_and_hidden_states(
ERROR 12-03 14:31:48 engine.py:135] File "/mnt/data_disk101/data_disk/lwq/LLM_INFER/split_platform/opensource/vllm-kuntai-disagg-refactor_1202/vllm/distributed/kv_transfer/kv_connector/simple_connector.py", line 134, in send_kv_caches_and_hidden_states
ERROR 12-03 14:31:48 engine.py:135] _, _, num_heads, head_size = kv_cache[0].shape
ERROR 12-03 14:31:48 engine.py:135] ValueError: not enough values to unpack (expected 4, got 2)
ERROR 12-03 14:31:48 engine.py:135]

My startup command is:

CUDA_VISIBLE_DEVICES=3 nohup nohup python3
-m vllm.entrypoints.openai.api_server
--model $model
--port 8100
--max-model-len 1000
--gpu-memory-utilization 0.7
--kv-transfer-config '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > $Log_folder/p.log 2>&1 &

CUDA_VISIBLE_DEVICES=4 nohup python3
-m vllm.entrypoints.openai.api_server
--model $model
--port 8200
--max-model-len 1000
--gpu-memory-utilization 0.7
--kv-transfer-config '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' > $Log_folder/D.log 2>&1 &

nohup python3 disagg_prefill_proxy_server.py > $Log_folder/proxy_server.log 2>&1 &

My startup command is:

CUDA_VISIBLE_DEVICES=3 nohup nohup python3
-m vllm.entrypoints.openai.api_server
--model $model
--port 8100
--max-model-len 1000
--gpu-memory-utilization 0.7
--kv-transfer-config '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > $Log_folder/p.log 2>&1 &

CUDA_VISIBLE_DEVICES=4 nohup python3
-m vllm.entrypoints.openai.api_server
--model $model
--port 8200
--max-model-len 1000
--gpu-memory-utilization 0.7
--kv-transfer-config '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' > $Log_folder/D.log 2>&1 &

nohup python3 disagg_prefill_proxy_server.py > $Log_folder/proxy_server.log 2>&1 &

@KuntaiDu
Copy link
Collaborator Author

KuntaiDu commented Dec 3, 2024

Hi @liweiqing1997 , currently I only tested Llama-style model. What kind of model are you using?

@ShangmingCai
Copy link
Contributor

Hello, I encountered the following issue while running the decomposition reasoning on the 'mian' branch: ValueError: not enough values to unpack (expected 4, got 2).

The actual real kvcache shape is “kv_cache[0] shape torch.Size([2162, 81920])”

INFO 12-03 14:31:48 model_runner_base.py:120] Writing input of failed execution to /tmp/err_execute_model_input_20241203-143148.pkl... INFO 12-03 14:31:48 model_runner_base.py:149] Completed writing input of failed execution to /tmp/err_execute_model_input_20241203-143148.pkl. ERROR 12-03 14:31:48 engine.py:135] ValueError('Error in model execution (input dumped to /tmp/err_execute_model_input_20241203-143148.pkl): not enough values to unpack (expected 4, got 2)') ERROR 12-03 14:31:48 engine.py:135] Traceback (most recent call last): ERROR 12-03 14:31:48 engine.py:135] File "/mnt/data_disk101/data_disk/lwq/LLM_INFER/split_platform/opensource/vllm-kuntai-disagg-refactor_1202/vllm/worker/model_runner_base.py", line 116, in _wrapper ERROR 12-03 14:31:48 engine.py:135] return func(*args, **kwargs) ERROR 12-03 14:31:48 engine.py:135] File "/mnt/data_disk101/data_disk/lwq/LLM_INFER/split_platform/opensource/vllm-kuntai-disagg-refactor_1202/vllm/worker/model_runner.py", line 1718, in execute_model ERROR 12-03 14:31:48 engine.py:135] get_kv_transfer_group().send_kv_caches_and_hidden_states( ERROR 12-03 14:31:48 engine.py:135] File "/mnt/data_disk101/data_disk/lwq/LLM_INFER/split_platform/opensource/vllm-kuntai-disagg-refactor_1202/vllm/distributed/kv_transfer/kv_transfer_agent.py", line 60, in send_kv_caches_and_hidden_states ERROR 12-03 14:31:48 engine.py:135] self.connector.send_kv_caches_and_hidden_states( ERROR 12-03 14:31:48 engine.py:135] File "/mnt/data_disk101/data_disk/lwq/LLM_INFER/split_platform/opensource/vllm-kuntai-disagg-refactor_1202/vllm/distributed/kv_transfer/kv_connector/simple_connector.py", line 134, in send_kv_caches_and_hidden_states ERROR 12-03 14:31:48 engine.py:135] _, _, num_heads, head_size = kv_cache[0].shape ERROR 12-03 14:31:48 engine.py:135] ValueError: not enough values to unpack (expected 4, got 2) ERROR 12-03 14:31:48 engine.py:135]

My startup command is:

CUDA_VISIBLE_DEVICES=3 nohup nohup python3 -m vllm.entrypoints.openai.api_server --model $model --port 8100 --max-model-len 1000 --gpu-memory-utilization 0.7 --kv-transfer-config '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > $Log_folder/p.log 2>&1 &

CUDA_VISIBLE_DEVICES=4 nohup python3 -m vllm.entrypoints.openai.api_server --model $model --port 8200 --max-model-len 1000 --gpu-memory-utilization 0.7 --kv-transfer-config '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' > $Log_folder/D.log 2>&1 &

nohup python3 disagg_prefill_proxy_server.py > $Log_folder/proxy_server.log 2>&1 &

My startup command is:

CUDA_VISIBLE_DEVICES=3 nohup nohup python3 -m vllm.entrypoints.openai.api_server --model $model --port 8100 --max-model-len 1000 --gpu-memory-utilization 0.7 --kv-transfer-config '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > $Log_folder/p.log 2>&1 &

CUDA_VISIBLE_DEVICES=4 nohup python3 -m vllm.entrypoints.openai.api_server --model $model --port 8200 --max-model-len 1000 --gpu-memory-utilization 0.7 --kv-transfer-config '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' > $Log_folder/D.log 2>&1 &

nohup python3 disagg_prefill_proxy_server.py > $Log_folder/proxy_server.log 2>&1 &

I guess you are using a GPU card with Volta or Turing architecture? I have found this problem in the older version of this PR. @KuntaiDu If you don't have bandwidth, I can propose a PR to fix this.

@liweiqing1997
Copy link

liweiqing1997 commented Dec 3, 2024

Hi @liweiqing1997 , currently I only tested Llama-style model. What kind of model are you using?

I am testing the Qwen 1.5 14B chat. Previously, I tested a version that had not been merged into the vllm/main branch, and it ran successfully. However, the main branch version does not work. I'm not sure if any changes were made or if there is an issue with my settings.

@KuntaiDu
Copy link
Collaborator Author

KuntaiDu commented Dec 3, 2024

BTW feel free to also comment in disaggregated prefill roadmap (#10818)

@liweiqing1997
Copy link

您好,我在“mian”分支上运行分解推理时遇到以下问题:ValueError:没有足够的值可以解包(预期 4 个,但得到 2 个)。
实际的真实 kvcache 形状是“kv_cache[0] shape torch.Size([2162, 81920])”
INFO 12-03 14:31:48 model_runner_base.py:120] 将失败执行的输入写入 /tmp/err_execute_model_input_20241203-143148.pkl... INFO 12-03 14:31:48 model_runner_base.py:149] 已完成将失败执行的输入写入 /tmp/err_execute_model_input_20241203-143148.pkl。错误 12-03 14:31:48 engine.py:135] ValueError('模型执行错误(输入转储到 /tmp/err_execute_model_input_20241203-143148.pkl):没有足够的值来解包(预期 4 个,得到 2 个)')错误 12-03 14:31:48 engine.py:135] 回溯(最近一次调用最后一次):错误 12-03 14:31:48 engine.py:135] 文件“/mnt/data_disk101/data_disk/lwq/LLM_INFER/split_platform/opensource/vllm-kuntai-disagg-refactor_1202/vllm/worker/model_runner_base.py”,第 116 行,在 _wrapper 中错误12-03 14:31:48 engine.py:135] 返回 func(*args,**kwargs) 错误 12-03 14:31:48 engine.py:135] 文件“/mnt/data_disk101/data_disk/lwq/LLM_INFER/split_platform/opensource/vllm-kuntai-disagg-refactor_1202/vllm/worker/model_runner.py”,第 1718 行,在 execute_model 中错误 12-03 14:31:48 engine.py:135] get_kv_transfer_group().send_kv_caches_and_hidden_​​states(错误 12-03 14:31:48 engine.py:135] 文件“/mnt/data_disk101/data_disk/lwq/LLM_INFER/split_platform/opensource/vllm-kuntai-disagg-refactor_1202/vllm/distributed/kv_transfer/kv_transfer_agent.py”,第 60 行,在 send_kv_caches_and_hidden_​​states 中错误 12-03 14:31:48 engine.py:135] self.connector.send_kv_caches_and_hidden_​​states(错误 12-03 14:31:48 engine.py:135] 文件“/mnt/data_disk101/data_disk/lwq/LLM_INFER/split_platform/opensource/vllm-kuntai-disagg-refactor_1202/vllm/distributed/kv_transfer/kv_connector/simple_connector.py”,第 134 行,位于 send_kv_caches_and_hidden_​​states ERROR 12-03 14:31:48 engine.py:135] ,num_heads,head_size = kv_cache[0].shape ERROR 12-03 14:31:48 engine.py:135] ValueError:没有足够的值来解压(预期 4 个,得到 2 个)ERROR 12-03 14:31:48 engine.py:135]
我的启动命令是:
CUDA_VISIBLE_DEVICES=3 nohup nohup python3 -m vllm.entrypoints.openai.api_server --model $model --port 8100 --max-model-len 1000 --gpu-memory-utilization 0.7 --kv-transfer-config '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > $Log_folder/p.log 2>&1 &
CUDA_VISIBLE_DEVICES=4 nohup python3 -m vllm.entrypoints.openai.api_server --model $model --port 8200 --max-model-len 1000 --gpu-memory-utilization 0.7 --kv-transfer-config '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' > $Log_folder/D.log 2>&1 &
nohup python3 disagg_prefill_proxy_server.py > $Log_folder/proxy_server.log 2>&1 &
我的启动命令是:
CUDA_VISIBLE_DEVICES=3 nohup nohup python3 -m vllm.entrypoints.openai.api_server --model $model --port 8100 --max-model-len 1000 --gpu-memory-utilization 0.7 --kv-transfer-config '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > $Log_folder/p.log 2>&1 &
CUDA_VISIBLE_DEVICES=4 nohup python3 -m vllm.entrypoints.openai.api_server --model $model --port 8200 --max-model-len 1000 --gpu-memory-utilization 0.7 --kv-transfer-config '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' > $Log_folder/D.log 2>&1 &
nohup python3 disagg_prefill_proxy_server.py > $Log_folder/proxy_server.log 2>&1 &

我猜你使用的是 Volta 或 Turing 架构的 GPU 卡?我在这个 PR 的旧版本中发现了这个问题。@KuntaiDu如果您没有带宽,我可以提出 PR 来解决这个问题。

NVIDIA A100-SXM4-80GB

@ShangmingCai
Copy link
Contributor

NVIDIA A100-SXM4-80GB

OK, then this bug may affect a wider range than I thought. My solution is ​​to obtain num_heads and head_size from model_executable.model.config instead of getting them from kv_cache[0].shape.

sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
…t#10502)

This PR provides initial support for single-node disaggregated prefill in 1P1D scenario.
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Co-authored-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: YaoJiayi <120040070@link.cuhk.edu.cn>
@liuyumoye
Copy link

Hello, I noticed that in #6170 you used torch.distributed.init_process_group to initialize all ranks for prefill node and decode node, but later changed it to StatelessProcessGroup for kv cache transfer.
However, StatelessProcessGroup only supports nccl backend. If I want to use the CPU for transferring KV cache, do you have any good suggestions? It seems that TCPStore might not be suitable for transferring large amounts of data.

@youkaichao
Copy link
Member

@liuyumoye can you take a look at #10884 ? I think mooncake transfer engine should support cpu transfer.

@liuyumoye
Copy link

@liuyumoye can you take a look at #10884 ? I think mooncake transfer engine should support cpu transfer.

Thanks, I'll try your suggestion

BKitor pushed a commit to BKitor/vllm that referenced this pull request Dec 30, 2024
…t#10502)

This PR provides initial support for single-node disaggregated prefill in 1P1D scenario.
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Co-authored-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants