Llama3.1: Median decode latency is high with batch size 128 on the Triton backend #1935
Replies: 5 comments 1 reply
-
@Jackycheng0808 Nice analysis! |
Beta Was this translation helpful? Give feedback.
0 replies
-
Hi @Jackycheng0808 Could you test with |
Beta Was this translation helpful? Give feedback.
0 replies
-
We might find the problem you can also add |
Beta Was this translation helpful? Give feedback.
1 reply
-
@Jackycheng0808 Thanks for reporting the issue and providing helpful data for debugging! It will be fixed in #2134. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Docker Image: sglang 0.3.4.post2
Hardware: H200
Command: python3 -m sglang.bench_latency --batch-size 128 --input 128 --output 128 --model "amd/Meta-Llama-3.1-8B-Instruct-FP8-KV" --quantization fp8 --tp 1
Hi, I am testing Llama 3.1 on different backends (Triton & FlashInfer) and have found some unusual behavior on the Triton backend. The median decode latency and total latency significantly increase from batch size 64 to batch size 128, then drop back down at batch size 256. Compared to the FlashInfer backend, the latency is 5x slower at batch size 128. After profiling, I realized the _fwd_grouped_kernel_stage1 kernel takes ~90% of the execution time, while the BatchDecode kernel of the FlashInfer engine only takes 24%. I am wondering if there might be an issue in the fwd_grouped_kernel_stage1 Triton kernel implementation?
Beta Was this translation helpful? Give feedback.
All reactions