Skip to content

Commit

Permalink
bugfix: Fix the computation of total_num_tiles_q (#652)
Browse files Browse the repository at this point in the history
The previous upper bound forgot to multiply with `gqa_group_size`.
  • Loading branch information
nandor authored Dec 12, 2024
1 parent 4313654 commit 4c15777
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <vector>

#include "../allocator.h"
#include "../exception.h"
#include "../pos_enc.cuh"
#include "../utils.cuh"

Expand Down Expand Up @@ -483,7 +484,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
// number of rows and the batch size. The sum of qo lengths rounded
// up to cta_tile_q will not exceed this number derived from the total
// number of rows.
total_num_tiles_q = ceil_div(total_num_rows, cta_tile_q) + batch_size - 1;
total_num_tiles_q = ceil_div(total_num_rows * gqa_group_size, cta_tile_q) + batch_size - 1;
} else {
int64_t sum_packed_qo_len = 0;
for (uint32_t i = 0; i < batch_size; ++i) {
Expand All @@ -505,10 +506,11 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
// step 3: split qo_indptr and kv_indptr
uint32_t new_batch_size = 0;
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
int64_t packed_qo_len = packed_qo_len_arr[request_idx],
kv_len = std::max(int(kv_len_arr[request_idx]), 1);
int64_t num_tiles_q = ceil_div(packed_qo_len, cta_tile_q),
num_tiles_kv = ceil_div(kv_len, kv_chunk_size);
const int64_t packed_qo_len = packed_qo_len_arr[request_idx];
const int64_t kv_len = std::max(int(kv_len_arr[request_idx]), 1);
const int64_t num_tiles_q = ceil_div(packed_qo_len, cta_tile_q);
const int64_t num_tiles_kv = ceil_div(kv_len, kv_chunk_size);

for (uint32_t q_tile_idx = 0; q_tile_idx < num_tiles_q; ++q_tile_idx) {
for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) {
new_batch_size += 1;
Expand All @@ -525,14 +527,16 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv);
}

size_t padded_batch_size =
const size_t padded_batch_size =
enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : new_batch_size;
FLASHINFER_CHECK(new_batch_size <= padded_batch_size,
"new batch size should not exceed padded batch size");

// step 4: multiply kv_chunk_size by page_size
kv_chunk_size *= page_size;

return std::make_tuple(split_kv, total_num_tiles_q, new_batch_size, padded_batch_size, cta_tile_q,
kv_chunk_size, std::move(request_indices), std::move(qo_tile_indices),
return std::make_tuple(split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size,
std::move(request_indices), std::move(qo_tile_indices),
std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr));
}

Expand Down Expand Up @@ -639,15 +643,14 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads;

// step 2: determine kv_chunk_size
auto [split_kv, total_num_tiles_q, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size,
request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec,
o_indptr_vec] = PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, batch_size,
num_qo_heads, num_kv_heads, head_dim, page_size,
max_batch_size_if_split, enable_cuda_graph);
auto [split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec,
qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] =
PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, num_qo_heads,
num_kv_heads, head_dim, page_size, max_batch_size_if_split,
enable_cuda_graph);

plan_info.cta_tile_q = cta_tile_q;
plan_info.total_num_rows = total_num_rows;

plan_info.enable_cuda_graph = enable_cuda_graph;
plan_info.padded_batch_size = padded_batch_size;
plan_info.split_kv = split_kv;
Expand Down

0 comments on commit 4c15777

Please sign in to comment.