Skip to content

Commit

Permalink
bugfix: fix wrong padded_batch_size_ (#296)
Browse files Browse the repository at this point in the history
In #294 , we set `padded_batch_size_` to `num_kv_heads * batch_size`,
which should be `batch_size`
  • Loading branch information
yzh119 authored Jun 11, 2024
1 parent 60459e4 commit aff4cf0
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -318,56 +318,56 @@ class BatchDecodeHandler {
<< " initialized for CUDAGraph";
throw std::runtime_error(err_msg.str());
}
size_t padded_batch_size_after_partition = max_grid_size / num_kv_heads;
size_t padded_batch_size = max_grid_size / num_kv_heads;
if (tmp_size > 0) {
padded_batch_size_ = padded_batch_size_after_partition;
padded_batch_size_ = padded_batch_size;
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(
num_qo_heads * padded_batch_size_after_partition * HEAD_DIM * sizeof(DTypeOut), 16);
num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16);
tmp_s_ = allocator.aligned_alloc<void>(
num_qo_heads * padded_batch_size_after_partition * 2 * sizeof(float), 16);
num_qo_heads * padded_batch_size * 2 * sizeof(float), 16);
new_indptr_ = allocator.aligned_alloc<void>(
(padded_batch_size_after_partition + 1) * sizeof(IdType), 16);
(padded_batch_size + 1) * sizeof(IdType), 16);

void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ =
allocator.aligned_alloc<void>(padded_batch_size_after_partition * sizeof(IdType), 16);
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* new_last_page_len_h_ =
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ = allocator.aligned_alloc<void>(
(padded_batch_size_after_partition + 1) * sizeof(IdType), 16);
(padded_batch_size + 1) * sizeof(IdType), 16);
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ =
allocator.aligned_alloc<void>(padded_batch_size_after_partition * sizeof(IdType), 16);
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ =
allocator.aligned_alloc<void>(padded_batch_size_after_partition * sizeof(IdType), 16);
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* chunk_start_pos_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void>(padded_batch_size_after_partition * sizeof(IdType), 16);
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* seq_lengths_before_partition_h_ =
(char*)page_locked_buffer_ +
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
block_valid_mask_ =
allocator.aligned_alloc<bool>(padded_batch_size_after_partition * sizeof(bool), 16);
allocator.aligned_alloc<bool>(padded_batch_size * sizeof(bool), 16);
bool* block_valid_mask_h_ =
(bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_);
std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size_after_partition, 0);
std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0);

size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, padded_batch_size_after_partition, page_size,
max_num_pages_per_batch, batch_size, padded_batch_size, page_size,
indptr, last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
(IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_,
/*device_buffer=*/new_indptr_,
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
} else {
block_valid_mask_ = nullptr;
padded_batch_size_ = num_kv_heads * batch_size;
padded_batch_size_ = batch_size;
}
} else {
// NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled.
Expand Down

0 comments on commit aff4cf0

Please sign in to comment.