Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] AsyncLLM Implementation #9826

Merged
merged 175 commits into from
Nov 11, 2024

Conversation

robertgshaw2-neuralmagic
Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic commented Oct 30, 2024

SUMMARY:

  • AsyncLLM in V1 - better overlapping of GPU and CPU

TODO:

  • Make mypy happy
  • Remove debugging polling in io threads
  • LM Eval testing
  • Cleaner shutdown
  • Get the test passing in the CI :)

FOLLOW UP PRS:

  • Benchmarking with CUDAGraphs (todo as follow up given cudagraphs are broken)
  • Robustness (health checks, make sure abort is working properly everywhere)
  • More AsyncLLM and LLMEngine tests (abort, stop string, other unit)
  • Enable multiprocessing for LLM by default (need to figure out a way around fork) - currently, need to set VLLM_ENABLE_V1_MULTIPROCESSING=1

DIAGRAM:

  • Note: this diagram is a bit dated. There is an EngineCoreClient class that is used by the AsyncLLM to interact with the EngineCore, but the overall architecture is close to what we have.
  • Note: stop strings are detected in the detokenizer and we send an abort message from output_handler_loop to EngineCore
image

…s-proto

# Conflicts:
#	vllm/v1/engine/llm_engine.py
#	vllm/v1/tokenizer/detokenizer.py
@mergify mergify bot removed the needs-rebase label Nov 11, 2024
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work!

Comment on lines +479 to +480
"VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: In which case should we turn this on?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLLM_ENABLE_V1_MULTIPROCESSING=1 enables multiprocessing for EngineCore inside LLM (multiprocessing is always used for AsyncLLM right now). It is faster than the current implementation.

image

We will want to enable VLLM_ENABLE_V1_MULTIPROCESSING=1, but right now it is a problem for LLM since we cannot spawn without an if __name__ == "__main__" guard. We left solving this issue for follow up work.

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic merged commit 6ace6fb into vllm-project:main Nov 11, 2024
72 checks passed
@robertgshaw2-neuralmagic
Copy link
Collaborator Author

@robertgshaw2-neuralmagic @njhill this looks super 🚀🚀🚀

Couple questions for my own understanding:

  • How should I interpret the name core? There's both a v1/core package as well as v1/engine/core.py
  • It looks like this is way faster than v0 + multistep decoding, are we planning on ditching multistep in v1 or is that still TBD?

Thanks!

  • Its a bit unfortunate of a naming conflict. we can consider moving some files from this diff into v1/core
  • Goal of V1 is to simplify vLLM and make it faster such that multistep is not needed, since the code is complex and hard to maintain

@lixiaolx
Copy link

lixiaolx commented Nov 13, 2024

@robertgshaw2-neuralmagic I'm glad to see your optimized pr. I found some problems during the test and wanted to ask for advice. I set llama2-7b, 1gpu, batch=256, used V1-engine for testing and analysis, and used pr Comparing the test with your PR, the token gap is analyzed as follows:
pr-9289:
image

this-pr:
image

I am very happy that the new implementation has removed the token enqueue and dequeue time, but I found that the new version of update_schedule and schedule take longer. There is no major change in the total gap time
I carefully compared the code implementation. I found that there are no big changes.
I wonder if the new multi-threading of encode and decode causes the time consuming to become longer.

@robertgshaw2-neuralmagic
Copy link
Collaborator Author

Hey @lixiaolx - thanks for taking a look. I am having a hard time understanding your analysis - could you clarify?

@njhill
Copy link
Member

njhill commented Nov 13, 2024

Thanks @lixiaolx, nice profiles! What you observe is not unexpected since the scheduling logic currently contends for the GIL with the IPC message serialization/deserialization.

Our intention is to improve this very soon but doing the IPC work in a separate thread is still a big win as a first step since much of that work overlaps with parts of the critical loop that don't contend for the GIL, primarily the forward pass in the GPU.

rickyyx pushed a commit to rickyyx/vllm that referenced this pull request Nov 13, 2024
Signed-off-by: Nick Hill <nickhill@us.ibm.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
@lixiaolx
Copy link

Thanks @lixiaolx, nice profiles! What you observe is not unexpected since the scheduling logic currently contends for the GIL with the IPC message serialization/deserialization.

Our intention is to improve this very soon but doing the IPC work in a separate thread is still a big win as a first step since much of that work overlaps with parts of the critical loop that don't contend for the GIL, primarily the forward pass in the GPU.

Thank you very much for your answer. I tried to compare this solution. If we solve the GIL problem, the remaining gap time will be 2-3ms according to the above calculation.
I would like to ask if we have any plans to do asynchronous scheduling? Compared with sglang asynchronous, there is still a gap.
I recently analyzed that the overall test gap of sglang's asynchronous solution under the same conditions is between 200-300us. If you have a plan, are there any arrangements?

@lixiaolx
Copy link

Hey @lixiaolx - thanks for taking a look. I am having a hard time understanding your analysis - could you clarify?
@robertgshaw2-neuralmagic
I compared the previous pr with your current pr, and did nsys analysis. I added nvtx to analyze the time overhead where the mainloop function is called, and split and analyzed the CPU overhead between the two forwards before and after the GPU.

omer-dayan pushed a commit to omer-dayan/vllm that referenced this pull request Nov 14, 2024
Signed-off-by: Nick Hill <nickhill@us.ibm.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: OmerD <omer@run.ai>
@lixiaolx
Copy link

@robertgshaw2-neuralmagic @njhill Hello, does our pr support multiple gpu cards? Well, when testing llama2-70b 8gpu,occurs server log was stuck here.
image
I use nvidia-smi found that only 0 gpu card was occupying only about 500MB.

@njhill
Copy link
Member

njhill commented Nov 14, 2024

@lixiaolx the V1 path is still in an alpha state and does not yet support multiple GPUs, but will do soon.

@lixiaolx
Copy link

Thanks @lixiaolx, nice profiles! What you observe is not unexpected since the scheduling logic currently contends for the GIL with the IPC message serialization/deserialization.
Our intention is to improve this very soon but doing the IPC work in a separate thread is still a big win as a first step since much of that work overlaps with parts of the critical loop that don't contend for the GIL, primarily the forward pass in the GPU.

Thank you very much for your answer. I tried to compare this solution. If we solve the GIL problem, the remaining gap time will be 2-3ms according to the above calculation. I would like to ask if we have any plans to do asynchronous scheduling? Compared with sglang asynchronous, there is still a gap. I recently analyzed that the overall test gap of sglang's asynchronous solution under the same conditions is between 200-300us. If you have a plan, are there any arrangements?

@njhill ,Is there any arrangement for this asynchronous scheduling?

@lixiaolx
Copy link

@lixiaolx the V1 path is still in an alpha state and does not yet support multiple GPUs, but will do soon.

OK,thank you

sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
Signed-off-by: Nick Hill <nickhill@us.ibm.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
@njhill
Copy link
Member

njhill commented Nov 14, 2024

@njhill ,Is there any arrangement for this asynchronous scheduling?

Not yet, our plan is to first optimize other aspects first since it will be complex to combine this with certain other optimizations.

KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Signed-off-by: Nick Hill <nickhill@us.ibm.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
Signed-off-by: Nick Hill <nickhill@us.ibm.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
tlrmchlsmth added a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
Signed-off-by: Nick Hill <nickhill@us.ibm.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
@B-201
Copy link
Contributor

B-201 commented Nov 28, 2024

Hey @lixiaolx - thanks for taking a look. I am having a hard time understanding your analysis - could you clarify?
@robertgshaw2-neuralmagic
I compared the previous pr with your current pr, and did nsys analysis. I added nvtx to analyze the time overhead where the mainloop function is called, and split and analyzed the CPU overhead between the two forwards before and after the GPU.

Sorry to bother you, but I’d like to ask how you added nvtx to analyze the time overhead of these function calls?

sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Signed-off-by: Nick Hill <nickhill@us.ibm.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants