-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mixtral FastGen Support #4828
Mixtral FastGen Support #4828
Conversation
this is amazing, i was also working on this, but i think most of what i have you already added in this pr. thanks @cmikeh2 :) |
|
||
} // namespace scatter | ||
|
||
template <typename T, int copyUnroll> | ||
template <typename T, int copyUnroll, int N_TOP_K> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cmikeh2, I know you like to generalize this function, but I was wondering if we can have two kernels here, one for top-1 and one for top-k, just so that we can remove some of the complexity added for top-1. what do u think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you observing any slowdown with top-1? The re-org intention here was primarily to simplify things. Previously, on block-0 we did a cumsum for the GEMM kernel and the rest of the thread blocks did max reductions. The max reduction is of similar complexity to the cumsum anyways (log(n)
) steps and since it's necessary on all blocks for the top-N case anyways, I thought it made sense to remove the branch from the code and have a unified path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i haven't but i will try to do some profiling of this in the next days. thanks for the clarification on the changes :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @cmikeh2!
deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp
Show resolved
Hide resolved
deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu
Outdated
Show resolved
Hide resolved
The Mixtral PR #4828 has introduced the positional embedding config class which is a required argument of `make_attn_layer()` function. This has forced the user to override and duplicate the `make_attn_layer()` call for new model implementations using RoPE (This has also broken the Falcon model implementations). This PR: - refactors the inference transformer base class to avoid code duplication by adding a new abstract `positional_embedding_config` property - Fixes the Falcon model implementation to use positional embedding config. The models `llama_v2`, `OPT`, `Mistral 7B`, `Mixtral`, `Falcon` and `Phi-2` are tested with the PR! --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Adds support for Mixtral with FastGen. Key features implemented: 1. Top-2 MoE support 2. Better support for RoPE thetas 3. The mistral model implementation --------- Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
The Mixtral PR deepspeedai#4828 has introduced the positional embedding config class which is a required argument of `make_attn_layer()` function. This has forced the user to override and duplicate the `make_attn_layer()` call for new model implementations using RoPE (This has also broken the Falcon model implementations). This PR: - refactors the inference transformer base class to avoid code duplication by adding a new abstract `positional_embedding_config` property - Fixes the Falcon model implementation to use positional embedding config. The models `llama_v2`, `OPT`, `Mistral 7B`, `Mixtral`, `Falcon` and `Phi-2` are tested with the PR! --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Adds support for Mixtral with FastGen. Key features implemented: