Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] fix flashinfer cudagraph capture for PP #6708

Merged
merged 3 commits into from
Jul 24, 2024

Conversation

SolitaryThinker
Copy link
Contributor

The previous cudagraph capture assumes only a single pass through the batch sizes and clobbers the pre-allocated max_batch_size indptr and last_page_len tensors. However with PP, we need to capture the graphs for each VE, thus the clobbered tensors results in a size mismatch in the flashinfer wrapper.


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@comaniac
Copy link
Collaborator

Thanks for the fix. Will the engine crash without this PR when enabling CUDA graph + PP + FlashInfer? If so we could try to add a test to cover it. Also cc @LiuXiaoxuanPKU

@SolitaryThinker
Copy link
Contributor Author

SolitaryThinker commented Jul 23, 2024

Yes, currently the cudagraph capture fails as flashinfer wrapper asserts that the batch_size in begin_foward() matches the buffer size it was initialized with.

[rank0]:   File "/home/ray/anaconda3/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank0]:     return _run_code(code, main_globals, None,
[rank0]:   File "/home/ray/anaconda3/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/entrypoints/openai/api_server.py", line 317, in <module>
[rank0]:     run_server(args)
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/entrypoints/openai/api_server.py", line 231, in run_server
[rank0]:     if llm_engine is not None else AsyncLLMEngine.from_engine_args(
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/engine/async_llm_engine.py", line 466, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/engine/async_llm_engine.py", line 380, in __init__
[rank0]:     self.engine = self._init_engine(*args, **kwargs)
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/engine/async_llm_engine.py", line 547, in _init_engine
[rank0]:     return engine_class(*args, **kwargs)
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/engine/llm_engine.py", line 265, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/engine/llm_engine.py", line 377, in _initialize_kv_caches
[rank0]:     self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/executor/distributed_gpu_executor.py", line 62, in initialize_cache
[rank0]:     self._run_workers("initialize_cache",
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/executor/ray_gpu_executor.py", line 350, in _run_workers
[rank0]:     self.driver_worker.execute_method(method, *driver_args,
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/worker/worker_base.py", line 383, in execute_method
[rank0]:     raise e
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/worker/worker_base.py", line 374, in execute_method
[rank0]:     return executor(*args, **kwargs)
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/worker/worker.py", line 220, in initialize_cache
[rank0]:     self._warm_up_model()
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/worker/worker.py", line 236, in _warm_up_model
[rank0]:     self.model_runner.capture_model(self.gpu_cache)
[rank0]:   File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/worker/model_runner.py", line 1096, in capture_model
[rank0]:     attn_metadata.begin_forward()
[rank0]:   File "/mnt/user_storage/vllm-nsight/vllm/attention/backends/flashinfer.py", line 163, in begin_forward
[rank0]:     self.decode_wrapper.begin_forward(
[rank0]:   File "/home/ray/anaconda3/lib/python3.10/site-packages/flashinfer/decode.py", line 450, in begin_forward
[rank0]:     raise ValueError(
[rank0]: ValueError: The batch size should be fixed in cudagraph mode, the runtime batch size 256  mismatches the batch size set during initialization 1

@comaniac
Copy link
Collaborator

Do you think it's possible to add a unit test to tests/distributed/test_pipeline_parallel.py if it's straightforward?

@SolitaryThinker SolitaryThinker changed the title fix flashinfer cudagraph capture for PP [Bugfix] fix flashinfer cudagraph capture for PP Jul 23, 2024
@SolitaryThinker
Copy link
Contributor Author

@comaniac Let me know if this makes sense, not sure if there are other attn backends I should add?

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! We could merge it if test passed.

@SolitaryThinker
Copy link
Contributor Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 23, 2024
@Yard1 Yard1 enabled auto-merge (squash) July 24, 2024 00:15
@Yard1 Yard1 merged commit 5e8ca97 into vllm-project:main Jul 24, 2024
72 of 73 checks passed
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
cduk pushed a commit to cduk/vllm-pascal that referenced this pull request Aug 6, 2024
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants