Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize GQA/MQA #1649

Merged
merged 4 commits into from
May 24, 2024
Merged

Optimize GQA/MQA #1649

merged 4 commits into from
May 24, 2024

Conversation

grimoire
Copy link
Collaborator

@grimoire grimoire commented May 23, 2024

enable tensorcore in decoding MHA kernel.

internlm2-chat-20b tp=2 batch_size=256 num_reqs=3000

origin

first token latency(s)(min, max, ave): 1.151, 20.570, 5.350
per-token latency(s) percentile(50, 75, 95, 99): [0.074, 0.087, 0.387, 0.515]

number of prompt tokens: 684711
number of completion tokens: 624144
token throughput (completion token): 1647.393 token/s
token throughput (prompt + completion token): 3454.650 token/s
RPS (request per second): 7.918 req/s
RPM (request per minute): 475.100 req/min

This PR

concurrency: 256
elapsed_time: 346.817s

first token latency(s)(min, max, ave): 2.643, 21.116, 5.051
per-token latency(s) percentile(50, 75, 95, 99): [0.065, 0.083, 0.407, 0.504]

number of prompt tokens: 684711
number of completion tokens: 624144
token throughput (completion token): 1799.636 token/s
token throughput (prompt + completion token): 3773.910 token/s
RPS (request per second): 8.650 req/s
RPM (request per minute): 519.006 req/min

internlm2-chat-20b tp=1 batch_size=128 num_reqs=3000

origin

concurrency: 128
elapsed_time: 566.050s

first token latency(s)(min, max, ave): 1.929, 12.539, 3.287
per-token latency(s) percentile(50, 75, 95, 99): [0.064, 0.066, 0.288, 0.609]

number of prompt tokens: 684711
number of completion tokens: 624144
token throughput (completion token): 1102.630 token/s
token throughput (prompt + completion token): 2312.260 token/s
RPS (request per second): 5.300 req/s
RPM (request per minute): 317.993 req/min

This PR

concurrency: 128
elapsed_time: 480.266s

first token latency(s)(min, max, ave): 1.489, 10.305, 2.808
per-token latency(s) percentile(50, 75, 95, 99): [0.051, 0.053, 0.251, 0.571]

number of prompt tokens: 684711
number of completion tokens: 624144
token throughput (completion token): 1299.579 token/s
token throughput (prompt + completion token): 2725.268 token/s
RPS (request per second): 6.247 req/s
RPM (request per minute): 374.792 req/min

LLama-3-8b-instruct tp=1 batch_size=128 num_reqs=3000

origin

concurrency: 256
elapsed_time: 260.775s

first token latency(s)(min, max, ave): 1.044, 15.817, 3.555
per-token latency(s) percentile(50, 75, 95, 99): [0.056, 0.06, 0.315, 0.379]

number of prompt tokens: 676779
number of completion tokens: 612685
token throughput (completion token): 2349.476 token/s
token throughput (prompt + completion token): 4944.734 token/s
RPS (request per second): 11.504 req/s
RPM (request per minute): 690.250 req/min

This PR

concurrency: 256
elapsed_time: 246.199s

first token latency(s)(min, max, ave): 2.152, 13.018, 3.580
per-token latency(s) percentile(50, 75, 95, 99): [0.051, 0.065, 0.316, 0.361]

number of prompt tokens: 676779
number of completion tokens: 612685
token throughput (completion token): 2488.573 token/s
token throughput (prompt + completion token): 5237.481 token/s
RPS (request per second): 12.185 req/s
RPM (request per minute): 731.115 req/min

@grimoire grimoire added enhancement New feature or request improvement and removed enhancement New feature or request labels May 24, 2024
qk *= sm_scale
# NOTE: inf - inf = nan, and nan will leads to error
qk_mask = history_len >= (start_n + offs_n)
if window_size > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it related to local attention?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

SPLIT_K: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DV: tl.constexpr,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does 'DV' refer to?

Copy link
Collaborator Author

@grimoire grimoire May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

head_Dim of Value.

value could have a different head_dim from key and query(which is referred as BLOCK_DMODEL).

else:
BLOCK_DV = triton.next_power_of_2(Lv)
BLOCK_M = max(16, min(BLOCK, 16384 // BLOCK_DMODEL))
if Lk > 512 and BLOCK > 32:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"and" or "or"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"and".

Lk>512 => BLOCK_DMODEL>=1024
the key smem usage is BLOCK * BLOCK_DMODEL * sizeof(half).


sm_scale = 1.0 / (Lq**0.5)
if sm_scale is None:
sm_scale = 1.0 / (Lq**0.5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity, which model uses a different sm_scale other than 1/sqrt(dim)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -413,6 +569,8 @@ def paged_attention_fwd(
kv_seqlens: Tensor,
max_seqlen: int,
window_size: int = None,
sm_scale: float = None,
shared_kv: int = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int -> bool.
what does shared_kv mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://kexue.fm/archives/10091#Part%203

K is [ci, rope] V is [ci]. V share the same memory with K.

And it is not recommend to enable this flag since the layout of shared V is not friendly to matmul.

Copy link
Collaborator

@RunningLeon RunningLeon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lvhan028 lvhan028 merged commit cd19422 into InternLM:main May 24, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants