-
Notifications
You must be signed in to change notification settings - Fork 180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BatchPrefillWithPagedKVCacheWrapper has performance degradation when setting use_cuda_graph=True #411
Comments
It's mainly because you are using a very large page size. If you choose smaller import torch
from triton.testing import do_bench
import flashinfer
import time
bsz = 64
device = torch.device("cuda:0")
num_qo_heads = 8
num_kv_heads = 2
head_dim = 128
max_len = 16016
page_size = 16
for dec_len in [1, 2, 3, 4, 5, 6, 7, 8]:
decode_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
qo_indptr = torch.arange(bsz + 1, dtype=torch.int32, device=device)
paged_kv_indptr = torch.arange(bsz + 1, dtype=torch.int32, device=device) * (
max_len // page_size
)
paged_kv_indices = torch.arange(
bsz * (max_len // page_size), dtype=torch.int32, device=device
)
paged_kv_last_page_len = torch.full((bsz,), 16, dtype=torch.int32, device=device)
decode_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
decode_buffer,
"NHD",
use_cuda_graph=True,
qo_indptr_buf=qo_indptr,
paged_kv_indptr_buf=paged_kv_indptr,
paged_kv_indices_buf=paged_kv_indices,
paged_kv_last_page_len_buf=paged_kv_last_page_len,
)
q_flashinfer = torch.randn(
bsz * dec_len, num_qo_heads, head_dim, dtype=torch.bfloat16
).to("cuda:0")
kv_cache_flash_infer = torch.randn(
bsz * max_len // page_size,
2,
page_size,
num_kv_heads,
head_dim,
dtype=torch.bfloat16,
device="cuda:0",
)
decode_wrapper.begin_forward(
qo_indptr=qo_indptr * dec_len,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
page_size=16,
q_data_type=torch.bfloat16,
)
torch.cuda.synchronize()
print("dec_len:", dec_len)
print(
do_bench(
lambda: decode_wrapper.forward(
q_flashinfer, kv_cache_flash_infer, causal=True
)
)
)
decode_wrapper.end_forward()
for dec_len in [1, 2, 3, 4, 5, 6, 7, 8]:
decode_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
qo_indptr = torch.arange(bsz + 1, dtype=torch.int32, device=device)
paged_kv_indptr = torch.arange(bsz + 1, dtype=torch.int32, device=device) * (
max_len // page_size
)
paged_kv_indices = torch.arange(
bsz * (max_len // page_size), dtype=torch.int32, device=device
)
paged_kv_last_page_len = torch.full((bsz,), 16, dtype=torch.int32, device=device)
decode_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
decode_buffer, "NHD"
)
q_flashinfer = torch.randn(
bsz * dec_len, num_qo_heads, head_dim, dtype=torch.bfloat16
).to("cuda:0")
kv_cache_flash_infer = torch.randn(
bsz * max_len // page_size,
2,
page_size,
num_kv_heads,
head_dim,
dtype=torch.bfloat16,
device="cuda:0",
)
decode_wrapper.begin_forward(
qo_indptr=qo_indptr * dec_len,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
page_size=16,
q_data_type=torch.bfloat16,
)
print("dec_len:", dec_len)
print(
do_bench(
lambda: decode_wrapper.forward(
q_flashinfer, kv_cache_flash_infer, causal=True
)
)
)
decode_wrapper.end_forward() and here is my result:
You can see the difference is minimal. |
Let me explain a little bit more about the kernel difference between use/not use CUDAGraph. Due to the requirement of pytorch cudagraph, we must fix the grid size when CUDAGraph is enabled. For our prefill kernels, when CUDAGraph is enabled, we fix the grid size to Our scheduler would partition the work (with the minimal granularity of page), and dispatch different (request, some pages) to different threadblocks, so that the work on different threadblock is balanced. But when you have a very large page size, each request own only one page and we won't partition them, in this case we launch 2 * 132 threadblocks but only 2 * 64 of them are working, which is not efficient. |
This is very clear explanation! Thank you very much! |
Hi @yzh119 , I hope you don't mind me reaching out regarding this closed issue. I am encountering illegal memory access or correctness problems when using The variable
Do you have any suggestions on how to address this issue? Thank you for your time and assistance. |
Hi @ZhongYingMatrix , are you using flashinfer v0.1.6 or the nightly version? |
@yzh119 Thank you for your help! I am currently using version 0.1.6. I will try building from the latest main branch (nightly) to see if it resolves the issue. By the way, do you happen to know if a newer tagged version will be released soon? |
I got result:
It is obvious that after setting use_cuda_graph, the performance is worse than before. Specifically, there is a latency jump from dec_len=4 to dec_len=5. I tested it on H100.
The text was updated successfully, but these errors were encountered: