Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add silu mul kernel #2469

Merged
merged 4 commits into from
Sep 19, 2024
Merged

Add silu mul kernel #2469

merged 4 commits into from
Sep 19, 2024

Conversation

grimoire
Copy link
Collaborator

  • Add kernel to fuse silu and mul in MLP.
  • Optimize apply_rotary kernel and rmsnorm kernel.

TRITON_VERSION = version.parse(triton.__version__)

if TRITON_VERSION >= version.parse('3.0.0'):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support triton 3.0.0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested on llama3

qeh_ptrs = qe_ptr[:, None] + feat_offset_h[None, :] * stride_qed
ql_ptrs += head_id * stride_qh
qh_ptrs += head_id * stride_qh
qel_ptrs += head_id * stride_qeh
Copy link
Collaborator

@lvhan028 lvhan028 Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that stride_qeh is not equal to stride_qh?
I was wondering if it is necessary to pass stride_qes, stride_qeh, stride_qed, stride_kes, stride_keh and stride_ked

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q k can be slice of qkv tensor. stride can be different if output is not inplaced.

def forward(self, x):
"""forward."""

if x.size(-1) % 2048 != 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only use fused kernel when x.size(-1) % 2048 == 0.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed block size in

BLOCK_SIZE_N = min(N, 1024)

I am so lazy.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel would be more complex and slow if we had to support an arbitrary input shape.
This kernel only support aligned input. Unaligned input would be computed using the default implementation.

@lvhan028 lvhan028 merged commit 97449e3 into InternLM:main Sep 19, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants