From 00cf0a66deccb5fb10bef7b6aab054255e67280e Mon Sep 17 00:00:00 2001 From: zhyncs Date: Wed, 14 Aug 2024 18:24:12 +1000 Subject: [PATCH] fix --- include/flashinfer/attention/prefill.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 48dfac5f..0a9e501a 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1338,7 +1338,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx], kv_len = kv_indptr[request_idx + 1] - kv_indptr[request_idx]; - const uint32_t kv_len_num = kv_len > 0 ? kv_len : 1; + const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len; const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0; @@ -1559,7 +1559,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg // normalize d normalize_d(o_frag, m, d); - const uint32_t num_kv_chunks = ceil_div(kv_len_num, kv_chunk_size); + const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; // write back write_o_reg_gmem( @@ -1633,7 +1633,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage 1) * paged_kv.page_size + paged_kv.last_page_len[request_idx] : 0; - const uint32_t kv_len_num = kv_len > 0 ? kv_len : 1; + const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len; const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0; @@ -1874,7 +1874,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage // normalize d normalize_d(o_frag, m, d); - const uint32_t num_kv_chunks = ceil_div(kv_len_num, kv_chunk_size); + const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; // write_back write_o_reg_gmem(