Skip to content

Commit

Permalink
bugfix: fix the sliding window iteration bound for SWA in batch prefi…
Browse files Browse the repository at this point in the history
…ll operators (#563)

The iteration bound for sliding window in batch prefill kernels is
wrong, this PR fixes the issue.
Note that this bug do not influence kernel accuracy, but might harm
kernel performance under some circumstance.
  • Loading branch information
yzh119 authored Oct 27, 2024
1 parent 9d2996d commit 4800368
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1633,7 +1633,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag
16 * NUM_WARPS_KV * NUM_FRAGS_KV);

const uint32_t window_iteration =
ceil_div(sub_if_greater_or_zero(kv_len + (bx + 1) * num_rows_per_cta,
ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * num_rows_per_cta,
qo_len + window_left + chunk_start),
(16 * NUM_WARPS_KV * NUM_FRAGS_KV));

Expand Down Expand Up @@ -1962,7 +1962,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag
16 * NUM_WARPS_KV * NUM_FRAGS_KV);

const uint32_t window_iteration =
ceil_div(sub_if_greater_or_zero(kv_len + (bx + 1) * num_rows_per_cta,
ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * num_rows_per_cta,
qo_len + window_left + chunk_start),
(16 * NUM_WARPS_KV * NUM_FRAGS_KV));

Expand Down

0 comments on commit 4800368

Please sign in to comment.