-
Notifications
You must be signed in to change notification settings - Fork 444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add silu mul kernel #2469
Add silu mul kernel #2469
Conversation
grimoire
commented
Sep 14, 2024
- Add kernel to fuse silu and mul in MLP.
- Optimize apply_rotary kernel and rmsnorm kernel.
TRITON_VERSION = version.parse(triton.__version__) | ||
|
||
if TRITON_VERSION >= version.parse('3.0.0'): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we support triton 3.0.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested on llama3
qeh_ptrs = qe_ptr[:, None] + feat_offset_h[None, :] * stride_qed | ||
ql_ptrs += head_id * stride_qh | ||
qh_ptrs += head_id * stride_qh | ||
qel_ptrs += head_id * stride_qeh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible that stride_qeh
is not equal to stride_qh
?
I was wondering if it is necessary to pass stride_qes
, stride_qeh
, stride_qed
, stride_kes
, stride_keh
and stride_ked
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q k can be slice of qkv tensor. stride can be different if output is not inplaced.
def forward(self, x): | ||
"""forward.""" | ||
|
||
if x.size(-1) % 2048 != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why only use fused kernel when x.size(-1) % 2048 == 0
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed block size in
BLOCK_SIZE_N = min(N, 1024) |
I am so lazy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't get it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kernel would be more complex and slow if we had to support an arbitrary input shape.
This kernel only support aligned input. Unaligned input would be computed using the default implementation.