-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Another bucket sort #5109
Another bucket sort #5109
Conversation
If this continues we'll soon see a GPU tensor-core sorting kernel beating this one again :) |
Well, when the GPU beats this, there is still room for improvement. One can easily shave off another 10% or so from the time by having a top_k sampler instance that has the buffers pre-allocated, so one doesn't need to do memory allocations on each invocation of More seriously, I do agree with you that if the usage of large top_k becomes standard practice, it is better to prefilter the logits in some way. |
PR for a min_p implementation that works on unsorted tokens: #5115 . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM; the code has become more difficult to understand but in my opinion speed is more important unless the code is irrelevant for performance. But maybe we should wait for the opinion of another dev just to be sure.
I can confirm that the performance is better than both master and my bucket sort PR:
Good job!
* Initial bucket sort * Bucket sort: slightly better version * Bucket sort: another minor improvement --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
* Initial bucket sort * Bucket sort: slightly better version * Bucket sort: another minor improvement --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
We now have 3 PR's related to sorting logits with the goal to speedup top-k sampling:
std::nth_element
beforestd::partial_sort
The table shows a comparison between master and these 3 PR's as a function of
top_k
. Tests run on a Ryzen-5975WX + RTX--4080 with Ubuntu 22.04 and GCC 11.4.0.