Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: bugfix to pr 135 #136

Merged
merged 5 commits into from
Feb 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions include/flashinfer/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1530,9 +1530,14 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation(
(num_blocks_per_sm * num_sm) /
(num_kv_heads *
ceil_div(qo_len * group_size, num_rows_per_cta));
uint32_t chunk_size =
max(ceil_div(kv_len, max_num_kv_chunks), 256);
uint32_t num_chunks = ceil_div(kv_len, chunk_size);
uint32_t num_chunks;
if (max_num_kv_chunks > 0) {
uint32_t chunk_size =
max(ceil_div(kv_len, max_num_kv_chunks), 256);
num_chunks = ceil_div(kv_len, chunk_size);
} else {
num_chunks = 0;
}

max_grid_size = num_blocks_per_sm * num_sm;
if (num_chunks > 1) {
Expand Down Expand Up @@ -1626,8 +1631,13 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
uint32_t max_num_kv_chunks =
(num_blocks_per_sm * num_sm) /
(num_kv_heads * ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta));
uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256);
uint32_t num_chunks = ceil_div(kv_len, chunk_size);
uint32_t num_chunks;
if (max_num_kv_chunks > 0) {
uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256);
num_chunks = ceil_div(kv_len, chunk_size);
} else {
num_chunks = 0;
}

if (num_chunks <= 1 || tmp == nullptr) {
// Enough parallelism, do not split-kv
Expand Down