-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Implement disagg prefill by StatelessProcessGroup #10502
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Signed-off-by: KuntaiDu <kuntai@uchicago.edu> Co-authored-by: ApostaC <yihua98@uchicago.edu> Co-authored-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
4541111
to
1eadc94
Compare
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
… package Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This pull request has merge conflicts that must be resolved before it can be |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Hello. @KuntaiDu Shortly after a successful invocation to the proxy. I got these errors from the vllm instances. [rank0]:[W1202 17:11:14.631061365 socket.cpp:487] [c10d] waitForInput: socket SocketImpl(fd=125, addr=[::ffff:127.0.0.1]:50228, remote=[::ffff:127.0.0.1]:14580) timed out after 300000ms |
Currently, PyNcclPipe does not implement heartbeats, so the connection will be closed if no request is sent within 5 minutes. |
…t#10502) This PR provides initial support for single-node disaggregated prefill in 1P1D scenario. Signed-off-by: KuntaiDu <kuntai@uchicago.edu> Co-authored-by: ApostaC <yihua98@uchicago.edu> Co-authored-by: YaoJiayi <120040070@link.cuhk.edu.cn> Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Is there a workaround to solve this timeout. Or how to modify the 5min timeout. Thanks! |
This is a known issue, it will be addressed in the future PR by @KuntaiDu . If you need a quick workaround, you can modify |
Hello, I encountered the following issue while running the decomposition reasoning on the 'mian' branch: The actual real kvcache shape is “kv_cache[0] shape torch.Size([2162, 81920])” INFO 12-03 14:31:48 model_runner_base.py:120] Writing input of failed execution to /tmp/err_execute_model_input_20241203-143148.pkl... My startup command is: CUDA_VISIBLE_DEVICES=3 nohup nohup python3 CUDA_VISIBLE_DEVICES=4 nohup python3 nohup python3 disagg_prefill_proxy_server.py > $Log_folder/proxy_server.log 2>&1 & My startup command is: CUDA_VISIBLE_DEVICES=3 nohup nohup python3 CUDA_VISIBLE_DEVICES=4 nohup python3 nohup python3 disagg_prefill_proxy_server.py > $Log_folder/proxy_server.log 2>&1 & |
Hi @liweiqing1997 , currently I only tested Llama-style model. What kind of model are you using? |
I guess you are using a GPU card with Volta or Turing architecture? I have found this problem in the older version of this PR. @KuntaiDu If you don't have bandwidth, I can propose a PR to fix this. |
I am testing the Qwen 1.5 14B chat. Previously, I tested a version that had not been merged into the vllm/main branch, and it ran successfully. However, the main branch version does not work. I'm not sure if any changes were made or if there is an issue with my settings. |
BTW feel free to also comment in disaggregated prefill roadmap (#10818) |
NVIDIA A100-SXM4-80GB |
OK, then this bug may affect a wider range than I thought. My solution is to obtain |
…t#10502) This PR provides initial support for single-node disaggregated prefill in 1P1D scenario. Signed-off-by: KuntaiDu <kuntai@uchicago.edu> Co-authored-by: ApostaC <yihua98@uchicago.edu> Co-authored-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Hello, I noticed that in #6170 you used torch.distributed.init_process_group to initialize all ranks for prefill node and decode node, but later changed it to StatelessProcessGroup for kv cache transfer. |
@liuyumoye can you take a look at #10884 ? I think mooncake transfer engine should support cpu transfer. |
Thanks, I'll try your suggestion |
…t#10502) This PR provides initial support for single-node disaggregated prefill in 1P1D scenario. Signed-off-by: KuntaiDu <kuntai@uchicago.edu> Co-authored-by: ApostaC <yihua98@uchicago.edu> Co-authored-by: YaoJiayi <120040070@link.cuhk.edu.cn>
A light-weight implementation of disaggregated prefill. I switched from PR #8498 to here in order to fix DCO issues.