-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[1/n] Triton sampling kernel #3186
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments & questions and hope you don't mind them!
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Yard1!
|
||
_SAMPLING_EPS = 1e-5 | ||
_SEED_0_REPLACEMENT = 3403598558 | ||
|
||
|
||
class SamplingMetadata: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we skip any of the new operations in this class in the case that no seeds are in use? (which I expect would be very common).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question - I think there are three considerations:
- Skipping seeds could bring a little better performance.
- Skipping seeds introduces more special cases (undesirable).
- Not skipping seeds allows for request-level reproducibility on the server side, which could be useful for debugging model behavior.
Aside from those, triton random operations require some sort of a seed, so generating one would be necessary regardless.
"""Get `seeds_to_generate` child seeds from `seed` and extra entropy.""" | ||
if not is_greedy: | ||
if seed is None: | ||
randint_fn = random.randint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there's effectively no overhead of seeded vs non-seeded random sampling, a nice feature would be to treat random.randint here equivalent to a passed-in seed, and then always return this seed in the API response.
This allows users to use the returned seed to reproduce the same output, if it happened to be something they particularly liked for example (without them having to provide a seed explicitly up-front).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree! That's one of the advantages of always generating the seed. I think it would be good to include it in a followup (ideally once we are using just the kernel so the logic is consistent).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it would be costly to do this in the non-kernel case.
These are only benefits if they translate to non-negligible end-to-end performance improvements right? Curious what the speedup looks like as a proportion of total TPOT? I guess it depends on the mix of parameters and in particular if there are many seeded requests (presumably uncommon) and/or mix of greed, random, seeded random in the same batch (presumably more common). I guess this question might be more important here given the nontrivial amount of new code introduced for this specific optimization.
Would these optimizations be applicable whether or not the dedicated kernel is used? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamp. please address @njhill's comment before merge.
@njhill We are seeing ~10% reduction in sampler time in our fork, but that will require more work to achieve (two next PRs required are Triton AOT compilation for those kernels and refactor of the sampler code to avoid unnecessary operations). This PR only adds the kernel to streamline the review process. Furthermore, once we can fully move to the kernel, we'll be able to remove the existing torch-based sampling code (not including the logit processing code).
I think they would make the sampler code easier to work with, though they would be tailored for the kernel. In general, the introduction of this kernel will allow us to push code complexity away from the sampler and into the kernel. |
@Yard1 do you have a rough sense of what percentage of TPOT sampler time accounts for? (I know as a proportion it would vary based on model size) .. e.g. if that is <10% then I guess this would translate to <1%? |
@njhill You are correct it's not that noticeable in normal usage, but we are seeing large gains in draft model speculative decoding, where the draft model is CPU bound. It can reduce ITL by several ms in that case. |
* upstream/main: [Misc] Bump up transformers to v4.39.0 & Remove StarCoder2Config (vllm-project#3551) [Misc][Log] Add log for tokenizer length not equal to vocabulary size (vllm-project#3500) [🚀 Ready to be merged] Added support for Jais models (vllm-project#3183) Fix 1D query issue from `_prune_hidden_states` (vllm-project#3539) [PREFIX CACHING FOLLOW UP] OrderedDict-based evictor (vllm-project#3431) [BugFix] Hot fix in setup.py for neuron build (vllm-project#3537) Migrate `logits` computation and gather to `model_runner` (vllm-project#3233) [1/n][Chunked Prefill] Refactor input query shapes (vllm-project#3236) [1/n] Triton sampling kernel (vllm-project#3186) [Bugfix] Fix ROCm support in CMakeLists.txt (vllm-project#3534)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This PR is the first one in a series of PRs.
This PR adds a custom triton sampling kernel, giving us the following benefits:
Currently the codepath using the triton kernel is disabled due to the following issues: