Skip to content

Commit

Permalink
feat: pass a dynamic token count to the cascade kernels (#635)
Browse files Browse the repository at this point in the history
Under CUDA graph, if the graph is built with a maximal token count, the
actual number of tokens from `qo_indptr` is passed on to the cascade
kernels.
  • Loading branch information
nandor authored Nov 25, 2024
1 parent db9c48d commit 5fe9f7d
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 138 deletions.
26 changes: 12 additions & 14 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2121,7 +2121,6 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P
const uint32_t num_qo_heads = params.num_qo_heads;
const uint32_t num_kv_heads = params.num_kv_heads;
const uint_fastdiv group_size_fastdiv(num_qo_heads / num_kv_heads);
const uint32_t total_num_rows = params.total_num_rows;
constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q);
constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q);
constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q);
Expand Down Expand Up @@ -2198,13 +2197,13 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse,
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
tmp_v, tmp_s, params.merge_indptr, o, lse, params.max_total_num_rows,
params.total_num_rows, num_qo_heads, HEAD_DIM, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o,
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(
VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, params.max_total_num_rows,
params.total_num_rows, num_qo_heads, HEAD_DIM, stream));
}
}
}
Expand All @@ -2223,7 +2222,6 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa
const uint32_t num_qo_heads = params.num_qo_heads;
const uint32_t num_kv_heads = params.paged_kv.num_heads;
const uint_fastdiv group_size_fastdiv(num_qo_heads / num_kv_heads);
const uint32_t total_num_rows = params.total_num_rows;
constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q);
constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q);
constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q);
Expand Down Expand Up @@ -2300,13 +2298,13 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse,
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
tmp_v, tmp_s, params.merge_indptr, o, lse, params.max_total_num_rows,
params.total_num_rows, num_qo_heads, HEAD_DIM, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o,
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(
VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, params.max_total_num_rows,
params.total_num_rows, num_qo_heads, HEAD_DIM, stream));
}
}
}
Expand Down
12 changes: 8 additions & 4 deletions include/flashinfer/attention/prefill_params.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ struct BatchPrefillRaggedParams {
IdType* o_indptr;
IdType* kv_chunk_size_ptr;
bool* block_valid_mask;
uint32_t total_num_rows;
uint32_t max_total_num_rows;
uint32_t* total_num_rows;
uint32_t padded_batch_size;
bool partition_kv;

Expand Down Expand Up @@ -178,7 +179,8 @@ struct BatchPrefillRaggedParams {
o_indptr(nullptr),
kv_chunk_size_ptr(nullptr),
block_valid_mask(nullptr),
total_num_rows(0),
max_total_num_rows(0),
total_num_rows(nullptr),
padded_batch_size(0),
partition_kv(false) {}

Expand Down Expand Up @@ -227,7 +229,8 @@ struct BatchPrefillPagedParams {
IdType* o_indptr;
bool* block_valid_mask;
IdType* kv_chunk_size_ptr;
uint32_t total_num_rows;
uint32_t max_total_num_rows;
uint32_t* total_num_rows;
uint32_t padded_batch_size;
bool partition_kv;

Expand Down Expand Up @@ -261,7 +264,8 @@ struct BatchPrefillPagedParams {
o_indptr(nullptr),
block_valid_mask(nullptr),
kv_chunk_size_ptr(nullptr),
total_num_rows(0),
max_total_num_rows(0),
total_num_rows(nullptr),
padded_batch_size(0),
partition_kv(false) {}

Expand Down
38 changes: 25 additions & 13 deletions include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
struct PrefillPlanInfo {
int64_t padded_batch_size;
int64_t total_num_rows;
int64_t total_num_rows_offset;
int64_t cta_tile_q;
int64_t request_indices_offset;
int64_t qo_tile_indices_offset;
Expand All @@ -534,6 +535,7 @@ struct PrefillPlanInfo {
PrefillPlanInfo()
: padded_batch_size(0),
total_num_rows(0),
total_num_rows_offset(0),
cta_tile_q(0),
request_indices_offset(0),
qo_tile_indices_offset(0),
Expand All @@ -551,6 +553,7 @@ struct PrefillPlanInfo {
std::vector<int64_t> ToVector() const {
return {padded_batch_size,
total_num_rows,
total_num_rows_offset,
cta_tile_q,
request_indices_offset,
qo_tile_indices_offset,
Expand All @@ -567,25 +570,26 @@ struct PrefillPlanInfo {

// From std::vector<int64_t> to PrefillPlanInfo
void FromVector(const std::vector<int64_t>& vec) {
if (vec.size() != 14) {
if (vec.size() != 15) {
std::ostringstream err_msg;
err_msg << "PrefillPlanInfo::FromVector: vec.size() should be 14, but got " << vec.size();
FLASHINFER_ERROR(err_msg.str());
}
padded_batch_size = vec[0];
total_num_rows = vec[1];
cta_tile_q = vec[2];
request_indices_offset = vec[3];
qo_tile_indices_offset = vec[4];
kv_tile_indices_offset = vec[5];
merge_indptr_offset = vec[6];
o_indptr_offset = vec[7];
kv_chunk_size_ptr_offset = vec[8];
v_offset = vec[9];
s_offset = vec[10];
block_valid_mask_offset = vec[11];
enable_cuda_graph = vec[12];
split_kv = vec[13];
total_num_rows_offset = vec[2];
cta_tile_q = vec[3];
request_indices_offset = vec[4];
qo_tile_indices_offset = vec[5];
kv_tile_indices_offset = vec[6];
merge_indptr_offset = vec[7];
o_indptr_offset = vec[8];
kv_chunk_size_ptr_offset = vec[9];
v_offset = vec[10];
s_offset = vec[11];
block_valid_mask_offset = vec[12];
enable_cuda_graph = vec[13];
split_kv = vec[14];
}
};

Expand Down Expand Up @@ -640,6 +644,14 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
plan_info.kv_chunk_size_ptr_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr");

if (plan_info.enable_cuda_graph) {
plan_info.total_num_rows_offset =
int_allocator.aligned_alloc_offset(sizeof(uint32_t), 16, "batch_prefill_total_num_rows");
uint32_t* total_num_rows_h =
GetPtrFromBaseOffset<uint32_t>(page_locked_int_buffer, plan_info.total_num_rows_offset);
*total_num_rows_h = qo_indptr_h[batch_size];
}

IdType* request_indices_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.request_indices_offset);
IdType* qo_tile_indices_h =
Expand Down
12 changes: 10 additions & 2 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,12 @@ void BatchPrefillWithRaggedKVCacheRun(
GetPtrFromBaseOffset<bool>(int_buffer_ptr, plan_info.block_valid_mask_offset);
}
}
params.total_num_rows = plan_info.total_num_rows;
params.padded_batch_size = plan_info.padded_batch_size;
params.max_total_num_rows = plan_info.total_num_rows;
if (plan_info.enable_cuda_graph) {
params.total_num_rows =
GetPtrFromBaseOffset<uint32_t>(int_buffer_ptr, plan_info.total_num_rows_offset);
}

cudaError_t status = cudaSuccess;

Expand Down Expand Up @@ -290,8 +294,12 @@ void BatchPrefillWithPagedKVCacheRun(
GetPtrFromBaseOffset<bool>(int_buffer_ptr, plan_info.block_valid_mask_offset);
}
}
params.total_num_rows = plan_info.total_num_rows;
params.padded_batch_size = plan_info.padded_batch_size;
params.max_total_num_rows = plan_info.total_num_rows;
if (plan_info.enable_cuda_graph) {
params.total_num_rows =
GetPtrFromBaseOffset<uint32_t>(int_buffer_ptr, plan_info.total_num_rows_offset);
}

cudaError_t status = cudaSuccess;

Expand Down
Loading

0 comments on commit 5fe9f7d

Please sign in to comment.