Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs committed Aug 14, 2024
1 parent c0dd96f commit 00cf0a6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg
constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16;
const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx],
kv_len = kv_indptr[request_idx + 1] - kv_indptr[request_idx];
const uint32_t kv_len_num = kv_len > 0 ? kv_len : 1;
const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1;
const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len;
const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len;
const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0;
Expand Down Expand Up @@ -1559,7 +1559,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg
// normalize d
normalize_d<num_frags_x, num_frags_y>(o_frag, m, d);

const uint32_t num_kv_chunks = ceil_div(kv_len_num, kv_chunk_size);
const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size;

// write back
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(
Expand Down Expand Up @@ -1633,7 +1633,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
1) * paged_kv.page_size +
paged_kv.last_page_len[request_idx]
: 0;
const uint32_t kv_len_num = kv_len > 0 ? kv_len : 1;
const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1;
const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len;
const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len;
const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0;
Expand Down Expand Up @@ -1874,7 +1874,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
// normalize d
normalize_d<num_frags_x, num_frags_y>(o_frag, m, d);

const uint32_t num_kv_chunks = ceil_div(kv_len_num, kv_chunk_size);
const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size;

// write_back
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(
Expand Down

0 comments on commit 00cf0a6

Please sign in to comment.