Skip to content

Commit

Permalink
refactor: refactor decode handler (#294)
Browse files Browse the repository at this point in the history
Change the use of an optional `fixed_grid_size` to `padded_batch_size`.
  • Loading branch information
yzh119 authored Jun 10, 2024
1 parent 4c5e28b commit 60459e4
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 31 deletions.
8 changes: 3 additions & 5 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -849,13 +849,11 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVL
cudaError_t BatchDecodeWithPagedKVCacheDispatched(
DTypeIn* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
float* lse, bool* block_valid_mask, std::optional<uint32_t> fixed_grid_size, float sm_scale,
float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale,
float rope_scale, float rope_theta, cudaStream_t stream) {
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;
const uint32_t num_kv_heads = paged_kv.num_heads;
const uint32_t batch_size = paged_kv.batch_size;
const uint32_t grid_size = fixed_grid_size.value_or(batch_size * num_kv_heads);
const uint32_t num_qo_heads = num_kv_heads * GROUP_SIZE;

constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL);
Expand All @@ -872,7 +870,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(

if (tmp_v == nullptr) {
// do not use partition-kv kernel
dim3 nblks(grid_size / num_kv_heads, num_kv_heads);
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
auto kernel =
BatchDecodeWithPagedKVCacheKernel</*partition_kv=*/false, POS_ENCODING_MODE,
Expand Down Expand Up @@ -913,7 +911,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
dim3 nblks(grid_size / num_kv_heads, num_kv_heads);
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)partition_kv_kernel, nblks, nthrs, args, smem_size, stream));
Expand Down
12 changes: 7 additions & 5 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class BatchDecodeHandler {
return (IdType*)seq_lengths_before_partition_;
}

uint32_t GetFixedGridSize() const { return fixed_grid_size_; }
uint32_t GetPaddedBatchSize() const { return padded_batch_size_; }

bool* GetBlockValidMask() const { return block_valid_mask_; }

Expand Down Expand Up @@ -320,7 +320,7 @@ class BatchDecodeHandler {
}
size_t padded_batch_size_after_partition = max_grid_size / num_kv_heads;
if (tmp_size > 0) {
fixed_grid_size_ = padded_batch_size_after_partition * num_kv_heads;
padded_batch_size_ = padded_batch_size_after_partition;
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(
num_qo_heads * padded_batch_size_after_partition * HEAD_DIM * sizeof(DTypeOut), 16);
Expand Down Expand Up @@ -367,11 +367,13 @@ class BatchDecodeHandler {
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
} else {
block_valid_mask_ = nullptr;
fixed_grid_size_ = num_kv_heads * batch_size;
padded_batch_size_ = num_kv_heads * batch_size;
}
} else {
// NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled.
block_valid_mask_ = nullptr;
// do not pad the batch size when not using CUDAGraph
padded_batch_size_ = batch_size_after_partition_;
if (tmp_size > 0) {
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(tmp_size, 16);
Expand Down Expand Up @@ -418,7 +420,7 @@ class BatchDecodeHandler {

cudaError_t EndForward() {
forward_started_ = false;
fixed_grid_size_ = 0;
padded_batch_size_ = 0;
batch_size_before_partition_ = 0;
batch_size_after_partition_ = 0;
block_valid_mask_ = nullptr;
Expand Down Expand Up @@ -492,7 +494,7 @@ class BatchDecodeHandler {
void* seq_lengths_before_partition_;
bool forward_started_;
bool cuda_graph_enabled_;
uint32_t fixed_grid_size_;
uint32_t padded_batch_size_;
uint32_t fixed_batch_size_;
cudaStream_t stream_;
};
Expand Down
8 changes: 3 additions & 5 deletions include/flashinfer/decode_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVL
cudaError_t BatchDecodeWithPagedKVCacheDispatched(
DTypeIn* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
float* lse, bool* block_valid_mask, std::optional<uint32_t> fixed_grid_size, float sm_scale,
float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale,
float rope_scale, float rope_theta, cudaStream_t stream);

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
Expand Down Expand Up @@ -84,10 +84,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched(
return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage, KV_LAYOUT,
POS_ENCODING_MODE, DTypeIn, DTypeOut, IdType>(
q, q_offset, new_paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse,
handler->GetBlockValidMask(),
(handler->IsCUDAGraphEnabled() ? std::optional<uint32_t>(handler->GetFixedGridSize())
: std::nullopt),
sm_scale, rope_scale, rope_theta, stream);
handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), sm_scale, rope_scale, rope_theta,
stream);
}

} // namespace flashinfer
Expand Down
3 changes: 1 addition & 2 deletions python/generate_batch_paged_decode_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def get_cu_file_str(
paged_kv_t<page_storage, {kv_layout}, {dtype_in}, {idtype}> paged_kv,
kv_partition_info_t<{idtype}> kv_partition_info,
{dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse,
bool* block_valid_mask,
std::optional<uint32_t> fixed_grid_size,
bool* block_valid_mask, uint32_t padded_batch_size,
float sm_scale, float rope_scale,
float rope_theta, cudaStream_t stream);
Expand Down
5 changes: 2 additions & 3 deletions src/bench_batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,9 @@ void bench_flashinfer_batch_decode(nvbench::state& state) {
} else {
state.exec([&](nvbench::launch&) {
cudaError_t status =
BatchDecodeWithPagedKVCache<PageStorage::kIndices, kv_layout, T, T, int32_t>(
BatchDecodeWithPagedKVCacheNoSplitKV<PageStorage::kIndices, kv_layout, T, T, int32_t>(
thrust::raw_pointer_cast(q.data()), /*q_offset=*/nullptr, paged_kv,
kv_partition_info_t<int32_t>(), thrust::raw_pointer_cast(o.data()), /*tmp_v=*/nullptr,
/*tmp_s=*/nullptr,
kv_partition_info_t<int32_t>(), thrust::raw_pointer_cast(o.data()),
/*lse=*/nullptr, num_qo_heads, pos_encoding_mode);
if (status != cudaSuccess) {
state.skip("CUDA error: " + std::string(cudaGetErrorString(status)));
Expand Down
13 changes: 7 additions & 6 deletions src/flashinfer_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@ cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTy

template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchDecodeWithPagedKVCache(
cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV(
DTypeIn* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
float* lse, uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, float* lse, uint32_t num_qo_heads,
PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
std::optional<float> maybe_sm_scale = std::nullopt, float rope_scale = 1.f,
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
const uint32_t num_kv_heads = paged_kv.num_heads;
Expand All @@ -233,9 +233,10 @@ cudaError_t BatchDecodeWithPagedKVCache(
return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
kv_layout, POS_ENCODING_MODE, DTypeIn,
DTypeOut, IdType>(
q, q_offset, paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse,
/*block_valid_mask=*/nullptr, std::nullopt, sm_scale, rope_scale, rope_theta,
stream);
q, q_offset, paged_kv, kv_partition_info, o, /*tmp_v=*/nullptr, /*tmp_s=*/nullptr,
lse,
/*block_valid_mask=*/nullptr, /*padded_batch_size=*/paged_kv.batch_size, sm_scale,
rope_scale, rope_theta, stream);
})})});

return cudaSuccess;
Expand Down
10 changes: 5 additions & 5 deletions src/test_batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si

if (!cooperative) {
// use non-cooperative kernel
cudaError_t status =
flashinfer::BatchDecodeWithPagedKVCache<PageStorage::kIndices, kv_layout, T, T, int32_t>(
thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv,
kv_partition_info_t<int32_t>(), thrust::raw_pointer_cast(o_device.data()),
/*tmp_v=*/nullptr, /*tmp_s=*/nullptr, /*lse=*/nullptr, num_qo_heads, pos_encoding_mode);
cudaError_t status = flashinfer::BatchDecodeWithPagedKVCacheNoSplitKV<PageStorage::kIndices,
kv_layout, T, T, int32_t>(
thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv,
kv_partition_info_t<int32_t>(), thrust::raw_pointer_cast(o_device.data()),
/*lse=*/nullptr, num_qo_heads, pos_encoding_mode);
EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status));
} else {
cudaError_t status = flashinfer::BatchDecodeWithPagedKVCacheWrapper<PageStorage::kIndices,
Expand Down

0 comments on commit 60459e4

Please sign in to comment.