-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix][Kernel] Use int64_t for indices in fp8 quant kernels #6649
[Bugfix][Kernel] Use int64_t for indices in fp8 quant kernels #6649
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
/ready |
Out of curiosity, do you observe any performance regression with int64 indices? I previously saw an issue in triton reported by someone that since int64 uses 2 registers. |
@comaniac I did some quick e2e decode runs -- looks like we may be very slightly slower using int64. But IMO we should still land the PR. Are there any micro-benchmarks for just these kernels?
This PR:
main:
|
Thanks for the benchmarking and it seems ok. |
btw do you think it makes sense to have a unit test (if it doesn't take unaccepted long time) to cover this bug? |
I added a unit test but going to try to get the memory footprint down as it takes over 16 GB right now. It's quick to run though. Edit: Given that our smallest runners have 24GB, we should be good to go |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
…roject#6649) Signed-off-by: Alvant <alvasian@yandex.ru>
Running
neuralmagic/Mistral-Nemo-Instruct-2407-FP8
, we are accessing an illegal memory address because our loop counters over the number of elements are int32_t, and the model has a maximum sequence length of 1 million.Let's change these to int64_t. We should also strongly consider compiling with
-Wconversion
or even-Werror=conversion
to prevent these. I tried adding-Wconversion
in this PR but it's extremely noisy right now, and it will be a much larger change to fix the warnings it cases.PR Checklist (Click to Expand)
Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.
PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
[Bugfix]
for bug fixes.[CI/Build]
for build or continuous integration improvements.[Doc]
for documentation fixes and improvements.[Model]
for adding a new model or improving an existing model. Model name should appear in the title.[Frontend]
For changes on the vLLM frontend (e.g., OpenAI API server,LLM
class, etc.)[Kernel]
for changes affecting CUDA kernels or other compute kernels.[Core]
for changes in the core vLLM logic (e.g.,LLMEngine
,AsyncLLMEngine
,Scheduler
, etc.)[Hardware][Vendor]
for hardware-specific changes. Vendor name should appear in the prefix (e.g.,[Hardware][AMD]
).[Misc]
for PRs that do not fit the above categories. Please use this sparingly.Note: If the PR spans more than one category, please include all relevant prefixes.
Code Quality
The PR need to meet the following code quality standards:
format.sh
to format your code.docs/source/
if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.Notes for Large Changes
Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with
rfc-required
and might not go through the PR.What to Expect for the Reviews
The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:
action-required
label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!