Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchPrefillWithPagedKVCacheWrapper has performance degradation when setting use_cuda_graph=True #411

Closed
jianc99 opened this issue Jul 30, 2024 · 6 comments

Comments

@jianc99
Copy link

jianc99 commented Jul 30, 2024

import torch
import flashinfer
import time
bsz = 64
device = torch.device("cuda:0")
num_qo_heads = 8
num_kv_heads = 2
head_dim = 128
max_len = 16016

for dec_len in [1,2,3,4,5,6,7,8]:
     decode_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
     qo_indptr = torch.arange(bsz+1, dtype=torch.int32, device=device)
     paged_kv_indptr = torch.arange(bsz+1, dtype=torch.int32, device=device)
     paged_kv_indices = torch.arange(bsz, dtype=torch.int32, device=device)
     paged_kv_last_page_len = torch.zeros((bsz), dtype=torch.int32, device=device) + 16000
     decode_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(decode_buffer, "NHD", use_cuda_graph=True, qo_indptr_buf=qo_indptr, paged_kv_indptr_buf=paged_kv_indptr, paged_kv_indices_buf=paged_kv_indices, paged_kv_last_page_len_buf=paged_kv_last_page_len)
     
     q_flashinfer = torch.randn(bsz*dec_len, num_qo_heads, head_dim, dtype=torch.bfloat16).to("cuda:0")
     kv_cache_flash_infer= torch.randn(bsz, 2, max_len, num_kv_heads, head_dim, dtype=torch.bfloat16, device="cuda:0")

     decode_wrapper.begin_forward(
                    qo_indptr=qo_indptr*dec_len,
                    paged_kv_indptr=paged_kv_indptr,
                    paged_kv_indices=paged_kv_indices,
                    paged_kv_last_page_len=paged_kv_last_page_len,
                    num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, page_size=max_len, q_data_type=torch.bfloat16)
     torch.cuda.synchronize()
     start = time.perf_counter()
     for _ in range(1000):
          decode_wrapper.forward(q_flashinfer, kv_cache_flash_infer, causal=True)
     torch.cuda.synchronize()
     end = time.perf_counter()
     decode_wrapper.end_forward()
     print("dec_len:", dec_len)
     print((end - start)/1000)

for dec_len in [1,2,3,4,5,6,7,8]:
     decode_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
     qo_indptr = torch.arange(bsz+1, dtype=torch.int32, device=device)
     paged_kv_indptr = torch.arange(bsz+1, dtype=torch.int32, device=device)
     paged_kv_indices = torch.arange(bsz, dtype=torch.int32, device=device)
     paged_kv_last_page_len = torch.zeros((bsz), dtype=torch.int32, device=device) + 16000
     decode_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(decode_buffer, "NHD")
     
     q_flashinfer = torch.randn(bsz*dec_len, num_qo_heads, head_dim, dtype=torch.bfloat16).to("cuda:0")
     kv_cache_flash_infer= torch.randn(bsz, 2, max_len, num_kv_heads, head_dim, dtype=torch.bfloat16, device="cuda:0")

     decode_wrapper.begin_forward(
                    qo_indptr=qo_indptr*dec_len,
                    paged_kv_indptr=paged_kv_indptr,
                    paged_kv_indices=paged_kv_indices,
                    paged_kv_last_page_len=paged_kv_last_page_len,
                    num_qo_heads=num_qo_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, page_size=max_len, q_data_type=torch.bfloat16)
     torch.cuda.synchronize()
     start = time.perf_counter()
     for _ in range(1000):
          decode_wrapper.forward(q_flashinfer, kv_cache_flash_infer, causal=True)
     torch.cuda.synchronize()
     end = time.perf_counter()
     decode_wrapper.end_forward()
     print("dec_len:", dec_len)
     print((end - start)/1000)

I got result:

dec_len: 1
0.00037962544430047274
dec_len: 2
0.0003796318080276251
dec_len: 3
0.00038123549008741976
dec_len: 4
0.00038167582917958496
dec_len: 5
0.000498349045868963
dec_len: 6
0.0004994462868198752
dec_len: 7
0.0005127514158375561
dec_len: 8
0.0005026387223042547
dec_len: 1
0.00033620123798027633
dec_len: 2
0.00033656913228332996
dec_len: 3
0.00033665854297578336
dec_len: 4
0.00033720786310732363
dec_len: 5
0.00034855281002819536
dec_len: 6
0.00034927211282774804
dec_len: 7
0.0003585655922070146
dec_len: 8
0.0003496954077854753

It is obvious that after setting use_cuda_graph, the performance is worse than before. Specifically, there is a latency jump from dec_len=4 to dec_len=5. I tested it on H100.

@yzh119
Copy link
Collaborator

yzh119 commented Jul 30, 2024

It's mainly because you are using a very large page size.
The minimal granularity of scheduler is 1 page, and we won't divide a page into multiple chunks. In your example, your page size is 16000, and each request owns only one page, which hinders all possible optimizations our scheduler can do.

If you choose smaller page_size (e.g. 16), the gap would be much smaller because our scheduler could dispatch different pages to different SMs:

import torch
from triton.testing import do_bench
import flashinfer
import time

bsz = 64
device = torch.device("cuda:0")
num_qo_heads = 8
num_kv_heads = 2
head_dim = 128
max_len = 16016
page_size = 16

for dec_len in [1, 2, 3, 4, 5, 6, 7, 8]:
    decode_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
    qo_indptr = torch.arange(bsz + 1, dtype=torch.int32, device=device)
    paged_kv_indptr = torch.arange(bsz + 1, dtype=torch.int32, device=device) * (
        max_len // page_size
    )
    paged_kv_indices = torch.arange(
        bsz * (max_len // page_size), dtype=torch.int32, device=device
    )
    paged_kv_last_page_len = torch.full((bsz,), 16, dtype=torch.int32, device=device)
    decode_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
        decode_buffer,
        "NHD",
        use_cuda_graph=True,
        qo_indptr_buf=qo_indptr,
        paged_kv_indptr_buf=paged_kv_indptr,
        paged_kv_indices_buf=paged_kv_indices,
        paged_kv_last_page_len_buf=paged_kv_last_page_len,
    )

    q_flashinfer = torch.randn(
        bsz * dec_len, num_qo_heads, head_dim, dtype=torch.bfloat16
    ).to("cuda:0")
    kv_cache_flash_infer = torch.randn(
        bsz * max_len // page_size,
        2,
        page_size,
        num_kv_heads,
        head_dim,
        dtype=torch.bfloat16,
        device="cuda:0",
    )

    decode_wrapper.begin_forward(
        qo_indptr=qo_indptr * dec_len,
        paged_kv_indptr=paged_kv_indptr,
        paged_kv_indices=paged_kv_indices,
        paged_kv_last_page_len=paged_kv_last_page_len,
        num_qo_heads=num_qo_heads,
        num_kv_heads=num_kv_heads,
        head_dim=head_dim,
        page_size=16,
        q_data_type=torch.bfloat16,
    )
    torch.cuda.synchronize()
    print("dec_len:", dec_len)
    print(
        do_bench(
            lambda: decode_wrapper.forward(
                q_flashinfer, kv_cache_flash_infer, causal=True
            )
        )
    )
    decode_wrapper.end_forward()

for dec_len in [1, 2, 3, 4, 5, 6, 7, 8]:
    decode_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
    qo_indptr = torch.arange(bsz + 1, dtype=torch.int32, device=device)
    paged_kv_indptr = torch.arange(bsz + 1, dtype=torch.int32, device=device) * (
        max_len // page_size
    )
    paged_kv_indices = torch.arange(
        bsz * (max_len // page_size), dtype=torch.int32, device=device
    )
    paged_kv_last_page_len = torch.full((bsz,), 16, dtype=torch.int32, device=device)
    decode_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
        decode_buffer, "NHD"
    )

    q_flashinfer = torch.randn(
        bsz * dec_len, num_qo_heads, head_dim, dtype=torch.bfloat16
    ).to("cuda:0")
    kv_cache_flash_infer = torch.randn(
        bsz * max_len // page_size,
        2,
        page_size,
        num_kv_heads,
        head_dim,
        dtype=torch.bfloat16,
        device="cuda:0",
    )

    decode_wrapper.begin_forward(
        qo_indptr=qo_indptr * dec_len,
        paged_kv_indptr=paged_kv_indptr,
        paged_kv_indices=paged_kv_indices,
        paged_kv_last_page_len=paged_kv_last_page_len,
        num_qo_heads=num_qo_heads,
        num_kv_heads=num_kv_heads,
        head_dim=head_dim,
        page_size=16,
        q_data_type=torch.bfloat16,
    )
    print("dec_len:", dec_len)
    print(
        do_bench(
            lambda: decode_wrapper.forward(
                q_flashinfer, kv_cache_flash_infer, causal=True
            )
        )
    )
    decode_wrapper.end_forward()

and here is my result:

dec_len: 1
0.3627031445503235
dec_len: 2
0.3647567331790924
dec_len: 3
0.36595797538757324
dec_len: 4
0.3674812316894531
dec_len: 5
0.3738013207912445
dec_len: 6
0.3759947419166565
dec_len: 7
0.375791996717453
dec_len: 8
0.37761634588241577
dec_len: 1
0.3598800599575043
dec_len: 2
0.3614909052848816
dec_len: 3
0.3625786304473877
dec_len: 4
0.36442112922668457
dec_len: 5
0.36965465545654297
dec_len: 6
0.3719732165336609
dec_len: 7
0.37129437923431396
dec_len: 8
0.3727892339229584

You can see the difference is minimal.

@yzh119
Copy link
Collaborator

yzh119 commented Jul 30, 2024

Let me explain a little bit more about the kernel difference between use/not use CUDAGraph.

Due to the requirement of pytorch cudagraph, we must fix the grid size when CUDAGraph is enabled. For our prefill kernels, when CUDAGraph is enabled, we fix the grid size to 2 * #SM (#SM refers to the number of SM, which is 132 on H100), and when CUDAGraph is not enabled, it's grid configuration is data dependent: (64, 2) in your case, where 64 refers to batch_size and 2 refers to num_kv_heads.

Our scheduler would partition the work (with the minimal granularity of page), and dispatch different (request, some pages) to different threadblocks, so that the work on different threadblock is balanced. But when you have a very large page size, each request own only one page and we won't partition them, in this case we launch 2 * 132 threadblocks but only 2 * 64 of them are working, which is not efficient.

@jianc99
Copy link
Author

jianc99 commented Jul 31, 2024

This is very clear explanation! Thank you very much!

@jianc99 jianc99 closed this as completed Jul 31, 2024
@ZhongYingMatrix
Copy link

Due to the requirement of pytorch cudagraph, we must fix the grid size when CUDAGraph is enabled. For our prefill kernels, when CUDAGraph is enabled, we fix the grid size to 2 * #SM (#SM refers to the number of SM, which is 132 on H100).

Hi @yzh119 , I hope you don't mind me reaching out regarding this closed issue. I am encountering illegal memory access or correctness problems when using BatchPrefillWithPagedKVCacheWrapper in CUDA graph mode, especially with large batch sizes. Upon investigation, I found that the grid dim3 differs between the capturing phase and the replay phase, with the grid size exceeding 2 * #SM.

The variable total_num_tiles_q seems to be causing variability here:

padded_batch_size_ = std::max(split_max_batch_size, total_num_tiles_q);

Do you have any suggestions on how to address this issue? Thank you for your time and assistance.

@yzh119
Copy link
Collaborator

yzh119 commented Dec 4, 2024

Hi @ZhongYingMatrix , are you using flashinfer v0.1.6 or the nightly version?
I think @nandor 's recent efforts have fixed these behaviors (#626), and you can try them in flashinfer nightly.

@ZhongYingMatrix
Copy link

@yzh119 Thank you for your help! I am currently using version 0.1.6. I will try building from the latest main branch (nightly) to see if it resolves the issue.

By the way, do you happen to know if a newer tagged version will be released soon?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants