-
Notifications
You must be signed in to change notification settings - Fork 444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize GQA/MQA #1649
Optimize GQA/MQA #1649
Conversation
qk *= sm_scale | ||
# NOTE: inf - inf = nan, and nan will leads to error | ||
qk_mask = history_len >= (start_n + offs_n) | ||
if window_size > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it related to local attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
SPLIT_K: tl.constexpr, | ||
BLOCK_DMODEL: tl.constexpr, | ||
BLOCK_DV: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does 'DV' refer to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
head_Dim of Value.
value could have a different head_dim from key and query(which is referred as BLOCK_DMODEL).
else: | ||
BLOCK_DV = triton.next_power_of_2(Lv) | ||
BLOCK_M = max(16, min(BLOCK, 16384 // BLOCK_DMODEL)) | ||
if Lk > 512 and BLOCK > 32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"and" or "or"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"and".
Lk>512
=> BLOCK_DMODEL>=1024
the key smem usage is BLOCK * BLOCK_DMODEL * sizeof(half)
.
|
||
sm_scale = 1.0 / (Lq**0.5) | ||
if sm_scale is None: | ||
sm_scale = 1.0 / (Lq**0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just out of curiosity, which model uses a different sm_scale
other than 1/sqrt(dim)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/modeling_deepseek.py#L749
And MLA might use the sm_scale of origin dim.
@@ -413,6 +569,8 @@ def paged_attention_fwd( | |||
kv_seqlens: Tensor, | |||
max_seqlen: int, | |||
window_size: int = None, | |||
sm_scale: float = None, | |||
shared_kv: int = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int -> bool.
what does shared_kv
mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://kexue.fm/archives/10091#Part%203
K is [ci, rope]
V is [ci]
. V share the same memory with K.
And it is not recommend to enable this flag since the layout of shared V is not friendly to matmul.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
enable tensorcore in decoding MHA kernel.
internlm2-chat-20b tp=2 batch_size=256 num_reqs=3000
origin
This PR
internlm2-chat-20b tp=1 batch_size=128 num_reqs=3000
origin
This PR
LLama-3-8b-instruct tp=1 batch_size=128 num_reqs=3000
origin
This PR